From 64f5f8cd90f4ff30c5e7627c3782d7e20d541931 Mon Sep 17 00:00:00 2001 From: Tsz Wai Ko <47970742+kenko911@users.noreply.github.com> Date: Mon, 6 May 2024 13:27:09 -0700 Subject: [PATCH] Included more united tests to improve code coverage (#253) * model version for Potential class is added * model version for Potential class is modified * Enable the smooth version of Spherical Bessel function in TensorNet * max_n, max_l for SphericalBessel radial basis functions are included in TensorNet class * adding united tests for improving the coverage score * little clean up in _so3.py and so3.py * remove unnecessary data storage in dgl graphs * update pymatgen version to fix the bug * refractor all include_states into include_state for consistency * change include_states into include_state in test_graph_conv.py * Ensure the state attr from molecule graph is consistent with matgl.float_th and including linear layer in TensorNet to match the original implementations * Fix the jupyter-notebook for M3GNet training * included more united tests to improve code coverage --- src/matgl/graph/data.py | 2 - tests/conftest.py | 33 ++++++++++ tests/ext/test_ase.py | 10 +-- tests/graph/test_data.py | 139 +++++++++++++++++++++++++++++++++++++++ tests/utils/test_so3.py | 6 ++ 5 files changed, 183 insertions(+), 7 deletions(-) diff --git a/src/matgl/graph/data.py b/src/matgl/graph/data.py index a6db553e..cf3ed746 100644 --- a/src/matgl/graph/data.py +++ b/src/matgl/graph/data.py @@ -70,8 +70,6 @@ def collate_fn_pes(batch, include_stress: bool = True, include_line_graph: bool if include_magmom: return g, torch.squeeze(lat), l_g, state_attr, e, f, s, m return g, torch.squeeze(lat), l_g, state_attr, e, f, s - if include_magmom: - return g, torch.squeeze(lat), state_attr, e, f, s, m return g, torch.squeeze(lat), state_attr, e, f, s diff --git a/tests/conftest.py b/tests/conftest.py index 4cc75763..7721c482 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -181,6 +181,39 @@ def Mo(): return Structure(Lattice.cubic(3.17), ["Mo", "Mo"], [[0.01, 0, 0], [0.5, 0.5, 0.5]]) +@pytest.fixture(scope="session") +def MoS2(): + return Structure( + [ + [3.18430383, 0.0, 1.9498237464610788e-16], + [-1.5921519149999994, 2.757688010148085, 1.9498237464610788e-16], + [0.0, 0.0, 19.44629514], + ], + [ + "Mo", + "Mo", + "Mo", + "S", + "S", + "S", + "S", + "S", + "S", + ], + [ + [0.00000000e00, 0.00000000e00, 1.94419205e01], + [1.59215192e00, 9.19229337e-01, 6.47772374e00], + [4.44089210e-16, 1.83845867e00, 1.29598221e01], + [0.00000000e00, 0.00000000e00, 4.92566372e00], + [1.59215192e00, 9.19229337e-01, 1.14077621e01], + [4.44089210e-16, 1.83845867e00, 1.78898605e01], + [0.00000000e00, 0.00000000e00, 8.02996293e00], + [1.59215192e00, 9.19229337e-01, 1.45120613e01], + [4.44089210e-16, 1.83845867e00, 1.54786455e00], + ], + ) + + @pytest.fixture(scope="session") def graph_Mo(Mo): return get_graph(Mo, 5.0) diff --git a/tests/ext/test_ase.py b/tests/ext/test_ase.py index 89c71e03..077bb627 100644 --- a/tests/ext/test_ase.py +++ b/tests/ext/test_ase.py @@ -115,14 +115,14 @@ def test_get_graph_from_atoms_mol(): assert np.allclose(state, [0.0, 0.0]) -def test_molecular_dynamics(MoS): +def test_molecular_dynamics(MoS2): pot = load_model("pretrained_models/M3GNet-MP-2021.2.8-PES/") for ensemble in ["nvt", "nvt_langevin", "nvt_andersen", "npt", "npt_berendsen", "npt_nose_hoover"]: - md = MolecularDynamics(MoS, potential=pot, ensemble=ensemble, taut=0.1, taup=0.1, compressibility_au=10) + md = MolecularDynamics(MoS2, potential=pot, ensemble=ensemble, taut=0.1, taup=0.1, compressibility_au=10) md.run(10) assert md.dyn is not None - md.set_atoms(MoS) - md = MolecularDynamics(MoS, potential=pot, ensemble=ensemble, taut=None, taup=None, compressibility_au=10) + md.set_atoms(MoS2) + md = MolecularDynamics(MoS2, potential=pot, ensemble=ensemble, taut=None, taup=None, compressibility_au=10) md.run(10) with pytest.raises(ValueError, match="Ensemble not supported"): - MolecularDynamics(MoS, potential=pot, ensemble="notanensemble") + MolecularDynamics(MoS2, potential=pot, ensemble="notanensemble") diff --git a/tests/graph/test_data.py b/tests/graph/test_data.py index 9db6a927..094c5a57 100644 --- a/tests/graph/test_data.py +++ b/tests/graph/test_data.py @@ -481,3 +481,142 @@ def test_mgl_dataloader_with_magmom(self, LiFePO4, BaNiO3): assert len(train_loader) == 8 assert len(val_loader) == 1 assert len(test_loader) == 1 + + def test_mgl_dataloader_without_collate_fn(self, LiFePO4, BaNiO3): + structures = [LiFePO4, BaNiO3] * 10 + energies = np.zeros(20).tolist() + f1 = np.zeros((28, 3)).tolist() + f2 = np.zeros((10, 3)).tolist() + s = np.zeros((3, 3)).tolist() + m1 = np.zeros(28).tolist() + m2 = np.zeros(10).tolist() + np.zeros((3, 3)).tolist() + forces = [f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2, f1, f2] + stresses = [s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s, s] + magmoms = [m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2, m1, m2] + + labels = { + "energies": energies, + "forces": forces, + "stresses": stresses, + "magmoms": magmoms, + } + element_types = get_element_list(structures) + cry_graph = Structure2Graph(element_types=element_types, cutoff=5.0) + dataset = MGLDataset( + threebody_cutoff=3.0, + structures=structures, + converter=cry_graph, + labels=labels, + include_line_graph=True, + clear_processed=True, + save_cache=False, + ) + + train_data, val_data, test_data = split_dataset( + dataset, + frac_list=[0.8, 0.1, 0.1], + shuffle=True, + random_state=42, + ) + # This modification is required for M3GNet property dataset + train_loader, val_loader, test_loader = MGLDataLoader( + train_data=train_data, + val_data=val_data, + test_data=test_data, + batch_size=2, + num_workers=1, + ) + + assert len(train_loader) == 8 + assert len(val_loader) == 1 + assert len(test_loader) == 1 + + labels.pop("magmoms") + dataset = MGLDataset( + threebody_cutoff=3.0, + structures=structures, + converter=cry_graph, + labels=labels, + include_line_graph=True, + clear_processed=True, + save_cache=False, + ) + + train_data, val_data, test_data = split_dataset( + dataset, + frac_list=[0.8, 0.1, 0.1], + shuffle=True, + random_state=42, + ) + # This modification is required for M3GNet property dataset + train_loader, val_loader, test_loader = MGLDataLoader( + train_data=train_data, + val_data=val_data, + test_data=test_data, + batch_size=2, + num_workers=1, + ) + + assert len(train_loader) == 8 + assert len(val_loader) == 1 + assert len(test_loader) == 1 + + labels.pop("stresses") + dataset = MGLDataset( + threebody_cutoff=3.0, + structures=structures, + converter=cry_graph, + labels=labels, + include_line_graph=True, + clear_processed=True, + save_cache=False, + ) + + train_data, val_data, test_data = split_dataset( + dataset, + frac_list=[0.8, 0.1, 0.1], + shuffle=True, + random_state=42, + ) + # This modification is required for M3GNet property dataset + train_loader, val_loader, test_loader = MGLDataLoader( + train_data=train_data, + val_data=val_data, + test_data=test_data, + batch_size=2, + num_workers=1, + ) + + assert len(train_loader) == 8 + assert len(val_loader) == 1 + assert len(test_loader) == 1 + labels.pop("forces") + dataset = MGLDataset( + threebody_cutoff=3.0, + structures=structures, + converter=cry_graph, + labels=labels, + include_line_graph=True, + clear_processed=True, + save_cache=False, + ) + + train_data, val_data, test_data = split_dataset( + dataset, + frac_list=[0.8, 0.1, 0.1], + shuffle=True, + random_state=42, + ) + # This modification is required for M3GNet property dataset + train_loader, val_loader, test_loader = MGLDataLoader( + train_data=train_data, + val_data=val_data, + test_data=test_data, + batch_size=2, + num_workers=1, + ) + + assert len(train_loader) == 8 + assert len(val_loader) == 1 + assert len(test_loader) == 1 diff --git a/tests/utils/test_so3.py b/tests/utils/test_so3.py index 3b46b215..45335243 100644 --- a/tests/utils/test_so3.py +++ b/tests/utils/test_so3.py @@ -317,6 +317,12 @@ def test_generate_clebsch_gordan_rsh(): expected_cg_rsh_0 = torch.eye(1).unsqueeze(0) # Expected cg_rsh result assert torch.allclose(cg_rsh_0, expected_cg_rsh_0, atol=1e-4) # Use torch.allclose for numerical comparisons + cg_rsh_0_without_pi = generate_clebsch_gordan_rsh(lmax_0, False) + expected_cg_rsh_0_without_pi = torch.eye(1).unsqueeze(0) # Expected cg_rsh without parity invariance result + assert torch.allclose( + cg_rsh_0_without_pi, expected_cg_rsh_0_without_pi, atol=1e-4 + ) # Use torch.allclose for numerical comparisons + # Test case 2: lmax = 1 lmax_1 = 1 cg_rsh_1 = generate_clebsch_gordan_rsh(lmax_1)