-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #675 from andrsd/interp
feat: Adding `Interpolation` class
- Loading branch information
Showing
3 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
// SPDX-FileCopyrightText: 2025 David Andrs <[email protected]> | ||
// SPDX-License-Identifier: MIT | ||
|
||
#pragma once | ||
|
||
#include "godzilla/Error.h" | ||
#include "godzilla/Types.h" | ||
#include "godzilla/Vector.h" | ||
#include "petscdm.h" | ||
|
||
namespace godzilla { | ||
|
||
class Interpolation { | ||
public: | ||
/// Construct an interpolation object | ||
Interpolation(); | ||
|
||
~Interpolation(); | ||
|
||
/// Creates an interpolation object | ||
void create(MPI_Comm comm); | ||
|
||
/// Destroys the interpolation object | ||
void destroy(); | ||
|
||
/// Add points at which we will interpolate the fields | ||
/// | ||
/// @param n The number of points | ||
/// @param points The coordinates of the points, an array of size `n * dim` | ||
void add_points(Int n, const Real points[]); | ||
|
||
void add_points(const std::vector<Real> & points); | ||
|
||
/// Gets a Vec with the coordinates of each interpolation point | ||
/// | ||
/// @return The coordinates of the interpolation points | ||
Vector get_coordinates() const; | ||
|
||
/// Gets the spatial dimension for the interpolation context | ||
/// | ||
/// @return The spatial dimension | ||
Int get_dim() const; | ||
|
||
/// Gets the number of fields interpolated at a point | ||
/// | ||
/// @return The number of fields | ||
Int get_dof() const; | ||
|
||
/// Gets a `Vector` which can hold all the interpolated field values | ||
/// | ||
/// @return A vector capable of holding the interpolated field values | ||
/// | ||
/// @note This vector should be returned using `RestoreVector() | ||
Vector get_vector(); | ||
|
||
/// Restores the vector returned by `get_vector()` | ||
/// | ||
/// @param v The vector to restore | ||
void restore_vector(Vector & v); | ||
|
||
/// Sets the spatial dimension for the interpolation context | ||
/// | ||
/// @param dim The spatial dimension | ||
void set_dim(Int dim); | ||
|
||
/// Sets the number of fields interpolated at a point for the interpolation context | ||
/// | ||
/// @param dof The number of fields | ||
void set_dof(Int dof); | ||
|
||
/// Compute spatial indices for point location during interpolation | ||
/// | ||
/// @param dm The DM object | ||
/// @param redundant_points If `true`, all processes are passing in the same array of points. | ||
/// Otherwise, points need to be communicated among processes. | ||
/// @param ignore_outside_domain If `true`, ignore points outside the domain, otherwise return | ||
/// an error | ||
void set_up(DM dm, bool redundant_points, bool ignore_outside_domain); | ||
|
||
/// Using the input from `dm` and `x`, calculate interpolated field values at the interpolation | ||
/// points. | ||
/// | ||
/// @param dm The DM object | ||
/// @param x The local vector containing the field to be interpolated | ||
/// @param values The vector containing the interpolated values, obtained with `get_vector()` | ||
void evaluate(DM dm, const Vector & x, Vector & values); | ||
|
||
private: | ||
DMInterpolationInfo info; | ||
}; | ||
|
||
} // namespace godzilla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
// SPDX-FileCopyrightText: 2025 David Andrs <[email protected]> | ||
// SPDX-License-Identifier: MIT | ||
|
||
#include "godzilla/Interpolation.h" | ||
#include "petscdm.h" | ||
|
||
namespace godzilla { | ||
|
||
Interpolation::Interpolation() : info(nullptr) {} | ||
|
||
Interpolation::~Interpolation() | ||
{ | ||
if (this->info != nullptr) | ||
destroy(); | ||
} | ||
|
||
void | ||
Interpolation::create(MPI_Comm comm) | ||
{ | ||
PETSC_CHECK(DMInterpolationCreate(comm, &this->info)); | ||
} | ||
|
||
void | ||
Interpolation::destroy() | ||
{ | ||
PETSC_CHECK(DMInterpolationDestroy(&this->info)); | ||
this->info = nullptr; | ||
} | ||
|
||
void | ||
Interpolation::add_points(Int n, const Real points[]) | ||
{ | ||
auto * data = const_cast<Real *>(points); | ||
PETSC_CHECK(DMInterpolationAddPoints(this->info, n, data)); | ||
} | ||
|
||
void | ||
Interpolation::add_points(const std::vector<Real> & points) | ||
{ | ||
auto * data = const_cast<Real *>(points.data()); | ||
PETSC_CHECK(DMInterpolationAddPoints(this->info, points.size(), data)); | ||
} | ||
|
||
Vector | ||
Interpolation::get_coordinates() const | ||
{ | ||
Vec v; | ||
PETSC_CHECK(DMInterpolationGetCoordinates(this->info, &v)); | ||
return Vector(v); | ||
} | ||
|
||
Int | ||
Interpolation::get_dim() const | ||
{ | ||
Int dim; | ||
PETSC_CHECK(DMInterpolationGetDim(this->info, &dim)); | ||
return dim; | ||
} | ||
|
||
Int | ||
Interpolation::get_dof() const | ||
{ | ||
Int dof; | ||
PETSC_CHECK(DMInterpolationGetDof(this->info, &dof)); | ||
return dof; | ||
} | ||
|
||
Vector | ||
Interpolation::get_vector() | ||
{ | ||
Vec v; | ||
PETSC_CHECK(DMInterpolationGetVector(this->info, &v)); | ||
return Vector(v); | ||
} | ||
|
||
void | ||
Interpolation::restore_vector(Vector & v) | ||
{ | ||
Vec vec = v; | ||
PETSC_CHECK(DMInterpolationRestoreVector(this->info, &vec)); | ||
} | ||
|
||
void | ||
Interpolation::set_dim(Int dim) | ||
{ | ||
PETSC_CHECK(DMInterpolationSetDim(this->info, dim)); | ||
} | ||
|
||
void | ||
Interpolation::set_dof(Int dof) | ||
{ | ||
PETSC_CHECK(DMInterpolationSetDof(this->info, dof)); | ||
} | ||
|
||
void | ||
Interpolation::set_up(DM dm, bool redundant_points, bool ignore_outside_domain) | ||
{ | ||
PETSC_CHECK(DMInterpolationSetUp(this->info, | ||
dm, | ||
redundant_points ? PETSC_TRUE : PETSC_FALSE, | ||
ignore_outside_domain ? PETSC_TRUE : PETSC_FALSE)); | ||
} | ||
|
||
void | ||
Interpolation::evaluate(DM dm, const Vector & x, Vector & values) | ||
{ | ||
PETSC_CHECK(DMInterpolationEvaluate(this->info, dm, x, values)); | ||
} | ||
|
||
} // namespace godzilla |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#include "gmock/gmock.h" | ||
#include "TestApp.h" | ||
#include "godzilla/Interpolation.h" | ||
#include "godzilla/LineMesh.h" | ||
#include "petscdmplex.h" | ||
#include "petscfe.h" | ||
|
||
using namespace godzilla; | ||
|
||
TEST(InterpolationTest, test_1d) | ||
{ | ||
TestApp app; | ||
|
||
auto comm = app.get_comm(); | ||
|
||
Int dim = 1; | ||
std::array<Real, 1> lower = { 0 }; | ||
std::array<Real, 1> upper = { 1 }; | ||
std::array<Int, 1> faces = { 4 }; | ||
std::array<DMBoundaryType, 1> periodicity = { DM_BOUNDARY_GHOSTED }; | ||
DM dm; | ||
PETSC_CHECK(DMPlexCreateBoxMesh(comm, | ||
1, | ||
PETSC_TRUE, | ||
faces.data(), | ||
lower.data(), | ||
upper.data(), | ||
periodicity.data(), | ||
PETSC_FALSE, | ||
&dm)); | ||
PetscFE fe; | ||
PETSC_CHECK(PetscFECreateLagrange(comm, dim, 1, PETSC_TRUE, 1, PETSC_DECIDE, &fe)); | ||
PETSC_CHECK(DMSetField(dm, 0, nullptr, (PetscObject) fe)); | ||
PETSC_CHECK(DMCreateDS(dm)); | ||
|
||
Vec v_sln; | ||
PETSC_CHECK(DMCreateGlobalVector(dm, &v_sln)); | ||
Vector sln(v_sln); | ||
sln.set_values({ 0, 1, 2, 3, 4 }, { 2., 2.5, 3., 3.5, 4. }); | ||
sln.assemble(); | ||
|
||
Interpolation interp; | ||
interp.create(app.get_comm()); | ||
interp.set_dim(dim); | ||
interp.set_dof(1); | ||
interp.add_points({ 0.125, 0.875 }); | ||
interp.set_up(dm, false, true); | ||
|
||
auto coord = interp.get_coordinates(); | ||
auto * c = coord.get_array_read(); | ||
EXPECT_NEAR(c[0], 0.125, 1.0e-12); | ||
EXPECT_NEAR(c[1], 0.875, 1.0e-12); | ||
coord.restore_array_read(c); | ||
|
||
EXPECT_EQ(interp.get_dim(), 1); | ||
EXPECT_EQ(interp.get_dof(), 1); | ||
|
||
auto vals = interp.get_vector(); | ||
interp.evaluate(dm, sln, vals); | ||
auto * v = vals.get_array_read(); | ||
EXPECT_NEAR(v[0], 2.25, 1e-10); | ||
EXPECT_NEAR(v[1], 3.75, 1e-10); | ||
vals.restore_array_read(v); | ||
|
||
interp.restore_vector(vals); | ||
|
||
interp.destroy(); | ||
PetscFEDestroy(&fe); | ||
DMDestroy(&dm); | ||
} |