Skip to content

Commit

Permalink
Merge pull request #674 from andrsd/nl-delegates
Browse files Browse the repository at this point in the history
Use delegates in NonlinearProblem for residual|Jacobian evaluation
  • Loading branch information
andrsd authored Jan 5, 2025
2 parents 8c9a539 + df4a292 commit 86e1cf1
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 70 deletions.
3 changes: 1 addition & 2 deletions include/godzilla/ExplicitDGLinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ class ExplicitDGLinearProblem :
void post_step() override;

private:
void compute_residual(const Vector & x, Vector & f) final;
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) final;
SNESolver create_sne_solver() override;

public:
static Parameters parameters();
Expand Down
1 change: 1 addition & 0 deletions include/godzilla/ExplicitFELinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ExplicitFELinearProblem : public FENonlinearProblem, public ExplicitProble
void post_step() override;

private:
SNESolver create_sne_solver() override;
void compute_rhs_local(Real time, const Vector & x, Vector & F) override;
void compute_rhs_function_fem(Real time, const Vector & loc_x, Vector & loc_g);

Expand Down
3 changes: 1 addition & 2 deletions include/godzilla/ExplicitFVLinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ class ExplicitFVLinearProblem :
void post_step() override;

private:
SNESolver create_sne_solver() override;
void compute_rhs_local(Real time, const Vector & x, Vector & F) override;
void compute_residual(const Vector & x, Vector & f) final;
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) final;

public:
static Parameters parameters();
Expand Down
4 changes: 2 additions & 2 deletions include/godzilla/FENonlinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class FENonlinearProblem : public NonlinearProblem, public FEProblemInterface {
const IndexSet & facets);

private:
void compute_residual(const Vector & x, Vector & f) override;
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) override;
void compute_residual(const Vector & x, Vector & f);
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp);
virtual void compute_boundary(Vector & x);

enum State { INITIAL, FINAL } state;
Expand Down
2 changes: 2 additions & 0 deletions include/godzilla/ImplicitFENonlinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ImplicitFENonlinearProblem : public FENonlinearProblem, public TransientPr
}

private:
SNESolver create_sne_solver() override;

/// Form the local residual `f` from the local input `x`
///
/// @param time The time
Expand Down
30 changes: 27 additions & 3 deletions include/godzilla/NonlinearProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,34 @@ class NonlinearProblem : public Problem {

void solve();

/// Set residual evaluation function
///
/// @tparam T C++ class type
/// @param instance Instance of class T
/// @param method Member function in class T to compute residual
template <class T>
void
set_function(T * instance, void (T::*method)(const Vector &, Vector &))
{
this->snes.set_function(this->r, instance, method);
}

/// Set Jacobian evaluation function
///
/// @tparam T C++ class type
/// @param instance Instance of class T
/// @param method Member function in class T to compute Jacobian
template <class T>
void
set_jacobian(T * instance, void (T::*method)(const Vector &, Matrix &, Matrix &))
{
this->snes.set_jacobian(this->J, this->J, instance, method);
}

private:
/// Create a SNESolver
virtual SNESolver create_sne_solver();

/// Set up line search
virtual void set_up_line_search();

Expand All @@ -91,9 +118,6 @@ class NonlinearProblem : public Problem {
/// Method for setting matrix properties
virtual void set_up_matrix_properties();

virtual void compute_residual(const Vector & x, Vector & f) = 0;
virtual void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) = 0;

/// Nonlinear solver
SNESolver snes;
/// Linear solver
Expand Down
19 changes: 7 additions & 12 deletions src/ExplicitDGLinearProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,18 @@ ExplicitDGLinearProblem::get_step_num() const
return ExplicitProblemInterface::get_step_number();
}

SNESolver
ExplicitDGLinearProblem::create_sne_solver()
{
return ExplicitProblemInterface::get_snes();
}

void
ExplicitDGLinearProblem::init()
{
CALL_STACK_MSG();
ExplicitProblemInterface::init();
auto snes = ExplicitProblemInterface::get_snes();
NonlinearProblem::set_snes(snes);
NonlinearProblem::init();
DGProblemInterface::init();
}

Expand Down Expand Up @@ -129,14 +134,4 @@ ExplicitDGLinearProblem::post_step()
output(EXECUTE_ON_TIMESTEP);
}

void
ExplicitDGLinearProblem::compute_residual(const Vector & x, Vector & f)
{
}

void
ExplicitDGLinearProblem::compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp)
{
}

} // namespace godzilla
9 changes: 7 additions & 2 deletions src/ExplicitFELinearProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,18 @@ ExplicitFELinearProblem::get_step_num() const
return ExplicitProblemInterface::get_step_number();
}

SNESolver
ExplicitFELinearProblem::create_sne_solver()
{
return ExplicitProblemInterface::get_snes();
}

void
ExplicitFELinearProblem::init()
{
CALL_STACK_MSG();
ExplicitProblemInterface::init();
auto snes = ExplicitProblemInterface::get_snes();
NonlinearProblem::set_snes(snes);
NonlinearProblem::init();
FEProblemInterface::init();
// so that the call to DMTSCreateRHSMassMatrix would form the mass matrix
for (Int i = 0; i < get_num_fields(); i++)
Expand Down
19 changes: 7 additions & 12 deletions src/ExplicitFVLinearProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@ ExplicitFVLinearProblem::get_step_num() const
return ExplicitProblemInterface::get_step_number();
}

SNESolver
ExplicitFVLinearProblem::create_sne_solver()
{
return ExplicitProblemInterface::get_snes();
}

