diff --git a/tests/describers/test_m3gnet.py b/tests/describers/test_m3gnet.py index 657bf761..eb695926 100644 --- a/tests/describers/test_m3gnet.py +++ b/tests/describers/test_m3gnet.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +import pytest import unittest from pymatgen.core import Lattice, Structure @@ -13,23 +14,29 @@ from maml.describers import M3GNetStructure, M3GNetSite -@unittest.skipIf(M3GNet is None, "M3GNet package is required") +@unittest.skipIf(M3GNet is None, "M3GNet package is required.") class M3GNetTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.s = Structure.from_spacegroup( "Fm-3m", Lattice.cubic(5.69169), ["Na", "Cl"], [[0, 0, 0], [0, 0, 0.5]] ) - cls.m3gnet_struct = M3GNetStructure() - cls.m3gnet_site = M3GNetSite(feature_batch="pandas_concat") def test_m3gnet_site_transform(self): - atom_features2 = self.m3gnet_site.transform([self.s] * 2) - self.assertListEqual(list(np.array(atom_features2).shape), [16, 64]) + atom_feat_2s = M3GNetSite(feature_batch="pandas_concat").transform([self.s] * 2) + self.assertListEqual(list(np.array(atom_feat_2s).shape), [16, 64]) + with pytest.raises(ValueError, match="Invalid output_layers"): + M3GNetSite(output_layers=["whatever"]) + atom_feat_2s_2l = M3GNetSite( + output_layers=["embedding", "gc_3"], feature_batch="pandas_concat" + ).transform([self.s] * 2) + self.assertListEqual(list(np.array(atom_feat_2s_2l).shape), [16, 128]) + atom_feat_dict = M3GNetSite(return_type=dict).transform_one(self.s) + assert type(atom_feat_dict) == dict def test_m3gnet_structure_transform(self): - struct_feature2 = self.m3gnet_struct.transform([self.s] * 2) - self.assertListEqual(list(np.array(struct_feature2).shape), [2, 128]) + struct_feat_2s = M3GNetStructure().transform([self.s] * 2) + self.assertListEqual(list(np.array(struct_feat_2s).shape), [2, 128]) if __name__ == "__main__":