Skip to content

Commit

Permalink
Merge pull request #219 from melo-gonzo/checkpoint-check-bugfix
Browse files Browse the repository at this point in the history
Checkpoint Path Check Bugfix
  • Loading branch information
laserkelvin authored May 17, 2024
2 parents f494d60 + d418d81 commit 6b2bcb0
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def recursive_type_cast(
return data_dict


def __checkpoint_conversion_exist_check(ckpt_path: str | Path) -> Path:
def _checkpoint_conversion_exist_check(ckpt_path: str | Path) -> Path:
"""Standardizes and checks for checkpoint path existence."""
if isinstance(ckpt_path, str):
ckpt_path = Path(ckpt_path)
Expand Down Expand Up @@ -261,22 +261,22 @@ def calculate(
def from_pretrained_force_regression(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = ForceRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

@classmethod
def from_pretrained_gradfree_task(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = GradFreeForceRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

@classmethod
def from_pretrained_scalar_task(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = ScalarRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

0 comments on commit 6b2bcb0

Please sign in to comment.