From 24b2043c5260a9495deab07d2fb9647b97d98c29 Mon Sep 17 00:00:00 2001 From: "Lee, Kin Long Kelvin" Date: Wed, 15 May 2024 10:47:21 -0700 Subject: [PATCH] feat: added scalar regression constructor method --- matsciml/interfaces/ase/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/matsciml/interfaces/ase/base.py b/matsciml/interfaces/ase/base.py index b90475a0..f6453bbb 100644 --- a/matsciml/interfaces/ase/base.py +++ b/matsciml/interfaces/ase/base.py @@ -220,3 +220,14 @@ def from_pretrained_gradfree_task( raise FileNotFoundError(f"Checkpoint file not found; passed {ckpt_path}") task = GradFreeForceRegressionTask.load_from_checkpoint(ckpt_path) return cls(task, *args, **kwargs) + + @classmethod + def from_pretrained_scalar_task( + cls, ckpt_path: str | Path, *args, **kwargs + ) -> MatSciMLCalculator: + if isinstance(ckpt_path, str): + ckpt_path = Path(ckpt_path) + if not ckpt_path.exists(): + raise FileNotFoundError(f"Checkpoint file not found; passed {ckpt_path}") + task = ScalarRegressionTask.load_from_checkpoint(ckpt_path) + return cls(task, *args, **kwargs)