Skip to content

Commit

Permalink
increase test coverage for M3GNetSite describer
Browse files Browse the repository at this point in the history
  • Loading branch information
JiQi535 committed Jan 27, 2024
1 parent 8d3627f commit a73ef6b
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions tests/describers/test_m3gnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import numpy as np
import pytest
import unittest

from pymatgen.core import Lattice, Structure
Expand All @@ -13,23 +14,29 @@
from maml.describers import M3GNetStructure, M3GNetSite


@unittest.skipIf(M3GNet is None, "M3GNet package is required")
@unittest.skipIf(M3GNet is None, "M3GNet package is required.")
class M3GNetTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.s = Structure.from_spacegroup(
"Fm-3m", Lattice.cubic(5.69169), ["Na", "Cl"], [[0, 0, 0], [0, 0, 0.5]]
)
cls.m3gnet_struct = M3GNetStructure()
cls.m3gnet_site = M3GNetSite(feature_batch="pandas_concat")

def test_m3gnet_site_transform(self):
atom_features2 = self.m3gnet_site.transform([self.s] * 2)
self.assertListEqual(list(np.array(atom_features2).shape), [16, 64])
atom_feat_2s = M3GNetSite(feature_batch="pandas_concat").transform([self.s] * 2)
self.assertListEqual(list(np.array(atom_feat_2s).shape), [16, 64])
with pytest.raises(ValueError, match="Invalid output_layers"):
M3GNetSite(output_layers=["whatever"])
atom_feat_2s_2l = M3GNetSite(
output_layers=["embedding", "gc_3"], feature_batch="pandas_concat"
).transform([self.s] * 2)
self.assertListEqual(list(np.array(atom_feat_2s_2l).shape), [16, 128])
atom_feat_dict = M3GNetSite(return_type=dict).transform_one(self.s)
assert type(atom_feat_dict) == dict

def test_m3gnet_structure_transform(self):
struct_feature2 = self.m3gnet_struct.transform([self.s] * 2)
self.assertListEqual(list(np.array(struct_feature2).shape), [2, 128])
struct_feat_2s = M3GNetStructure().transform([self.s] * 2)
self.assertListEqual(list(np.array(struct_feat_2s).shape), [2, 128])


if __name__ == "__main__":
Expand Down

0 comments on commit a73ef6b

Please sign in to comment.