Skip to content

Commit

Permalink
refactor CHGNet.load() and pass model_name as version to CHGNet.__ini…
Browse files Browse the repository at this point in the history
…t__()
  • Loading branch information
janosh committed Oct 23, 2023
1 parent d0ee30f commit 1747180
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions chgnet/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,11 @@ def forward(
return_site_energies: bool = False,
return_atom_feas: bool = False,
return_crystal_feas: bool = False,
) -> dict:
) -> dict[str, Tensor]:
"""Get prediction associated with input graphs
Args:
graphs (List): a list of CrystalGraphs
task (str): the prediction task
eg: 'e', 'em', 'ef', 'efs', 'efsm'
task (str): the prediction task. One of 'e', 'em', 'ef', 'efs', 'efsm'.
Default = 'e'
return_site_energies (bool): whether to return per-site energies,
only available if self.mlp_first == True
Expand Down Expand Up @@ -662,26 +661,28 @@ def from_file(cls, path, **kwargs):

@classmethod
def load(cls, model_name="0.3.0"):
"""Load pretrained CHGNet."""
current_dir = os.path.dirname(os.path.abspath(__file__))
if model_name == "0.3.0":
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
)
)
elif model_name == "0.2.0": # noqa: RET505
return cls.from_file(
os.path.join(
current_dir,
"../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
),
mlp_out_bias=True,
)
else:
"""Load pretrained CHGNet model.
Args:
model_name (str, optional): Defaults to "0.3.0".
Raises:
ValueError: On unknown model_name.
"""
checkpoint_path = {
"0.3.0": "../pretrained/0.3.0/chgnet_0.3.0_e29f68s314m37.pth.tar",
"0.2.0": "../pretrained/0.2.0/chgnet_0.2.0_e30f77s348m32.pth.tar",
}.get(model_name)

if checkpoint_path is None:
raise ValueError(f"Unknown {model_name=}")

return cls.from_file(
os.path.join(module_dir, checkpoint_path),
mlp_out_bias=model_name == "0.2.0",
version=model_name,
)


@dataclass
class BatchedGraph:
Expand Down

0 comments on commit 1747180

Please sign in to comment.