Skip to content

Commit

Permalink
Merge pull request #675 from andrsd/interp
Browse files Browse the repository at this point in the history
feat: Adding `Interpolation` class
  • Loading branch information
andrsd authored Jan 8, 2025
2 parents 86e1cf1 + 6b420f2 commit b810c74
Show file tree
Hide file tree
Showing 3 changed files with 272 additions and 0 deletions.
92 changes: 92 additions & 0 deletions include/godzilla/Interpolation.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// SPDX-FileCopyrightText: 2025 David Andrs <[email protected]>
// SPDX-License-Identifier: MIT

#pragma once

#include "godzilla/Error.h"
#include "godzilla/Types.h"
#include "godzilla/Vector.h"
#include "petscdm.h"

namespace godzilla {

class Interpolation {
public:
/// Construct an interpolation object
Interpolation();

~Interpolation();

/// Creates an interpolation object
void create(MPI_Comm comm);

/// Destroys the interpolation object
void destroy();

/// Add points at which we will interpolate the fields
///
/// @param n The number of points
/// @param points The coordinates of the points, an array of size `n * dim`
void add_points(Int n, const Real points[]);

void add_points(const std::vector<Real> & points);

/// Gets a Vec with the coordinates of each interpolation point
///
/// @return The coordinates of the interpolation points
Vector get_coordinates() const;

/// Gets the spatial dimension for the interpolation context
///
/// @return The spatial dimension
Int get_dim() const;

/// Gets the number of fields interpolated at a point
///
/// @return The number of fields
Int get_dof() const;

/// Gets a `Vector` which can hold all the interpolated field values
///
/// @return A vector capable of holding the interpolated field values
///
/// @note This vector should be returned using `RestoreVector()
Vector get_vector();

/// Restores the vector returned by `get_vector()`
///
/// @param v The vector to restore
void restore_vector(Vector & v);

/// Sets the spatial dimension for the interpolation context
///
/// @param dim The spatial dimension
void set_dim(Int dim);

/// Sets the number of fields interpolated at a point for the interpolation context
///
/// @param dof The number of fields
void set_dof(Int dof);

/// Compute spatial indices for point location during interpolation
///
/// @param dm The DM object
/// @param redundant_points If `true`, all processes are passing in the same array of points.
/// Otherwise, points need to be communicated among processes.
/// @param ignore_outside_domain If `true`, ignore points outside the domain, otherwise return
/// an error
void set_up(DM dm, bool redundant_points, bool ignore_outside_domain);

/// Using the input from `dm` and `x`, calculate interpolated field values at the interpolation
/// points.
///
/// @param dm The DM object
/// @param x The local vector containing the field to be interpolated
/// @param values The vector containing the interpolated values, obtained with `get_vector()`
void evaluate(DM dm, const Vector & x, Vector & values);

private:
DMInterpolationInfo info;
};

} // namespace godzilla
110 changes: 110 additions & 0 deletions src/Interpolation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// SPDX-FileCopyrightText: 2025 David Andrs <[email protected]>
// SPDX-License-Identifier: MIT

#include "godzilla/Interpolation.h"
#include "petscdm.h"

