Skip to content

Commit

Permalink
Torch2 compatibility (#106)
Browse files Browse the repository at this point in the history
* Update ci

* port the latest CI versions to the self-hosted runner

* forgot to remove bit from version I copy

* see if cuda 11.7 works

* Update swig typecheck to use py::isinstance instead of as_module

* run CUDA tests on GPU runner

* Update TorchForce swig input typemap to avoid as_module

* Fix formatting

* Switch to using micromamba

* Make sure temporary files are deleted even if there was an error

* Update swig typemaps to not use a temp file

* Revert "Merge remote-tracking branch 'origin/feat/add_aws_gpu_runner' into torch2"

This reverts commit 0a58de4, reversing
changes made to a1c16c6.

* run gpu tests on this branch

* actually test with pytorch2

* Revert "actually test with pytorch2"

This reverts commit 4e7a4fb.

* Revert "Merge remote-tracking branch 'origin/feat/add_aws_gpu_runner' into torch2"

This reverts commit 0a58de4, reversing
changes made to a1c16c6.

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
RaulPPelaez and mikemhenry authored Jul 24, 2023
1 parent b76deb4 commit e196c8d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ jobs:
pytorch-version: "1.11.*"

# Latest supported versions
- name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.12)
- name: Linux (CUDA 11.8, Python 3.10, PyTorch 2.0)
os: ubuntu-22.04
cuda-version: "11.2.2"
cuda-version: "11.8.0"
gcc-version: "10.3.*"
nvcc-version: "11.2"
nvcc-version: "11.8"
python-version: "3.10"
pytorch-version: "1.12.*"
pytorch-version: "2.0.*"

- name: MacOS (Python 3.9, PyTorch 1.9)
os: macos-11
Expand Down
23 changes: 14 additions & 9 deletions python/openmmtorch.i
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "openmm/RPMDIntegrator.h"
#include "openmm/RPMDMonteCarloBarostat.h"
#include <torch/csrc/jit/python/module_python.h>
#include <torch/csrc/jit/serialization/import.h>
%}

/*
Expand All @@ -28,23 +29,27 @@
}
}

%typemap(in) const torch::jit::Module&(torch::jit::Module module) {
%typemap(in) const torch::jit::Module&(torch::jit::Module mod) {
py::object o = py::reinterpret_borrow<py::object>($input);
module = torch::jit::as_module(o).value();
$1 = &module;
py::object pybuffer = py::module::import("io").attr("BytesIO")();
py::module::import("torch.jit").attr("save")(o, pybuffer);
std::string s = py::cast<std::string>(pybuffer.attr("getvalue")());
std::stringstream buffer(s);
mod = torch::jit::load(buffer);
$1 = &mod;
}

%typemap(out) const torch::jit::Module& {
auto fileName = std::tmpnam(nullptr);
$1->save(fileName);
$result = py::module::import("torch.jit").attr("load")(fileName).release().ptr();
//This typemap assumes that torch does not require the file to exist after construction
std::remove(fileName);
std::stringstream buffer;
$1->save(buffer);
auto pybuffer = py::module::import("io").attr("BytesIO")(py::bytes(buffer.str()));
$result = py::module::import("torch.jit").attr("load")(pybuffer).release().ptr();
}

%typecheck(SWIG_TYPECHECK_POINTER) const torch::jit::Module& {
py::object o = py::reinterpret_borrow<py::object>($input);
$1 = torch::jit::as_module(o).has_value() ? 1 : 0;
py::handle ScriptModule = py::module::import("torch.jit").attr("ScriptModule");
$1 = py::isinstance(o, ScriptModule);
}

namespace std {
Expand Down

0 comments on commit e196c8d

Please sign in to comment.