diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index fdcddcca..f08cfc01 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -201,8 +201,8 @@ def test_log_json(project_client): assert artifact_a.id in [a.id for a in artifacts] assert artifact_b.id in [a.id for a in artifacts] - -def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project): +@pytest.mark.parametrize("use_mojo", [False, True]) +def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project X, y = make_classification_df @@ -222,7 +222,7 @@ def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_w y=target_name, ) - artifact = project.log_h2o_model(h2o_model, tags=["h2o"]) + artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo, tags=["h2o"]) read_artifact = project.artifact(name=artifact.name) assert artifact.id == read_artifact.id