Skip to content

Commit

Permalink
test: add unit test for retrieval of MOJO model
Browse files Browse the repository at this point in the history
  • Loading branch information
thebrianbn committed Sep 26, 2024
1 parent dd54552 commit d12f037
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions tests/unit/client/test_artifact_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import xgboost
from h2o import H2OFrame
from h2o.estimators.random_forest import H2ORandomForestEstimator
from h2o.estimators.generic import H2OGenericEstimator

from rubicon_ml import domain
from rubicon_ml.client import Artifact, Rubicon
Expand Down Expand Up @@ -159,8 +160,15 @@ def test_download_location(mock_open, project_client):
mock_file().write.assert_called_once_with(data)


@pytest.mark.parametrize(
["use_mojo", "deserialization_method"],
[
(False, "h2o_binary"),
(True, "h2o_mojo"),
],
)
def test_get_data_deserialize_h2o(
make_classification_df, rubicon_local_filesystem_client_with_project
make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo, deserialization_method
):
"""Test logging `h2o` model data."""
_, project = rubicon_local_filesystem_client_with_project
Expand All @@ -181,10 +189,13 @@ def test_get_data_deserialize_h2o(
y=target_name,
)

artifact = project.log_h2o_model(h2o_model)
artifact_data = artifact.get_data(deserialize="h2o")
artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo)
artifact_data = artifact.get_data(deserialize=deserialization_method)

assert artifact_data.__class__ == h2o_model.__class__
if use_mojo:
assert isinstance(artifact_data, H2OGenericEstimator)
else:
assert artifact_data.__class__ == h2o_model.__class__


def test_get_data_deserialize_xgboost(
Expand Down

0 comments on commit d12f037

Please sign in to comment.