void
ExplicitFVLinearProblem::init()
{
CALL_STACK_MSG();
ExplicitProblemInterface::init();
auto snes = ExplicitProblemInterface::get_snes();
NonlinearProblem::set_snes(snes);
NonlinearProblem::init();
FVProblemInterface::init();
}

Expand Down Expand Up @@ -124,16 +129,6 @@ ExplicitFVLinearProblem::compute_rhs_local(Real time, const Vector & x, Vector &
PETSC_CHECK(DMPlexTSComputeRHSFunctionFVM(get_dm(), time, x, F, this));
}

void
ExplicitFVLinearProblem::compute_residual(const Vector & x, Vector & f)
{
}

void
ExplicitFVLinearProblem::compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp)
{
}

void
ExplicitFVLinearProblem::post_step()
{
Expand Down
10 changes: 7 additions & 3 deletions src/ImplicitFENonlinearProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,18 @@ ImplicitFENonlinearProblem::get_step_num() const
return TransientProblemInterface::get_step_number();
}

SNESolver
ImplicitFENonlinearProblem::create_sne_solver()
{
return TransientProblemInterface::get_snes();
}

void
ImplicitFENonlinearProblem::init()
{
CALL_STACK_MSG();
TransientProblemInterface::init();
auto snes = TransientProblemInterface::get_snes();
NonlinearProblem::set_snes(snes);
FEProblemInterface::init();
FENonlinearProblem::init();
}

void
Expand Down
19 changes: 12 additions & 7 deletions src/NonlinearProblem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,15 +131,22 @@ NonlinearProblem::get_ksp() const
return this->ksp;
}

SNESolver
NonlinearProblem::create_sne_solver()
{
CALL_STACK_MSG();
SNESolver snes;
snes.create(get_comm());
snes.set_dm(get_dm());
PETSC_CHECK(DMSetApplicationContext(get_dm(), this));
return snes;
}

void
NonlinearProblem::init()
{
CALL_STACK_MSG();
DM dm = get_dm();
this->snes.create(get_comm());
this->snes.set_dm(dm);
PETSC_CHECK(DMSetApplicationContext(dm, this));
set_snes(this->snes);
set_snes(create_sne_solver());
}

void
Expand Down Expand Up @@ -191,8 +198,6 @@ void
NonlinearProblem::set_up_callbacks()
{
CALL_STACK_MSG();
this->snes.set_function(this->r, this, &NonlinearProblem::compute_residual);
this->snes.set_jacobian(this->J, this->J, this, &NonlinearProblem::compute_jacobian);
}

void
Expand Down
61 changes: 36 additions & 25 deletions test/src/NonlinearProblem_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class G1DTestNonlinearProblem : public NonlinearProblem {
void call_initial_guess();

protected:
void compute_residual(const Vector & x, Vector & f) override;
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) override;
void set_up_callbacks() override;
void compute_residual(const Vector & x, Vector & f);
void compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp);

PetscSection s;
};

Expand Down Expand Up @@ -56,6 +58,13 @@ G1DTestNonlinearProblem::call_initial_guess()
NonlinearProblem::set_up_initial_guess();
}

void
G1DTestNonlinearProblem::set_up_callbacks()
{
set_function(this, &G1DTestNonlinearProblem::compute_residual);
set_jacobian(this, &G1DTestNonlinearProblem::compute_jacobian);
}

void
G1DTestNonlinearProblem::compute_residual(const Vector & x, Vector & f)
{
Expand Down Expand Up @@ -139,8 +148,29 @@ TEST(NonlinearProblemTest, run)

MOCK_METHOD(void, set_up_initial_guess, ());
MOCK_METHOD(void, on_initial, ());
MOCK_METHOD(void, compute_residual, (const Vector & x, Vector & f));
MOCK_METHOD(void, compute_jacobian, (const Vector & x, Matrix & J, Matrix & Jp));

void
compute_residual(const Vector & x, Vector & f)
{
f.zero();
this->compute_residual_called = true;
}

void
compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp)
{
this->compute_jacobian_called = true;
}

void
set_up_callbacks()
{
set_function(this, &MockNonlinearProblem::compute_residual);
set_jacobian(this, &MockNonlinearProblem::compute_jacobian);
}

bool compute_residual_called = false;
bool compute_jacobian_called = false;
};

TestApp app;
Expand All @@ -159,8 +189,9 @@ TEST(NonlinearProblemTest, run)

EXPECT_CALL(prob, set_up_initial_guess);
EXPECT_CALL(prob, on_initial);
EXPECT_CALL(prob, compute_residual);
prob.run();
EXPECT_TRUE(prob.compute_residual_called);
EXPECT_FALSE(prob.compute_jacobian_called);
}

TEST(NonlinearProblemTest, line_search_type)
Expand All @@ -174,16 +205,6 @@ TEST(NonlinearProblemTest, line_search_type)
{
return NonlinearProblem::get_snes();
}

void
compute_residual(const Vector & x, Vector & F) override
{
}

void
compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) override
{
}
};

TestApp app;
Expand Down Expand Up @@ -219,16 +240,6 @@ TEST(NonlinearProblemTest, invalid_line_search_type)
class MockNonlinearProblem : public NonlinearProblem {
public:
explicit MockNonlinearProblem(const Parameters & params) : NonlinearProblem(params) {}

void
compute_residual(const Vector & x, Vector & f) override
{
}

void
compute_jacobian(const Vector & x, Matrix & J, Matrix & Jp) override
{
}
};

TestApp app;
Expand Down

0 comments on commit 86e1cf1

Please sign in to comment.