diff --git a/configure b/configure
index cbfff8ac9a..068480a473 100755
--- a/configure
+++ b/configure
@@ -680,6 +680,7 @@ infodir
docdir
oldincludedir
includedir
+runstatedir
localstatedir
sharedstatedir
sysconfdir
@@ -809,6 +810,7 @@ datadir='${datarootdir}'
sysconfdir='${prefix}/etc'
sharedstatedir='${prefix}/com'
localstatedir='${prefix}/var'
+runstatedir='${localstatedir}/run'
includedir='${prefix}/include'
oldincludedir='/usr/include'
docdir='${datarootdir}/doc/${PACKAGE_TARNAME}'
@@ -1061,6 +1063,15 @@ do
| -silent | --silent | --silen | --sile | --sil)
silent=yes ;;
+ -runstatedir | --runstatedir | --runstatedi | --runstated \
+ | --runstate | --runstat | --runsta | --runst | --runs \
+ | --run | --ru | --r)
+ ac_prev=runstatedir ;;
+ -runstatedir=* | --runstatedir=* | --runstatedi=* | --runstated=* \
+ | --runstate=* | --runstat=* | --runsta=* | --runst=* | --runs=* \
+ | --run=* | --ru=* | --r=*)
+ runstatedir=$ac_optarg ;;
+
-sbindir | --sbindir | --sbindi | --sbind | --sbin | --sbi | --sb)
ac_prev=sbindir ;;
-sbindir=* | --sbindir=* | --sbindi=* | --sbind=* | --sbin=* \
@@ -1198,7 +1209,7 @@ fi
for ac_var in exec_prefix prefix bindir sbindir libexecdir datarootdir \
datadir sysconfdir sharedstatedir localstatedir includedir \
oldincludedir docdir infodir htmldir dvidir pdfdir psdir \
- libdir localedir mandir
+ libdir localedir mandir runstatedir
do
eval ac_val=\$$ac_var
# Remove trailing slashes.
@@ -1351,6 +1362,7 @@ Fine tuning of the installation directories:
--sysconfdir=DIR read-only single-machine data [PREFIX/etc]
--sharedstatedir=DIR modifiable architecture-independent data [PREFIX/com]
--localstatedir=DIR modifiable single-machine data [PREFIX/var]
+ --runstatedir=DIR modifiable per-process data [LOCALSTATEDIR/run]
--libdir=DIR object code libraries [EPREFIX/lib]
--includedir=DIR C header files [PREFIX/include]
--oldincludedir=DIR C header files for non-gcc [/usr/include]
diff --git a/src/pytorch/PytorchModel.cpp b/src/pytorch/PytorchModel.cpp
index ae5861a0bb..7382ed25e5 100644
--- a/src/pytorch/PytorchModel.cpp
+++ b/src/pytorch/PytorchModel.cpp
@@ -1,5 +1,5 @@
/* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
-Copyright (c) 2022-2023 of Luigi Bonati.
+Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio.
The pytorch module is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
@@ -31,6 +31,23 @@ along with plumed. If not, see .
#include
#include
+// We have to do a backward compatability hack for <1.10
+// https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4
+// Basically, the check in torch::jit::freeze
+// (see https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479)
+// is wrong, and we have ro "reimplement" the function
+// to get around that...
+// it's broken in 1.8 and 1.9
+// BUT the internal logic in the function is wrong in 1.10
+// So we only use torch::jit::freeze in >=1.11
+// credits for this implementation of the hack to the NequIP guys
+#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
+#define DO_TORCH_FREEZE_HACK
+// For the hack, need more headers:
+#include
+#include
+#endif
+
using namespace std;
namespace PLMD {
@@ -68,6 +85,7 @@ class PytorchModel :
unsigned _n_in;
unsigned _n_out;
torch::jit::script::Module _model;
+ torch::Device device = torch::kCPU;
public:
explicit PytorchModel(const ActionOptions&);
@@ -103,10 +121,18 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
std::string fname="model.ptc";
parse("FILE",fname);
+
+ // we create the metatdata dict
+ std::unordered_map metadata = {
+ {"_jit_bailout_depth", ""},
+ {"_jit_fusion_strategy", ""}
+ };
+
//deserialize the model from file
try {
- _model = torch::jit::load(fname);
+ _model = torch::jit::load(fname, device, metadata);
}
+
//if an error is thrown check if the file exists or not
catch (const c10::Error& e) {
std::ifstream infile(fname);
@@ -124,13 +150,67 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
plumed_merror("The FILE: '"+fname+"' does not exist.");
}
}
-
checkRead();
+// Optimize model
+ _model.eval();
+#ifdef DO_TORCH_FREEZE_HACK
+ // Do the hack
+ // Copied from the implementation of torch::jit::freeze,
+ // except without the broken check
+ // See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp
+ bool optimize_numerics = true; // the default
+ // the {} is preserved_attrs
+ auto out_mod = torch::jit::freeze_module(
+ _model, {}
+ );
+ // See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436
+ auto graph = out_mod.get_method("forward").graph();
+ OptimizeFrozenGraph(graph, optimize_numerics);
+ _model = out_mod;
+#else
+ // Do it normally
+ _model = torch::jit::freeze(_model);
+#endif
+
+#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
+ // Set JIT bailout to avoid long recompilations for many steps
+ size_t jit_bailout_depth;
+ if (metadata["_jit_bailout_depth"].empty()) {
+ // This is the default used in the Python code
+ jit_bailout_depth = 1;
+ } else {
+ jit_bailout_depth = std::stoi(metadata["_jit_bailout_depth"]);
+ }
+ torch::jit::getBailoutDepth() = jit_bailout_depth;
+#else
+ // In PyTorch >=1.11, this is now set_fusion_strategy
+ torch::jit::FusionStrategy strategy;
+ if (metadata["_jit_fusion_strategy"].empty()) {
+ // This is the default used in the Python code
+ strategy = {{torch::jit::FusionBehavior::DYNAMIC, 0}};
+ } else {
+ std::stringstream strat_stream(metadata["_jit_fusion_strategy"]);
+ std::string fusion_type, fusion_depth;
+ while(std::getline(strat_stream, fusion_type, ',')) {
+ std::getline(strat_stream, fusion_depth, ';');
+ strategy.push_back({fusion_type == "STATIC" ? torch::jit::FusionBehavior::STATIC : torch::jit::FusionBehavior::DYNAMIC, std::stoi(fusion_depth)});
+ }
+ }
+ torch::jit::setFusionStrategy(strategy);
+#endif
+
+// TODO check torch::jit::optimize_for_inference() for more complex models
+// This could speed up the code, it was not available on LTS
+#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10)
+ _model = torch::jit::optimize_for_inference(_model);
+#endif
+
//check the dimension of the output
log.printf("Checking output dimension:\n");
std::vector input_test (_n_in);
torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in});
+ single_input = single_input.to(device);
std::vector inputs;
inputs.push_back( single_input );
torch::Tensor output = _model.forward( inputs ).toTensor();
@@ -150,54 +230,55 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
log.printf("Number of outputs: %d \n",_n_out);
log.printf(" Bibliography: ");
log< current_S(_n_in);
for(unsigned i=0; i<_n_in; i++)
current_S[i]=getArgument(i);
//convert to tensor
- torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in});
+ torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device);
input_S.set_requires_grad(true);
//convert to Ivalue
std::vector inputs;
inputs.push_back( input_S );
//calculate output
torch::Tensor output = _model.forward( inputs ).toTensor();
- //set CV values
- vector cvs = this->tensor_to_vector (output);
- for(unsigned j=0; j<_n_out; j++) {
- string name_comp = "node-"+std::to_string(j);
- getPntrToComponent(name_comp)->set(cvs[j]);
- }
- //derivatives
+
+
for(unsigned j=0; j<_n_out; j++) {
- // expand dim to have shape (1,_n_out)
- int batch_size = 1;
- auto grad_output = torch::ones({1}).expand({batch_size, 1});
- // calculate derivatives with automatic differentiation
+ auto grad_output = torch::ones({1}).expand({1, 1}).to(device);
auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)},
{input_S},
/*grad_outputs=*/ {grad_output},
/*retain_graph=*/true,
- /*create_graph=*/false);
- // add dimension
- auto grad = gradient[0].unsqueeze(/*dim=*/1);
- //convert to vector
- vector der = this->tensor_to_vector ( grad );
+ /*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector
+ vector der = this->tensor_to_vector ( gradient );
string name_comp = "node-"+std::to_string(j);
//set derivatives of component j
for(unsigned i=0; i<_n_in; i++)
setDerivative( getPntrToComponent(name_comp),i, der[i] );
}
-}
-}
-}
+
+ //set CV values
+ vector cvs = this->tensor_to_vector (output);
+ for(unsigned j=0; j<_n_out; j++) {
+ string name_comp = "node-"+std::to_string(j);
+ getPntrToComponent(name_comp)->set(cvs[j]);
+ }
+
}
-#endif
+
+} //PLMD
+} //function
+} //pytorch
+
+#endif //PLUMED_HAS_LIBTORCH