Skip to content

Commit

Permalink
model version for Potential class is added (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
kenko911 authored Nov 17, 2023
1 parent 8b5dcef commit a7ae4a9
Show file tree
Hide file tree
Showing 27 changed files with 24 additions and 48 deletions.
1 change: 0 additions & 1 deletion dev/refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from __future__ import annotations

import torch

from matgl.models import MEGNet

# model_path = "pretrained_models/MEGNet-MP-2019.4.1-BandGap-mfi"
Expand Down
2 changes: 1 addition & 1 deletion pretrained_models/M3GNet-MP-2021.2.8-DIRECT-PES/model.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"@class": "Potential",
"@module": "matgl.apps.pes",
"@model_version": 1,
"@model_version": 2,
"metadata": null,
"kwargs": {
"model": {
Expand Down
2 changes: 1 addition & 1 deletion pretrained_models/M3GNet-MP-2021.2.8-PES/model.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"@class": "Potential",
"@module": "matgl.apps.pes",
"@model_version": 1,
"@model_version": 2,
"metadata": null,
"kwargs": {
"model": {
Expand Down
2 changes: 1 addition & 1 deletion src/matgl/apps/pes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
class Potential(nn.Module, IOMixIn):
"""A class representing an interatomic potential."""

__version__ = 1
__version__ = 2

def __init__(
self,
Expand Down
5 changes: 2 additions & 3 deletions tests/apps/test_pes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import matgl
import numpy as np
import pytest
import torch
from pymatgen.core import Lattice, Structure

import matgl
from matgl.apps.pes import Potential
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.models._m3gnet import M3GNet
from pymatgen.core import Lattice, Structure


@pytest.fixture()
Expand Down
7 changes: 3 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,15 @@
"""
from __future__ import annotations

import matgl
import pytest
import torch
from pymatgen.core import Lattice, Molecule, Structure
from pymatgen.util.testing import PymatgenTest

import matgl
from matgl.ext.pymatgen import Molecule2Graph, Structure2Graph, get_element_list
from matgl.graph.compute import (
compute_pair_vector_and_distance,
)
from pymatgen.core import Lattice, Molecule, Structure
from pymatgen.util.testing import PymatgenTest

matgl.clear_cache(confirm=False)

Expand Down
1 change: 0 additions & 1 deletion tests/data/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch

from matgl.data.transformer import LogTransformer, Normalizer


Expand Down
3 changes: 1 addition & 2 deletions tests/ext/test_ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
import pytest
import torch
from ase.build import molecule
from pymatgen.io.ase import AseAtomsAdaptor

from matgl import load_model
from matgl.ext.ase import Atoms2Graph, M3GNetCalculator, MolecularDynamics, Relaxer
from pymatgen.io.ase import AseAtomsAdaptor


def test_M3GNetCalculator(MoS):
Expand Down
3 changes: 1 addition & 2 deletions tests/ext/test_pymatgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import numpy as np
import torch
from pymatgen.core import Lattice, Structure

from matgl.ext.pymatgen import Structure2Graph, get_element_list
from pymatgen.core import Lattice, Structure

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down
5 changes: 2 additions & 3 deletions tests/graph/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

from functools import partial

import matgl
import numpy as np
import pytest
import torch
import torch.testing as tt
from pymatgen.core import Lattice, Structure

import matgl
from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.compute import (
compute_pair_vector_and_distance,
Expand All @@ -18,6 +16,7 @@
ensure_line_graph_compatibility,
prune_edges_by_features,
)
from pymatgen.core import Lattice, Structure


def _loop_indices(bond_atom_indices, pair_dist, cutoff=4.0):
Expand Down
3 changes: 1 addition & 2 deletions tests/graph/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

import numpy as np
from dgl.data.utils import split_dataset
from pymatgen.core import Molecule

from matgl.ext.pymatgen import Molecule2Graph, Structure2Graph, get_element_list
from matgl.graph.data import M3GNetDataset, MEGNetDataset, MGLDataLoader, collate_fn, collate_fn_efs
from pymatgen.core import Molecule

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down
1 change: 0 additions & 1 deletion tests/layers/test_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import pytest
import torch

from matgl.layers._activations import SoftExponential, SoftPlus2


Expand Down
1 change: 0 additions & 1 deletion tests/layers/test_atom_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import dgl
import numpy as np
import torch

from matgl.layers._atom_ref import AtomRef


Expand Down
3 changes: 1 addition & 2 deletions tests/layers/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest
import torch
from torch.testing import assert_close

from matgl.graph.compute import (
compute_theta_and_phi,
create_line_graph,
Expand All @@ -19,6 +17,7 @@
spherical_bessel_smooth,
)
from matgl.layers._three_body import combine_sbf_shf
from torch.testing import assert_close


def test_gaussian():
Expand Down
1 change: 0 additions & 1 deletion tests/layers/test_bond.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import pytest

from matgl.layers import BondExpansion


Expand Down
3 changes: 1 addition & 2 deletions tests/layers/test_core_and_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@

import pytest
import torch
from torch import nn

from matgl.layers import BondExpansion, EmbeddingBlock
from matgl.layers._core import MLP, GatedMLP
from torch import nn


@pytest.fixture()
Expand Down
3 changes: 1 addition & 2 deletions tests/layers/test_graph_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import dgl
import torch
from torch import nn

from matgl.layers import BondExpansion, EmbeddingBlock
from matgl.layers._graph_convolution import (
MLP,
Expand All @@ -15,6 +13,7 @@
MEGNetGraphConv,
)
from matgl.utils.cutoff import polynomial_cutoff
from torch import nn


class Graph(NamedTuple):
Expand Down
3 changes: 1 addition & 2 deletions tests/layers/test_readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import pytest
import torch
from torch import nn

from matgl.layers import BondExpansion, EmbeddingBlock
from matgl.layers._readout import (
ReduceReadOut,
Set2SetReadOut,
WeightedReadOut,
WeightedReadOutPair,
)
from torch import nn


class TestReadOut:
Expand Down
3 changes: 1 addition & 2 deletions tests/layers/test_three_body.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from __future__ import annotations

import torch
from torch import nn

from matgl.graph.compute import (
compute_theta_and_phi,
create_line_graph,
Expand All @@ -11,6 +9,7 @@
from matgl.layers._core import MLP, GatedMLP
from matgl.layers._three_body import ThreeBodyInteractions
from matgl.utils.cutoff import polynomial_cutoff
from torch import nn


def test_three_body_interactions(graph_MoS):
Expand Down
3 changes: 1 addition & 2 deletions tests/models/test_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import os

import matgl
import numpy as np
import pytest
import torch

import matgl
from matgl.models import M3GNet


Expand Down
5 changes: 2 additions & 3 deletions tests/models/test_megnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@

import os

import matgl
import numpy as np
import pytest
import torch as th
from pymatgen.core import Lattice, Structure

import matgl
from matgl.graph.compute import compute_pair_vector_and_distance
from matgl.layers import BondExpansion
from matgl.models import MEGNet
from pymatgen.core import Lattice, Structure


class TestMEGNet:
Expand Down
1 change: 0 additions & 1 deletion tests/models/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import torch.nn

from matgl.data.transformer import Normalizer
from matgl.models._wrappers import TransformedTargetModel

Expand Down
3 changes: 1 addition & 2 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""This is an integration test file that checks on pre-trained models to ensure they still work."""
from __future__ import annotations

import pytest

import matgl
import pytest


def test_form_e(LiFePO4):
Expand Down
3 changes: 1 addition & 2 deletions tests/utils/test_cutoff.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import torch
from torch.testing import assert_close

from matgl.layers._basis import SphericalBesselFunction
from matgl.utils.cutoff import cosine_cutoff, polynomial_cutoff
from torch.testing import assert_close


def test_cosine():
Expand Down
1 change: 0 additions & 1 deletion tests/utils/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pytest
import requests
import torch

from matgl.utils.io import IOMixIn, RemoteFile, get_available_pretrained_models, load_model

this_dir = Path(os.path.abspath(os.path.dirname(__file__)))
Expand Down
1 change: 0 additions & 1 deletion tests/utils/test_maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numpy as np
import pytest
import torch

from matgl.utils.maths import (
SPHERICAL_BESSEL_ROOTS,
broadcast_states_to_atoms,
Expand Down
3 changes: 1 addition & 2 deletions tests/utils/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@
import pytorch_lightning as pl
import torch.backends.mps
from dgl.data.utils import split_dataset
from pymatgen.core import Lattice, Structure

from matgl.ext.pymatgen import Structure2Graph, get_element_list
from matgl.graph.data import M3GNetDataset, MEGNetDataset, MGLDataLoader, collate_fn, collate_fn_efs
from matgl.models import M3GNet, MEGNet
from matgl.utils.training import ModelLightningModule, PotentialLightningModule, xavier_init
from pymatgen.core import Lattice, Structure

module_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down

0 comments on commit a7ae4a9

Please sign in to comment.