diff --git a/maml/describers/_m3gnet.py b/maml/describers/_m3gnet.py index e16bb74e..46d41d30 100644 --- a/maml/describers/_m3gnet.py +++ b/maml/describers/_m3gnet.py @@ -6,14 +6,17 @@ import numpy as np import pandas as pd -from maml.base import BaseDescriber +from maml.base import BaseDescriber, describer_type if TYPE_CHECKING: from pymatgen.core import Molecule, Structure -DEFAULT_MODEL = Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/" +DEFAULT_MODEL = ( + Path(__file__).parent / "data/m3gnet_models/matbench_mp_e_form/0/m3gnet/" +) +@describer_type("structure") class M3GNetStructure(BaseDescriber): """Use M3GNet pre-trained models as featurizer to get Structural features.""" @@ -57,7 +60,9 @@ def transform_one(self, structure: Structure | Molecule): graph = self.describer_model.graph_converter.convert(structure).as_list() graph = tf_compute_distance_angle(graph) three_basis = self.describer_model.basis_expansion(graph) - three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff) + three_cutoff = polynomial( + graph[Index.BONDS], self.describer_model.threebody_cutoff + ) g = self.describer_model.featurizer(graph) g = self.describer_model.feature_adjust(g) for i in range(self.describer_model.n_blocks): @@ -67,6 +72,7 @@ def transform_one(self, structure: Structure | Molecule): return np.array(layer_before_readout(g))[0] +@describer_type("site") class M3GNetSite(BaseDescriber): """Use M3GNet pre-trained models as featurizer to get atomic features.""" @@ -100,11 +106,17 @@ def __init__( else: self.describer_model = M3GNet.from_dir(DEFAULT_MODEL) self.model_path = str(DEFAULT_MODEL) - allowed_output_layers = ["embedding"] + [f"gc_{i + 1}" for i in range(self.describer_model.n_blocks)] + allowed_output_layers = ["embedding"] + [ + f"gc_{i + 1}" for i in range(self.describer_model.n_blocks) + ] if output_layers is None: output_layers = ["gc_1"] - elif not isinstance(output_layers, list) or set(output_layers).difference(allowed_output_layers): - raise ValueError(f"Invalid output_layers, it must be a sublist of {allowed_output_layers}.") + elif not isinstance(output_layers, list) or set(output_layers).difference( + allowed_output_layers + ): + raise ValueError( + f"Invalid output_layers, it must be a sublist of {allowed_output_layers}." + ) self.output_layers = output_layers self.return_type = return_type super().__init__(**kwargs) @@ -123,14 +135,16 @@ def transform_one(self, structure: Structure | Molecule): graph = self.describer_model.graph_converter.convert(structure).as_list() graph = tf_compute_distance_angle(graph) three_basis = self.describer_model.basis_expansion(graph) - three_cutoff = polynomial(graph[Index.BONDS], self.describer_model.threebody_cutoff) + three_cutoff = polynomial( + graph[Index.BONDS], self.describer_model.threebody_cutoff + ) g = self.describer_model.featurizer(graph) atom_fea = {"embedding": g[Index.ATOMS]} g = self.describer_model.feature_adjust(g) for i in range(self.describer_model.n_blocks): g = self.describer_model.three_interactions[i](g, three_basis, three_cutoff) g = self.describer_model.graph_layers[i](g) - atom_fea[f"gc_{i+1}"] = g[Index.ATOMS] + atom_fea[f"gc_{i + 1}"] = g[Index.ATOMS] atom_fea_dict = {k: v for k, v in atom_fea.items() if k in self.output_layers} if self.return_type == dict: return atom_fea_dict