From e7ef61860dec0ac3864bb4071d101500cb8792de Mon Sep 17 00:00:00 2001 From: Janosh Riebesell Date: Wed, 12 Jun 2024 18:51:39 -0400 Subject: [PATCH] test wandb_path and wandb_kwargs in test_trainer --- tests/test_trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_trainer.py b/tests/test_trainer.py index bcf44f64..3f9da1f6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import numpy as np +import pytest import torch from pymatgen.core import Lattice, Structure @@ -36,7 +37,7 @@ ) -def test_trainer(tmp_path: Path) -> None: +def test_trainer(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: chgnet = CHGNet.load() train_loader, val_loader, _test_loader = get_train_val_test_loader( data, batch_size=16, train_ratio=0.9, val_ratio=0.05 @@ -47,7 +48,9 @@ def test_trainer(tmp_path: Path) -> None: optimizer="Adam", criterion="MSE", learning_rate=1e-2, - epochs=5, + epochs=500, + wandb_path="/", + wandb_kwargs=dict(anonymous="allow"), ) dir_name = "test_tmp_dir" test_dir = tmp_path / dir_name @@ -63,6 +66,12 @@ def test_trainer(tmp_path: Path) -> None: n_matches == 1 ), f"Expected 1 {prefix} file, found {n_matches} in {output_files}" + # expect ImportError when passing wandb_path without wandb installed + err_msg = "Weights and Biases not installed. pip install wandb to use wandb logging" + with monkeypatch.context() as ctx, pytest.raises(ImportError, match=err_msg): # noqa: PT012 + ctx.setattr("chgnet.trainer.trainer.wandb", None) + _ = Trainer(model=chgnet, wandb_path="radicalai/chgnet-test-finetune") + def test_trainer_composition_model(tmp_path: Path) -> None: chgnet = CHGNet.load()