diff --git a/include/godzilla/Interpolation.h b/include/godzilla/Interpolation.h new file mode 100644 index 00000000..c1740e96 --- /dev/null +++ b/include/godzilla/Interpolation.h @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: 2025 David Andrs +// SPDX-License-Identifier: MIT + +#pragma once + +#include "godzilla/Error.h" +#include "godzilla/Types.h" +#include "godzilla/Vector.h" +#include "petscdm.h" + +namespace godzilla { + +class Interpolation { +public: + /// Construct an interpolation object + Interpolation(); + + ~Interpolation(); + + /// Creates an interpolation object + void create(MPI_Comm comm); + + /// Destroys the interpolation object + void destroy(); + + /// Add points at which we will interpolate the fields + /// + /// @param n The number of points + /// @param points The coordinates of the points, an array of size `n * dim` + void add_points(Int n, const Real points[]); + + void add_points(const std::vector & points); + + /// Gets a Vec with the coordinates of each interpolation point + /// + /// @return The coordinates of the interpolation points + Vector get_coordinates() const; + + /// Gets the spatial dimension for the interpolation context + /// + /// @return The spatial dimension + Int get_dim() const; + + /// Gets the number of fields interpolated at a point + /// + /// @return The number of fields + Int get_dof() const; + + /// Gets a `Vector` which can hold all the interpolated field values + /// + /// @return A vector capable of holding the interpolated field values + /// + /// @note This vector should be returned using `RestoreVector() + Vector get_vector(); + + /// Restores the vector returned by `get_vector()` + /// + /// @param v The vector to restore + void restore_vector(Vector & v); + + /// Sets the spatial dimension for the interpolation context + /// + /// @param dim The spatial dimension + void set_dim(Int dim); + + /// Sets the number of fields interpolated at a point for the interpolation context + /// + /// @param dof The number of fields + void set_dof(Int dof); + + /// Compute spatial indices for point location during interpolation + /// + /// @param dm The DM object + /// @param redundant_points If `true`, all processes are passing in the same array of points. + /// Otherwise, points need to be communicated among processes. + /// @param ignore_outside_domain If `true`, ignore points outside the domain, otherwise return + /// an error + void set_up(DM dm, bool redundant_points, bool ignore_outside_domain); + + /// Using the input from `dm` and `x`, calculate interpolated field values at the interpolation + /// points. + /// + /// @param dm The DM object + /// @param x The local vector containing the field to be interpolated + /// @param values The vector containing the interpolated values, obtained with `get_vector()` + void evaluate(DM dm, const Vector & x, Vector & values); + +private: + DMInterpolationInfo info; +}; + +} // namespace godzilla diff --git a/src/Interpolation.cpp b/src/Interpolation.cpp new file mode 100644 index 00000000..15ea1951 --- /dev/null +++ b/src/Interpolation.cpp @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: 2025 David Andrs +// SPDX-License-Identifier: MIT + +#include "godzilla/Interpolation.h" +#include "petscdm.h" + +namespace godzilla { + +Interpolation::Interpolation() : info(nullptr) {} + +Interpolation::~Interpolation() +{ + if (this->info != nullptr) + destroy(); +} + +void +Interpolation::create(MPI_Comm comm) +{ + PETSC_CHECK(DMInterpolationCreate(comm, &this->info)); +} + +void +Interpolation::destroy() +{ + PETSC_CHECK(DMInterpolationDestroy(&this->info)); + this->info = nullptr; +} + +void +Interpolation::add_points(Int n, const Real points[]) +{ + auto * data = const_cast(points); + PETSC_CHECK(DMInterpolationAddPoints(this->info, n, data)); +} + +void +Interpolation::add_points(const std::vector & points) +{ + auto * data = const_cast(points.data()); + PETSC_CHECK(DMInterpolationAddPoints(this->info, points.size(), data)); +} + +Vector +Interpolation::get_coordinates() const +{ + Vec v; + PETSC_CHECK(DMInterpolationGetCoordinates(this->info, &v)); + return Vector(v); +} + +Int +Interpolation::get_dim() const +{ + Int dim; + PETSC_CHECK(DMInterpolationGetDim(this->info, &dim)); + return dim; +} + +Int +Interpolation::get_dof() const +{ + Int dof; + PETSC_CHECK(DMInterpolationGetDof(this->info, &dof)); + return dof; +} + +Vector +Interpolation::get_vector() +{ + Vec v; + PETSC_CHECK(DMInterpolationGetVector(this->info, &v)); + return Vector(v); +} + +void +Interpolation::restore_vector(Vector & v) +{ + Vec vec = v; + PETSC_CHECK(DMInterpolationRestoreVector(this->info, &vec)); +} + +void +Interpolation::set_dim(Int dim) +{ + PETSC_CHECK(DMInterpolationSetDim(this->info, dim)); +} + +void +Interpolation::set_dof(Int dof) +{ + PETSC_CHECK(DMInterpolationSetDof(this->info, dof)); +} + +void +Interpolation::set_up(DM dm, bool redundant_points, bool ignore_outside_domain) +{ + PETSC_CHECK(DMInterpolationSetUp(this->info, + dm, + redundant_points ? PETSC_TRUE : PETSC_FALSE, + ignore_outside_domain ? PETSC_TRUE : PETSC_FALSE)); +} + +void +Interpolation::evaluate(DM dm, const Vector & x, Vector & values) +{ + PETSC_CHECK(DMInterpolationEvaluate(this->info, dm, x, values)); +} + +} // namespace godzilla diff --git a/test/src/Interpolation_test.cpp b/test/src/Interpolation_test.cpp new file mode 100644 index 00000000..ef145b6f --- /dev/null +++ b/test/src/Interpolation_test.cpp @@ -0,0 +1,70 @@ +#include "gmock/gmock.h" +#include "TestApp.h" +#include "godzilla/Interpolation.h" +#include "godzilla/LineMesh.h" +#include "petscdmplex.h" +#include "petscfe.h" + +using namespace godzilla; + +TEST(InterpolationTest, test_1d) +{ + TestApp app; + + auto comm = app.get_comm(); + + Int dim = 1; + std::array lower = { 0 }; + std::array upper = { 1 }; + std::array faces = { 4 }; + std::array periodicity = { DM_BOUNDARY_GHOSTED }; + DM dm; + PETSC_CHECK(DMPlexCreateBoxMesh(comm, + 1, + PETSC_TRUE, + faces.data(), + lower.data(), + upper.data(), + periodicity.data(), + PETSC_FALSE, + &dm)); + PetscFE fe; + PETSC_CHECK(PetscFECreateLagrange(comm, dim, 1, PETSC_TRUE, 1, PETSC_DECIDE, &fe)); + PETSC_CHECK(DMSetField(dm, 0, nullptr, (PetscObject) fe)); + PETSC_CHECK(DMCreateDS(dm)); + + Vec v_sln; + PETSC_CHECK(DMCreateGlobalVector(dm, &v_sln)); + Vector sln(v_sln); + sln.set_values({ 0, 1, 2, 3, 4 }, { 2., 2.5, 3., 3.5, 4. }); + sln.assemble(); + + Interpolation interp; + interp.create(app.get_comm()); + interp.set_dim(dim); + interp.set_dof(1); + interp.add_points({ 0.125, 0.875 }); + interp.set_up(dm, false, true); + + auto coord = interp.get_coordinates(); + auto * c = coord.get_array_read(); + EXPECT_NEAR(c[0], 0.125, 1.0e-12); + EXPECT_NEAR(c[1], 0.875, 1.0e-12); + coord.restore_array_read(c); + + EXPECT_EQ(interp.get_dim(), 1); + EXPECT_EQ(interp.get_dof(), 1); + + auto vals = interp.get_vector(); + interp.evaluate(dm, sln, vals); + auto * v = vals.get_array_read(); + EXPECT_NEAR(v[0], 2.25, 1e-10); + EXPECT_NEAR(v[1], 3.75, 1e-10); + vals.restore_array_read(v); + + interp.restore_vector(vals); + + interp.destroy(); + PetscFEDestroy(&fe); + DMDestroy(&dm); +}