diff --git a/chgnet/model/model.py b/chgnet/model/model.py index 7c6fd783..76fd0271 100644 --- a/chgnet/model/model.py +++ b/chgnet/model/model.py @@ -324,12 +324,11 @@ def forward( return_site_energies: bool = False, return_atom_feas: bool = False, return_crystal_feas: bool = False, - ) -> dict: + ) -> dict[str, Tensor]: """Get prediction associated with input graphs Args: graphs (List): a list of CrystalGraphs - task (str): the prediction task - eg: 'e', 'em', 'ef', 'efs', 'efsm' + task (str): the prediction task. One of 'e', 'em', 'ef', 'efs', 'efsm'. Default = 'e' return_site_energies (bool): whether to return per-site energies, only available if self.mlp_first == True @@ -662,26 +661,28 @@ def from_file(cls, path, **kwargs): @classmethod def load(cls, model_name="0.3.0"): - """Load pretrained CHGNet.""" - current_dir = os.path.dirname(os.path.abspath(__file__)) - if model_name == "0.3.0": - return cls.from_file( - os.path.join( - current_dir, - "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", - ) - ) - elif model_name == "0.2.0": # noqa: RET505 - return cls.from_file( - os.path.join( - current_dir, - "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", - ), - mlp_out_bias=True, - ) - else: + """Load pretrained CHGNet model. + + Args: + model_name (str, optional): Defaults to "0.3.0". + + Raises: + ValueError: On unknown model_name. + """ + checkpoint_path = { + "0.3.0": "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar", + "0.2.0": "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar", + }.get(model_name) + + if checkpoint_path is None: raise ValueError(f"Unknown {model_name=}") + return cls.from_file( + os.path.join(module_dir, checkpoint_path), + mlp_out_bias=model_name == "0.2.0", + version=model_name, + ) + @dataclass class BatchedGraph: