diff --git a/tests/include/TransientRegression.h b/tests/include/TransientRegression.h index 87ab262e39..b369d249fc 100644 --- a/tests/include/TransientRegression.h +++ b/tests/include/TransientRegression.h @@ -29,6 +29,25 @@ namespace neml2 { + +struct AutoDefaultDtypeMode +{ + static std::mutex default_dtype_mutex; + + AutoDefaultDtypeMode(c10::ScalarType default_dtype) + : prev_default_dtype(torch::typeMetaToScalarType(torch::get_default_dtype())) + { + default_dtype_mutex.lock(); + torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype)); + } + ~AutoDefaultDtypeMode() + { + default_dtype_mutex.unlock(); + torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype)); + } + c10::ScalarType prev_default_dtype; +}; + class TransientDriver; class TransientRegression : public Driver diff --git a/tests/src/TransientRegression.cxx b/tests/src/TransientRegression.cxx index 35c9537007..a9f62074bb 100644 --- a/tests/src/TransientRegression.cxx +++ b/tests/src/TransientRegression.cxx @@ -30,6 +30,8 @@ namespace fs = std::filesystem; namespace neml2 { +std::mutex AutoDefaultDtypeMode::default_dtype_mutex; + register_NEML2_object(TransientRegression); OptionSet @@ -56,6 +58,9 @@ TransientRegression::TransientRegression(const OptionSet & options) bool TransientRegression::run() { + // Paranoid guard + AutoDefaultDtypeMode dtype_mode(torch::kFloat64); + _driver.run(); // Verify the result