Skip to content

Commit

Permalink
Update the test models and CI dependencies (#73)
Browse files Browse the repository at this point in the history
* Update dependencies

* Downgrade GCC

* Reverse enginier a model generator

* Regenerate the model with PyTorch 1.11

* Update to PyTorch 1.11

* Update to Python 3.10

* Empty line

* Simplify the test models
  • Loading branch information
Raimondas Galvelis authored May 26, 2022
1 parent 84f7d88 commit fa102e6
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 7 deletions.
14 changes: 7 additions & 7 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,22 @@ jobs:
matrix:
include:
# Oldest supported versions
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.7)
- name: Linux (CUDA 10.2, Python 3.7, PyTorch 1.11)
os: ubuntu-18.04
cuda-version: "10.2.89"
gcc-version: "9.4.*"
gcc-version: "8.5.*"
nvcc-version: "10.2"
python-version: "3.7"
pytorch-version: "1.7.*"
pytorch-version: "1.11.*"

# Latest supported versions
- name: Linux (CUDA 11.2, Python 3.9, PyTorch 1.10)
- name: Linux (CUDA 11.2, Python 3.10, PyTorch 1.11)
os: ubuntu-18.04
cuda-version: "11.2.2"
gcc-version: "11.2.*"
gcc-version: "10.3.*"
nvcc-version: "11.2"
python-version: "3.9"
pytorch-version: "1.10.*"
python-version: "3.10"
pytorch-version: "1.11.*"

- name: MacOS (Python 3.9, PyTorch 1.9)
os: macos-11
Expand Down
Binary file modified tests/central.pt
Binary file not shown.
Binary file modified tests/forces.pt
Binary file not shown.
24 changes: 24 additions & 0 deletions tests/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch as pt

class Central(pt.nn.Module):
def forward(self, pos):
return pos.pow(2).sum()

class Forces(pt.nn.Module):
def forward(self, pos):
return pos.pow(2).sum(), -2 * pos

class Global(pt.nn.Module):
def forward(self, pos, k):
return k * pos.pow(2).sum()

class Periodic(pt.nn.Module):
def forward(self, pos, box):
box = box.diagonal().unsqueeze(0)
pos = pos - (pos / box).floor() * box
return pos.pow(2).sum()

pt.jit.script(Central()).save('central.pt')
pt.jit.script(Forces()).save('forces.pt')
pt.jit.script(Global()).save('global.pt')
pt.jit.script(Periodic()).save('periodic.pt')
Binary file modified tests/global.pt
Binary file not shown.
Binary file modified tests/periodic.pt
Binary file not shown.

0 comments on commit fa102e6

Please sign in to comment.