diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 85ec2291..639d5557 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -14,14 +14,6 @@ def test_log_images(m, every_n_epochs, tmpdir): m.trainer.fit(m) saved_images = glob.glob("{}/*.png".format(tmpdir)) assert len(saved_images) == 1 - -@pytest.mark.parametrize("every_n_epochs", [1, 2, 3]) -def test_log_images_multiclass(two_class_m, every_n_epochs, tmpdir): - im_callback = callbacks.images_callback(savedir=tmpdir, every_n_epochs=every_n_epochs) - two_class_m.create_trainer(callbacks=[im_callback]) - two_class_m.trainer.fit(two_class_m) - saved_images = glob.glob("{}/*.png".format(tmpdir)) - assert len(saved_images) == 1 def test_create_checkpoint(m, tmpdir): checkpoint_callback = ModelCheckpoint(