Skip to content

Commit

Permalink
add describer_type
Browse files Browse the repository at this point in the history
  • Loading branch information
JiQi535 committed Jan 28, 2024
1 parent e928630 commit 96fc1bd
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions maml/describers/_m3gnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import numpy as np
import pandas as pd

from maml.base import BaseDescriber
from maml.base import BaseDescriber, describer_type

if TYPE_CHECKING:
from pymatgen.core import Molecule, Structure

DEFAULT_MODEL = Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"
DEFAULT_MODEL = (
Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/"
)


@describer_type("structure")
class M3GNetStructure(BaseDescriber):
"""Use M3GNet pre-trained models as featurizer to get Structural features."""

Expand Down Expand Up @@ -57,7 +60,9 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
g = self.describer_model.featurizer(graph)
g = self.describer_model.feature_adjust(g)
for i in range(self.describer_model.n_blocks):
Expand All @@ -67,6 +72,7 @@ def transform_one(self, structure: Structure | Molecule):
return np.array(layer_before_readout(g))[0]


@describer_type("site")
class M3GNetSite(BaseDescriber):
"""Use M3GNet pre-trained models as featurizer to get atomic features."""

Expand Down Expand Up @@ -100,11 +106,17 @@ def __init__(
else:
self.describer_model = M3GNet.from_dir(DEFAULT_MODEL)
self.model_path = str(DEFAULT_MODEL)
allowed_output_layers = ["embedding"] + [f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)]
allowed_output_layers = ["embedding"] + [
f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)
]
if output_layers is None:
output_layers = ["gc_1"]
elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers):
raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.")
elif not isinstance(output_layers, list) or set(output_layers).difference(
allowed_output_layers
):
raise ValueError(
f"Invalid output_layers, it must be a sublist of {allowed_output_layers}."
)
self.output_layers = output_layers
self.return_type = return_type
super().__init__(**kwargs)
Expand All @@ -123,14 +135,16 @@ def transform_one(self, structure: Structure | Molecule):
graph = self.describer_model.graph_converter.convert(structure).as_list()
graph = tf_compute_distance_angle(graph)
three_basis = self.describer_model.basis_expansion(graph)
three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff)
three_cutoff = polynomial(
graph[Index.BONDS], self.describer_model.threebody_cutoff
)
g = self.describer_model.featurizer(graph)
atom_fea = {"embedding": g[Index.ATOMS]}
g = self.describer_model.feature_adjust(g)
for i in range(self.describer_model.n_blocks):
g = self.describer_model.three_interactions[i](g, three_basis, three_cutoff)
g = self.describer_model.graph_layers[i](g)
atom_fea[f"gc_{i+1}"] = g[Index.ATOMS]
atom_fea[f"gc_{i + 1}"] = g[Index.ATOMS]
atom_fea_dict = {k: v for k, v in atom_fea.items() if k in self.output_layers}
if self.return_type == dict:
return atom_fea_dict
Expand Down

0 comments on commit 96fc1bd

Please sign in to comment.