Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch inference optimization #2

Merged
merged 4 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion configure
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ infodir
docdir
oldincludedir
includedir
runstatedir
localstatedir
sharedstatedir
sysconfdir
Expand Down Expand Up @@ -809,6 +810,7 @@ datadir='${datarootdir}'
sysconfdir='${prefix}/etc'
sharedstatedir='${prefix}/com'
localstatedir='${prefix}/var'
runstatedir='${localstatedir}/run'
includedir='${prefix}/include'
oldincludedir='/usr/include'
docdir='${datarootdir}/doc/${PACKAGE_TARNAME}'
Expand Down Expand Up @@ -1061,6 +1063,15 @@ do
| -silent | --silent | --silen | --sile | --sil)
silent=yes ;;

-runstatedir | --runstatedir | --runstatedi | --runstated \
| --runstate | --runstat | --runsta | --runst | --runs \
| --run | --ru | --r)
ac_prev=runstatedir ;;
-runstatedir=* | --runstatedir=* | --runstatedi=* | --runstated=* \
| --runstate=* | --runstat=* | --runsta=* | --runst=* | --runs=* \
| --run=* | --ru=* | --r=*)
runstatedir=$ac_optarg ;;

-sbindir | --sbindir | --sbindi | --sbind | --sbin | --sbi | --sb)
ac_prev=sbindir ;;
-sbindir=* | --sbindir=* | --sbindi=* | --sbind=* | --sbin=* \
Expand Down Expand Up @@ -1198,7 +1209,7 @@ fi
for ac_var in exec_prefix prefix bindir sbindir libexecdir datarootdir \
datadir sysconfdir sharedstatedir localstatedir includedir \
oldincludedir docdir infodir htmldir dvidir pdfdir psdir \
libdir localedir mandir
libdir localedir mandir runstatedir
do
eval ac_val=\$$ac_var
# Remove trailing slashes.
Expand Down Expand Up @@ -1351,6 +1362,7 @@ Fine tuning of the installation directories:
--sysconfdir=DIR read-only single-machine data [PREFIX/etc]
--sharedstatedir=DIR modifiable architecture-independent data [PREFIX/com]
--localstatedir=DIR modifiable single-machine data [PREFIX/var]
--runstatedir=DIR modifiable per-process data [LOCALSTATEDIR/run]
--libdir=DIR object code libraries [EPREFIX/lib]
--includedir=DIR C header files [PREFIX/include]
--oldincludedir=DIR C header files for non-gcc [/usr/include]
Expand Down
131 changes: 106 additions & 25 deletions src/pytorch/PytorchModel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/* +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Copyright (c) 2022-2023 of Luigi Bonati.
Copyright (c) 2022-2023 of Luigi Bonati and Enrico Trizio.

The pytorch module is free software: you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
Expand Down Expand Up @@ -31,6 +31,23 @@ along with plumed. If not, see <http://www.gnu.org/licenses/>.
#include <fstream>
#include <cmath>

// We have to do a backward compatability hack for <1.10
// https://discuss.pytorch.org/t/how-to-check-libtorch-version/77709/4
// Basically, the check in torch::jit::freeze
// (see https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp#L479)
// is wrong, and we have ro "reimplement" the function
// to get around that...
// it's broken in 1.8 and 1.9
// BUT the internal logic in the function is wrong in 1.10
// So we only use torch::jit::freeze in >=1.11
// credits for this implementation of the hack to the NequIP guys
#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
#define DO_TORCH_FREEZE_HACK
// For the hack, need more headers:
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#endif

using namespace std;

namespace PLMD {
Expand Down Expand Up @@ -68,6 +85,7 @@ class PytorchModel :
unsigned _n_in;
unsigned _n_out;
torch::jit::script::Module _model;
torch::Device device = torch::kCPU;

public:
explicit PytorchModel(const ActionOptions&);
Expand Down Expand Up @@ -103,10 +121,18 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
std::string fname="model.ptc";
parse("FILE",fname);


// we create the metatdata dict
std::unordered_map<std::string, std::string> metadata = {
{"_jit_bailout_depth", ""},
{"_jit_fusion_strategy", ""}
};

//deserialize the model from file
try {
_model = torch::jit::load(fname);
_model = torch::jit::load(fname, device, metadata);
}

//if an error is thrown check if the file exists or not
catch (const c10::Error& e) {
std::ifstream infile(fname);
Expand All @@ -124,13 +150,67 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
plumed_merror("The FILE: '"+fname+"' does not exist.");
}
}

checkRead();

// Optimize model
_model.eval();
#ifdef DO_TORCH_FREEZE_HACK
// Do the hack
// Copied from the implementation of torch::jit::freeze,
// except without the broken check
// See https://github.com/pytorch/pytorch/blob/dfbd030854359207cb3040b864614affeace11ce/torch/csrc/jit/api/module.cpp
bool optimize_numerics = true; // the default
// the {} is preserved_attrs
auto out_mod = torch::jit::freeze_module(
_model, {}
);
// See 1.11 bugfix in https://github.com/pytorch/pytorch/pull/71436
auto graph = out_mod.get_method("forward").graph();
OptimizeFrozenGraph(graph, optimize_numerics);
_model = out_mod;
#else
// Do it normally
_model = torch::jit::freeze(_model);
#endif

#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR <= 10)
// Set JIT bailout to avoid long recompilations for many steps
size_t jit_bailout_depth;
if (metadata["_jit_bailout_depth"].empty()) {
// This is the default used in the Python code
jit_bailout_depth = 1;
} else {
jit_bailout_depth = std::stoi(metadata["_jit_bailout_depth"]);
}
torch::jit::getBailoutDepth() = jit_bailout_depth;
#else
// In PyTorch >=1.11, this is now set_fusion_strategy
torch::jit::FusionStrategy strategy;
if (metadata["_jit_fusion_strategy"].empty()) {
// This is the default used in the Python code
strategy = {{torch::jit::FusionBehavior::DYNAMIC, 0}};
} else {
std::stringstream strat_stream(metadata["_jit_fusion_strategy"]);
std::string fusion_type, fusion_depth;
while(std::getline(strat_stream, fusion_type, ',')) {
std::getline(strat_stream, fusion_depth, ';');
strategy.push_back({fusion_type == "STATIC" ? torch::jit::FusionBehavior::STATIC : torch::jit::FusionBehavior::DYNAMIC, std::stoi(fusion_depth)});
}
}
torch::jit::setFusionStrategy(strategy);
#endif

// TODO check torch::jit::optimize_for_inference() for more complex models
// This could speed up the code, it was not available on LTS
#if (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 10)
_model = torch::jit::optimize_for_inference(_model);
#endif

//check the dimension of the output
log.printf("Checking output dimension:\n");
std::vector<float> input_test (_n_in);
torch::Tensor single_input = torch::tensor(input_test).view({1,_n_in});
single_input = single_input.to(device);
std::vector<torch::jit::IValue> inputs;
inputs.push_back( single_input );
torch::Tensor output = _model.forward( inputs ).toTensor();
Expand All @@ -150,54 +230,55 @@ PytorchModel::PytorchModel(const ActionOptions&ao):
log.printf("Number of outputs: %d \n",_n_out);
log.printf(" Bibliography: ");
log<<plumed.cite("Bonati, Rizzi and Parrinello, J. Phys. Chem. Lett. 11, 2998-3004 (2020)");
log<<plumed.cite("Trizio and Parrinello, J. Phys. Chem. Lett. 12, 8621-8626 (2021)");
log.printf("\n");

}


void PytorchModel::calculate() {

//retrieve arguments
// retrieve arguments
vector<float> current_S(_n_in);
for(unsigned i=0; i<_n_in; i++)
current_S[i]=getArgument(i);
//convert to tensor
torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in});
torch::Tensor input_S = torch::tensor(current_S).view({1,_n_in}).to(device);
input_S.set_requires_grad(true);
//convert to Ivalue
std::vector<torch::jit::IValue> inputs;
inputs.push_back( input_S );
//calculate output
torch::Tensor output = _model.forward( inputs ).toTensor();
//set CV values
vector<float> cvs = this->tensor_to_vector (output);
for(unsigned j=0; j<_n_out; j++) {
string name_comp = "node-"+std::to_string(j);
getPntrToComponent(name_comp)->set(cvs[j]);
}
//derivatives


for(unsigned j=0; j<_n_out; j++) {
// expand dim to have shape (1,_n_out)
int batch_size = 1;
auto grad_output = torch::ones({1}).expand({batch_size, 1});
// calculate derivatives with automatic differentiation
auto grad_output = torch::ones({1}).expand({1, 1}).to(device);
auto gradient = torch::autograd::grad({output.slice(/*dim=*/1, /*start=*/j, /*end=*/j+1)},
{input_S},
/*grad_outputs=*/ {grad_output},
/*retain_graph=*/true,
/*create_graph=*/false);
// add dimension
auto grad = gradient[0].unsqueeze(/*dim=*/1);
//convert to vector
vector<float> der = this->tensor_to_vector ( grad );
/*create_graph=*/false)[0]; // the [0] is to get a tensor and not a vector<at::tensor>

vector<float> der = this->tensor_to_vector ( gradient );
string name_comp = "node-"+std::to_string(j);
//set derivatives of component j
for(unsigned i=0; i<_n_in; i++)
setDerivative( getPntrToComponent(name_comp),i, der[i] );
}
}
}
}

//set CV values
vector<float> cvs = this->tensor_to_vector (output);
for(unsigned j=0; j<_n_out; j++) {
string name_comp = "node-"+std::to_string(j);
getPntrToComponent(name_comp)->set(cvs[j]);
}

}

#endif

} //PLMD
} //function
} //pytorch

#endif //PLUMED_HAS_LIBTORCH
Loading