diff --git a/include/godzilla/SNESolver.h b/include/godzilla/SNESolver.h index eba1e87c..9aba0030 100644 --- a/include/godzilla/SNESolver.h +++ b/include/godzilla/SNESolver.h @@ -78,6 +78,16 @@ class SNESolver { /// @param dm The `DM` void set_dm(DM dm); + /// Sets the method for the nonlinear solver. + /// + /// @param A known method + void set_type(SNESType type); + + /// Gets the SNES method type (as a string). + /// + /// @return SNES method + std::string get_type() const; + /// Sets non-linear solver options from the options database void set_from_options(); diff --git a/src/SNESolver.cpp b/src/SNESolver.cpp index ff65986c..f44b06ef 100644 --- a/src/SNESolver.cpp +++ b/src/SNESolver.cpp @@ -139,6 +139,22 @@ SNESolver::set_dm(DM dm) PETSC_CHECK(SNESSetDM(this->snes, dm)); } +void +SNESolver::set_type(SNESType type) +{ + CALL_STACK_MSG(); + PETSC_CHECK(SNESSetType(this->snes, type)); +} + +std::string +SNESolver::get_type() const +{ + CALL_STACK_MSG(); + SNESType type; + PETSC_CHECK(SNESGetType(this->snes, &type)); + return std::string(type); +} + void SNESolver::set_from_options() { diff --git a/test/src/SNESolver_test.cpp b/test/src/SNESolver_test.cpp index 1b3fb32e..e05d4fa0 100644 --- a/test/src/SNESolver_test.cpp +++ b/test/src/SNESolver_test.cpp @@ -71,3 +71,17 @@ TEST(SNESolverTest, mat_create_mf) auto m = snes.mat_create_mf(); EXPECT_EQ(m.get_type(), "mffd"); } + +TEST(SNESolverTest, type) +{ + TestApp app; + TestNLProblem prob; + auto comm = app.get_comm(); + + SNESolver snes; + snes.create(comm); + + snes.set_type("newtontr"); + + EXPECT_EQ(snes.get_type(), "newtontr"); +}