diff --git a/configure b/configure index cbfff8ac9a..068480a473 100755 --- a/configure +++ b/configure @@ -680,6 +680,7 @@ infodir docdir oldincludedir includedir +runstatedir localstatedir sharedstatedir sysconfdir @@ -809,6 +810,7 @@ datadir='${datarootdir}' sysconfdir='${prefix}/etc' sharedstatedir='${prefix}/com' localstatedir='${prefix}/var' +runstatedir='${localstatedir}/run' includedir='${prefix}/include' oldincludedir='/usr/include' docdir='${datarootdir}/doc/${PACKAGE_TARNAME}' @@ -1061,6 +1063,15 @@ do | -silent | --silent | --silen | --sile | --sil) silent=yes ;; + -runstatedir | --runstatedir | --runstatedi | --runstated \ + | --runstate | --runstat | --runsta | --runst | --runs \ + | --run | --ru | --r) + ac_prev=runstatedir ;; + -runstatedir=* | --runstatedir=* | --runstatedi=* | --runstated=* \ + | --runstate=* | --runstat=* | --runsta=* | --runst=* | --runs=* \ + | --run=* | --ru=* | --r=*) + runstatedir=$ac_optarg ;; + -sbindir | --sbindir | --sbindi | --sbind | --sbin | --sbi | --sb) ac_prev=sbindir ;; -sbindir=* | --sbindir=* | --sbindi=* | --sbind=* | --sbin=* \ @@ -1198,7 +1209,7 @@ fi for ac_var in exec_prefix prefix bindir sbindir libexecdir datarootdir \ datadir sysconfdir sharedstatedir localstatedir includedir \ oldincludedir docdir infodir htmldir dvidir pdfdir psdir \ - libdir localedir mandir + libdir localedir mandir runstatedir do eval ac_val=\$$ac_var # Remove trailing slashes. @@ -1351,6 +1362,7 @@ Fine tuning of the installation directories: --sysconfdir=DIR read-only single-machine data [PREFIX/etc] --sharedstatedir=DIR modifiable architecture-independent data [PREFIX/com] --localstatedir=DIR modifiable single-machine data [PREFIX/var] + --runstatedir=DIR modifiable per-process data [LOCALSTATEDIR/run] --libdir=DIR object code libraries [EPREFIX/lib] --includedir=DIR C header files [PREFIX/include] --oldincludedir=DIR C header files for non-gcc [/usr/include] diff --git a/src/pytorch/PytorchModel.cpp b/src/pytorch/PytorchModel.cpp index ae5861a0bb..7382ed25e5 100644 --- a/src/pytorch/PytorchModel.cpp +++ b/src/pytorch/PytorchModel.cpp @@ -1,5 +1,5 @@ /* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -Copyright (c) 2022-2023 of Luigi Bonati. +Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio. The pytorch module is free software: you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by @@ -31,6 +31,23 @@ along with plumed. If not, see . #include #include +// We have to do a backward compatability hack for <1.10 +// https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4 +// Basically, the check in torch::jit::freeze +// (see https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479) +// is wrong, and we have ro "reimplement" the function +// to get around that... +// it's broken in 1.8 and 1.9 +// BUT the internal logic in the function is wrong in 1.10 +// So we only use torch::jit::freeze in >=1.11 +// credits for this implementation of the hack to the NequIP guys +#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10) +#define DO_TORCH_FREEZE_HACK +// For the hack, need more headers: +#include +#include +#endif + using namespace std; namespace PLMD { @@ -68,6 +85,7 @@ class PytorchModel : unsigned _n_in; unsigned _n_out; torch::jit::script::Module _model; + torch::Device device = torch::kCPU; public: explicit PytorchModel(const ActionOptions&); @@ -103,10 +121,18 @@ PytorchModel::PytorchModel(const ActionOptions&ao): std::string fname="model.ptc"; parse("FILE",fname); + + // we create the metatdata dict + std::unordered_map metadata = { + {"_jit_bailout_depth", ""}, + {"_jit_fusion_strategy", ""} + }; + //deserialize the model from file try { - _model = torch::jit::load(fname); + _model = torch::jit::load(fname, device, metadata); } + //if an error is thrown check if the file exists or not catch (const c10::Error& e) { std::ifstream infile(fname); @@ -124,13 +150,67 @@ PytorchModel::PytorchModel(const ActionOptions&ao): plumed_merror("The FILE: '"+fname+"' does not exist."); } } - checkRead(); +// Optimize model + _model.eval(); +#ifdef DO_TORCH_FREEZE_HACK + // Do the hack + // Copied from the implementation of torch::jit::freeze, + // except without the broken check + // See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp + bool optimize_numerics = true; // the default + // the {} is preserved_attrs + auto out_mod = torch::jit::freeze_module( + _model, {} + ); + // See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436 + auto graph = out_mod.get_method("forward").graph(); + OptimizeFrozenGraph(graph, optimize_numerics); + _model = out_mod; +#else + // Do it normally + _model = torch::jit::freeze(_model); +#endif + +#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10) + // Set JIT bailout to avoid long recompilations for many steps + size_t jit_bailout_depth; + if (metadata["_jit_bailout_depth"].empty()) { + // This is the default used in the Python code + jit_bailout_depth = 1; + } else { + jit_bailout_depth = std::stoi(metadata["_jit_bailout_depth"]); + } + torch::jit::getBailoutDepth() = jit_bailout_depth; +#else + // In PyTorch >=1.11, this is now set_fusion_strategy + torch::jit::FusionStrategy strategy; + if (metadata["_jit_fusion_strategy"].empty()) { + // This is the default used in the Python code + strategy = {{torch::jit::FusionBehavior::DYNAMIC, 0}}; + } else { + std::stringstream strat_stream(metadata["_jit_fusion_strategy"]); + std::string fusion_type, fusion_depth; + while(std::getline(strat_stream, fusion_type, ',')) { + std::getline(strat_stream, fusion_depth, ';'); + strategy.push_back({fusion_type == "STATIC" ? torch::jit::FusionBehavior::STATIC : torch::jit::FusionBehavior::DYNAMIC, std::stoi(fusion_depth)}); + } + } + torch::jit::setFusionStrategy(strategy); +#endif + +// TODO check torch::jit::optimize_for_inference() for more complex models +// This could speed up the code, it was not available on LTS +#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10) + _model = torch::jit::optimize_for_inference(_model); +#endif + //check the dimension of the output log.printf("Checking output dimension:\n"); std::vector input_test (_n_in); torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in}); + single_input = single_input.to(device); std::vector inputs; inputs.push_back( single_input ); torch::Tensor output = _model.forward( inputs ).toTensor(); @@ -150,54 +230,55 @@ PytorchModel::PytorchModel(const ActionOptions&ao): log.printf("Number of outputs: %d \n",_n_out); log.printf(" Bibliography: "); log< current_S(_n_in); for(unsigned i=0; i<_n_in; i++) current_S[i]=getArgument(i); //convert to tensor - torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}); + torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device); input_S.set_requires_grad(true); //convert to Ivalue std::vector inputs; inputs.push_back( input_S ); //calculate output torch::Tensor output = _model.forward( inputs ).toTensor(); - //set CV values - vector cvs = this->tensor_to_vector (output); - for(unsigned j=0; j<_n_out; j++) { - string name_comp = "node-"+std::to_string(j); - getPntrToComponent(name_comp)->set(cvs[j]); - } - //derivatives + + for(unsigned j=0; j<_n_out; j++) { - // expand dim to have shape (1,_n_out) - int batch_size = 1; - auto grad_output = torch::ones({1}).expand({batch_size, 1}); - // calculate derivatives with automatic differentiation + auto grad_output = torch::ones({1}).expand({1, 1}).to(device); auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)}, {input_S}, /*grad_outputs=*/ {grad_output}, /*retain_graph=*/true, - /*create_graph=*/false); - // add dimension - auto grad = gradient[0].unsqueeze(/*dim=*/1); - //convert to vector - vector der = this->tensor_to_vector ( grad ); + /*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector + vector der = this->tensor_to_vector ( gradient ); string name_comp = "node-"+std::to_string(j); //set derivatives of component j for(unsigned i=0; i<_n_in; i++) setDerivative( getPntrToComponent(name_comp),i, der[i] ); } -} -} -} + + //set CV values + vector cvs = this->tensor_to_vector (output); + for(unsigned j=0; j<_n_out; j++) { + string name_comp = "node-"+std::to_string(j); + getPntrToComponent(name_comp)->set(cvs[j]); + } + } -#endif + +} //PLMD +} //function +} //pytorch + +#endif //PLUMED_HAS_LIBTORCH