Skip to content

Commit

Permalink
Make some tests use deterministic torch algorithms (#108)
Browse files Browse the repository at this point in the history
* Make some tests use deterministic torch algorithms

* bump ci

* need this for deterministic tests

---------

Co-authored-by: Mike Henry <[email protected]>
  • Loading branch information
RaulPPelaez and mikemhenry authored Jul 25, 2023
1 parent 5e2438d commit d15cb91
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 54 deletions.
1 change: 1 addition & 0 deletions .github/workflows/self-hosted-gpu-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ jobs:
run: |
conda activate nnpops
cd build
export CUBLAS_WORKSPACE_CONFIG=:4096:8
ctest --verbose
stop-runner:
Expand Down
56 changes: 34 additions & 22 deletions src/pytorch/TestBatchedNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,40 +34,52 @@ def test_import():
import NNPOps
import NNPOps.BatchedNN

class DeterministicTorch:
def __enter__(self):
if torch.are_deterministic_algorithms_enabled():
self._already_enabled = True
return
self._already_enabled = False
torch.use_deterministic_algorithms(True)

def __exit__(self, type, value, traceback):
if not self._already_enabled:
torch.use_deterministic_algorithms(False)

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
def test_compare_with_native(deviceString, molFile):

if deviceString == 'cuda' and not torch.cuda.is_available():
pytest.skip('CUDA is not available')
with DeterministicTorch():
from NNPOps.BatchedNN import TorchANIBatchedNN

from NNPOps.BatchedNN import TorchANIBatchedNN
device = torch.device(deviceString)

device = torch.device(deviceString)

mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device)
mol = mdtraj.load(os.path.join(molecules, f'{molFile}_ligand.mol2'))
atomicNumbers = torch.tensor([[atom.element.atomic_number for atom in mol.top.atoms]], device=device)
atomicPositions = torch.tensor(mol.xyz, dtype=torch.float32, requires_grad=True, device=device)

nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
energy_ref = nnp((atomicNumbers, atomicPositions)).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()
nnp = torchani.models.ANI2x(periodic_table_index=True).to(device)
energy_ref = nnp((atomicNumbers, atomicPositions)).energies
energy_ref.backward()
grad_ref = atomicPositions.grad.clone()

nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, atomicNumbers).to(device)
energy = nnp((atomicNumbers, atomicPositions)).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()
nnp.neural_networks = TorchANIBatchedNN(nnp.species_converter, nnp.neural_networks, atomicNumbers).to(device)
energy = nnp((atomicNumbers, atomicPositions)).energies
atomicPositions.grad.zero_()
energy.backward()
grad = atomicPositions.grad.clone()

energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))
energy_error = torch.abs((energy - energy_ref)/energy_ref)
grad_error = torch.max(torch.abs((grad - grad_ref)/grad_ref))

assert energy_error < 5e-7
if molFile == '3o99':
assert grad_error < 0.025 # Some numerical instability
else:
assert grad_error < 5e-3
assert energy_error < 5e-7
if molFile == '3o99':
assert grad_error < 0.025 # Some numerical instability
else:
assert grad_error < 5e-3

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
@pytest.mark.parametrize('molFile', ['1hvj', '1hvk', '2iuz', '3hkw', '3hky', '3lka', '3o99'])
Expand Down
76 changes: 44 additions & 32 deletions src/pytorch/TestCFConv.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,18 @@ def test_gradients(deviceString):
# return torch.sum(conv(neighbors, pos, input))
# assert torch.autograd.gradcheck(func, positions)

class DeterministicTorch:
def __enter__(self):
if torch.are_deterministic_algorithms_enabled():
self._already_enabled = True
return
self._already_enabled = False
torch.use_deterministic_algorithms(True)

def __exit__(self, type, value, traceback):
if not self._already_enabled:
torch.use_deterministic_algorithms(False)

@pytest.mark.parametrize('deviceString', ['cpu', 'cuda'])
def test_model_serialization(deviceString):

Expand All @@ -94,35 +106,35 @@ def test_model_serialization(deviceString):
device = torch.device(deviceString)
numAtoms = 7
numFilters = 5

neighbors_ref, conv_ref = getCFConv(numFilters, device)
positions = (10*torch.rand(numAtoms, 3, dtype=torch.float32, device=device) - 5).detach()
positions.requires_grad = True
input = torch.rand(numAtoms, numFilters, dtype=torch.float32, device=device)

neighbors_ref.build(positions)
output_ref = conv_ref(neighbors_ref, positions, input)
total_ref = torch.sum(output_ref)
total_ref.backward()
grad_ref = positions.grad.clone()

with tempfile.NamedTemporaryFile() as fd1, tempfile.NamedTemporaryFile() as fd2:

torch.jit.script(neighbors_ref).save(fd1.name)
neighbors = torch.jit.load(fd1.name).to(device)

torch.jit.script(conv_ref).save(fd2.name)
conv = torch.jit.load(fd2.name).to(device)

neighbors.build(positions)
output = conv(neighbors, positions, input)
total = torch.sum(output)
positions.grad.zero_()
total.backward()
grad = positions.grad.clone()

assert torch.allclose(output, output_ref, rtol=1e-07)
if deviceString == 'cuda':
assert torch.allclose(grad, grad_ref, rtol=1e-07, atol=1e-6) # Numerical noise
else:
assert torch.allclose(grad, grad_ref, rtol=1e-07)
with DeterministicTorch():
neighbors_ref, conv_ref = getCFConv(numFilters, device)
positions = (10*torch.rand(numAtoms, 3, dtype=torch.float32, device=device) - 5).detach()
positions.requires_grad = True
input = torch.rand(numAtoms, numFilters, dtype=torch.float32, device=device)

neighbors_ref.build(positions)
output_ref = conv_ref(neighbors_ref, positions, input)
total_ref = torch.sum(output_ref)
total_ref.backward()
grad_ref = positions.grad.clone()

with tempfile.NamedTemporaryFile() as fd1, tempfile.NamedTemporaryFile() as fd2:

torch.jit.script(neighbors_ref).save(fd1.name)
neighbors = torch.jit.load(fd1.name).to(device)

torch.jit.script(conv_ref).save(fd2.name)
conv = torch.jit.load(fd2.name).to(device)

neighbors.build(positions)
output = conv(neighbors, positions, input)
total = torch.sum(output)
positions.grad.zero_()
total.backward()
grad = positions.grad.clone()

assert torch.allclose(output, output_ref, rtol=1e-07)
if deviceString == 'cuda':
assert torch.allclose(grad, grad_ref, rtol=1e-07, atol=1e-6) # Numerical noise
else:
assert torch.allclose(grad, grad_ref, rtol=1e-07)

0 comments on commit d15cb91

Please sign in to comment.