diff --git a/tests/test_model.py b/tests/test_model.py index 378ecdd1..6e7964a1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -223,3 +223,21 @@ def test_as_to_from_dict() -> None: model_3 = CHGNet(**to_dict["model_args"]) assert model_3.todict() == to_dict + + +def test_model_load(capsys: pytest.CaptureFixture) -> None: + model = CHGNet.load() + assert model.version == "0.3.0" + stdout, stderr = capsys.readouterr() + assert stdout == "CHGNet initialized with 412,525 parameters\n" + assert stderr == "" + + model = CHGNet.load(model_name="0.2.0") + assert model.version == "0.2.0" + stdout, stderr = capsys.readouterr() + assert stdout == "CHGNet initialized with 400,438 parameters\n" + assert stderr == "" + + model_name = "0.1.0" # invalid + with pytest.raises(ValueError, match=f"Unknown {model_name=}"): + CHGNet.load(model_name=model_name)