Skip to content

Commit

Permalink
add test to compile foundation models
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyes319 committed Jan 15, 2025
1 parent 83bfd9d commit d5e8a38
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tests/test_foundations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.nn.functional
from ase.build import molecule
from e3nn import o3
from e3nn.util import jit
from scipy.spatial.transform import Rotation as R

from mace import data, modules, tools
Expand Down Expand Up @@ -176,6 +177,33 @@ def test_multi_reference():
)


@pytest.mark.parametrize(
"calc",
[
mace_mp(device="cpu", default_dtype="float64"),
mace_mp(model="small", device="cpu", default_dtype="float64"),
mace_mp(model="medium", device="cpu", default_dtype="float64"),
mace_mp(model="large", device="cpu", default_dtype="float64"),
mace_mp(model=MODEL_PATH, device="cpu", default_dtype="float64"),
mace_off(model="small", device="cpu", default_dtype="float64"),
mace_off(model="medium", device="cpu", default_dtype="float64"),
mace_off(model="large", device="cpu", default_dtype="float64"),
mace_off(model=MODEL_PATH, device="cpu", default_dtype="float64"),
],
)
def test_compile_foundation(calc):
model = calc.models[0]
atoms = molecule("CH4")
atoms.positions += np.random.randn(*atoms.positions.shape) * 0.1
batch = calc._atoms_to_batch(atoms)
output_1 = model(batch.to_dict())
model_compiled = jit.compile(model)
output = model_compiled(batch.to_dict())
for key in output_1.keys():
if isinstance(output_1[key], torch.Tensor):
assert torch.allclose(output_1[key], output[key], atol=1e-5)


@pytest.mark.parametrize(
"model",
[
Expand Down

0 comments on commit d5e8a38

Please sign in to comment.