namespace godzilla {

Interpolation::Interpolation() : info(nullptr) {}

Interpolation::~Interpolation()
{
if (this->info != nullptr)
destroy();
}

void
Interpolation::create(MPI_Comm comm)
{
PETSC_CHECK(DMInterpolationCreate(comm, &this->info));
}

void
Interpolation::destroy()
{
PETSC_CHECK(DMInterpolationDestroy(&this->info));
this->info = nullptr;
}

void
Interpolation::add_points(Int n, const Real points[])
{
auto * data = const_cast<Real *>(points);
PETSC_CHECK(DMInterpolationAddPoints(this->info, n, data));
}

void
Interpolation::add_points(const std::vector<Real> & points)
{
auto * data = const_cast<Real *>(points.data());
PETSC_CHECK(DMInterpolationAddPoints(this->info, points.size(), data));
}

Vector
Interpolation::get_coordinates() const
{
Vec v;
PETSC_CHECK(DMInterpolationGetCoordinates(this->info, &v));
return Vector(v);
}

Int
Interpolation::get_dim() const
{
Int dim;
PETSC_CHECK(DMInterpolationGetDim(this->info, &dim));
return dim;
}

Int
Interpolation::get_dof() const
{
Int dof;
PETSC_CHECK(DMInterpolationGetDof(this->info, &dof));
return dof;
}

Vector
Interpolation::get_vector()
{
Vec v;
PETSC_CHECK(DMInterpolationGetVector(this->info, &v));
return Vector(v);
}

void
Interpolation::restore_vector(Vector & v)
{
Vec vec = v;
PETSC_CHECK(DMInterpolationRestoreVector(this->info, &vec));
}

void
Interpolation::set_dim(Int dim)
{
PETSC_CHECK(DMInterpolationSetDim(this->info, dim));
}

void
Interpolation::set_dof(Int dof)
{
PETSC_CHECK(DMInterpolationSetDof(this->info, dof));
}

void
Interpolation::set_up(DM dm, bool redundant_points, bool ignore_outside_domain)
{
PETSC_CHECK(DMInterpolationSetUp(this->info,
dm,
redundant_points ? PETSC_TRUE : PETSC_FALSE,
ignore_outside_domain ? PETSC_TRUE : PETSC_FALSE));
}

void
Interpolation::evaluate(DM dm, const Vector & x, Vector & values)
{
PETSC_CHECK(DMInterpolationEvaluate(this->info, dm, x, values));
}

} // namespace godzilla
70 changes: 70 additions & 0 deletions test/src/Interpolation_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include "gmock/gmock.h"
#include "TestApp.h"
#include "godzilla/Interpolation.h"
#include "godzilla/LineMesh.h"
#include "petscdmplex.h"
#include "petscfe.h"

using namespace godzilla;

TEST(InterpolationTest, test_1d)
{
TestApp app;

auto comm = app.get_comm();

Int dim = 1;
std::array<Real, 1> lower = { 0 };
std::array<Real, 1> upper = { 1 };
std::array<Int, 1> faces = { 4 };
std::array<DMBoundaryType, 1> periodicity = { DM_BOUNDARY_GHOSTED };
DM dm;
PETSC_CHECK(DMPlexCreateBoxMesh(comm,
1,
PETSC_TRUE,
faces.data(),
lower.data(),
upper.data(),
periodicity.data(),
PETSC_FALSE,
&dm));
PetscFE fe;
PETSC_CHECK(PetscFECreateLagrange(comm, dim, 1, PETSC_TRUE, 1, PETSC_DECIDE, &fe));
PETSC_CHECK(DMSetField(dm, 0, nullptr, (PetscObject) fe));
PETSC_CHECK(DMCreateDS(dm));

Vec v_sln;
PETSC_CHECK(DMCreateGlobalVector(dm, &v_sln));
Vector sln(v_sln);
sln.set_values({ 0, 1, 2, 3, 4 }, { 2., 2.5, 3., 3.5, 4. });
sln.assemble();

Interpolation interp;
interp.create(app.get_comm());
interp.set_dim(dim);
interp.set_dof(1);
interp.add_points({ 0.125, 0.875 });
interp.set_up(dm, false, true);

auto coord = interp.get_coordinates();
auto * c = coord.get_array_read();
EXPECT_NEAR(c[0], 0.125, 1.0e-12);
EXPECT_NEAR(c[1], 0.875, 1.0e-12);
coord.restore_array_read(c);

EXPECT_EQ(interp.get_dim(), 1);
EXPECT_EQ(interp.get_dof(), 1);

auto vals = interp.get_vector();
interp.evaluate(dm, sln, vals);
auto * v = vals.get_array_read();
EXPECT_NEAR(v[0], 2.25, 1e-10);
EXPECT_NEAR(v[1], 3.75, 1e-10);
vals.restore_array_read(v);

interp.restore_vector(vals);

interp.destroy();
PetscFEDestroy(&fe);
DMDestroy(&dm);
}

0 comments on commit b810c74

Please sign in to comment.