Skip to content

Commit

Permalink
fix naming of checkpoint check function
Browse files Browse the repository at this point in the history
  • Loading branch information
melo-gonzo committed May 17, 2024
1 parent e839ce0 commit d418d81
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions matsciml/interfaces/ase/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def recursive_type_cast(
return data_dict


def __checkpoint_conversion_exist_check(ckpt_path: str | Path) -> Path:
def _checkpoint_conversion_exist_check(ckpt_path: str | Path) -> Path:
"""Standardizes and checks for checkpoint path existence."""
if isinstance(ckpt_path, str):
ckpt_path = Path(ckpt_path)
Expand Down Expand Up @@ -261,22 +261,22 @@ def calculate(
def from_pretrained_force_regression(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = ForceRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

@classmethod
def from_pretrained_gradfree_task(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = GradFreeForceRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

@classmethod
def from_pretrained_scalar_task(
cls, ckpt_path: str | Path, *args, **kwargs
) -> MatSciMLCalculator:
ckpt_path = __checkpoint_conversion_exist_check(ckpt_path)
ckpt_path = _checkpoint_conversion_exist_check(ckpt_path)
task = ScalarRegressionTask.load_from_checkpoint(ckpt_path)
return cls(task, *args, **kwargs)

0 comments on commit d418d81

Please sign in to comment.