From 6f22bce9741cf80147ec0ade5479936d09c09d65 Mon Sep 17 00:00:00 2001 From: Brian Nguyen Date: Mon, 30 Sep 2024 14:09:12 -0400 Subject: [PATCH] Add support for logging of H2O MOJO Models (#486) * feat: h2o model logging supports MOJO * feat: add support for deserializing of h2o MOJO artifacts * fix: log MOJO model as directory * test: update unit test to test for binary vs MOJO * test: update h2o schema integration test to make sure binary model is used * fix: download_mojo instead of save_mojo * test: add unit test for retrieval of MOJO model * chore: pre-commit * chore: backwards compatability for h2o * chore: update literals --- rubicon_ml/client/artifact.py | 23 ++++++++++++++++++++--- rubicon_ml/client/mixin.py | 20 ++++++++++++++------ tests/integration/test_schema.py | 4 +++- tests/unit/client/test_artifact_client.py | 23 +++++++++++++++++++---- tests/unit/client/test_mixin_client.py | 7 +++++-- 5 files changed, 61 insertions(+), 16 deletions(-) diff --git a/rubicon_ml/client/artifact.py b/rubicon_ml/client/artifact.py index 8ffcbdde..56d3a8c1 100644 --- a/rubicon_ml/client/artifact.py +++ b/rubicon_ml/client/artifact.py @@ -71,7 +71,7 @@ def _get_data(self): @failsafe def get_data( self, - deserialize: Optional[Literal["h2o", "pickle", "xgboost"]] = None, + deserialize: Optional[Literal["h2o", "h2o_binary", "h2o_mojo", "pickle", "xgboost"]] = None, unpickle: bool = False, # TODO: deprecate & move to `deserialize` ): """Loads the data associated with this artifact and @@ -82,7 +82,8 @@ def get_data( deseralize : str, optional Method to use to deseralize this artifact's data. * None to disable deseralization and return the raw data. - * "h2o" to use `h2o.load_model` to load the data. + * "h2o" or "h2o_binary" to use `h2o.load_model` to load the data. + * "h2o_mojo" to use `h2o.import_mojo` to load the data. * "pickle" to use pickles to load the data. * "xgboost" to use xgboost's JSON loader to load the data as a fitted model. Defaults to None. @@ -101,6 +102,13 @@ def get_data( ) deserialize = "pickle" + if deserialize == "h2o": + warnings.warn( + "'deserialize' method 'h2o' will be deprecated in a future release," + " please use 'h2o_binary' instead.", + DeprecationWarning, + ) + for repo in self.repositories or []: try: if deserialize == "xgboost": @@ -119,12 +127,21 @@ def get_data( except Exception as err: return_err = err else: - if deserialize == "h2o": + if deserialize in [ + "h2o", + "h2o_binary", + ]: # "h2o" will be deprecated in a future release import h2o data = h2o.load_model( repo._get_artifact_data_path(project_name, experiment_id, self.id) ) + elif deserialize == "h2o_mojo": + import h2o + + data = h2o.import_mojo( + repo._get_artifact_data_path(project_name, experiment_id, self.id) + ) elif deserialize == "pickle": data = pickle.loads(data) diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index cf7074c2..fb362162 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -244,6 +244,7 @@ def log_h2o_model( h2o_model, artifact_name: Optional[str] = None, export_cross_validation_predictions: bool = False, + use_mojo: bool = False, **log_artifact_kwargs, ) -> Artifact: """Log an `h2o` model as an artifact using `h2o.save_model`. @@ -256,6 +257,9 @@ def log_h2o_model( The name of the artifact. Defaults to None, using `h2o_model`'s class name. export_cross_validation_predictions: bool, optional (default False) Passed directly to `h2o.save_model`. + use_mojo: bool, optional (default False) + Whether to log the model in MOJO format. If False, the model will be + logged in binary format. log_artifact_kwargs : dict Additional kwargs to be passed directly to `self.log_artifact`. """ @@ -268,12 +272,16 @@ def log_h2o_model( artifact_name = h2o_model.__class__.__name__ with tempfile.TemporaryDirectory() as temp_dir_name: - model_data_path = h2o.save_model( - h2o_model, - export_cross_validation_predictions=export_cross_validation_predictions, - filename=artifact_name, - path=temp_dir_name, - ) + if use_mojo: + model_data_path = f"{temp_dir_name}/{artifact_name}.zip" + h2o_model.download_mojo(path=model_data_path) + else: + model_data_path = h2o.save_model( + h2o_model, + export_cross_validation_predictions=export_cross_validation_predictions, + filename=artifact_name, + path=temp_dir_name, + ) artifact = self.log_artifact( name=artifact_name, diff --git a/tests/integration/test_schema.py b/tests/integration/test_schema.py index 6aa82054..73fe4575 100644 --- a/tests/integration/test_schema.py +++ b/tests/integration/test_schema.py @@ -111,4 +111,6 @@ def test_estimator_h2o_schema_train( model_artifact = experiment.artifact(name=schema_cls.__name__) assert len(project.schema_["parameters"]) == len(experiment.parameters()) - assert model_artifact.get_data(deserialize="h2o").__class__.__name__ == schema_cls.__name__ + assert ( + model_artifact.get_data(deserialize="h2o_binary").__class__.__name__ == schema_cls.__name__ + ) diff --git a/tests/unit/client/test_artifact_client.py b/tests/unit/client/test_artifact_client.py index 10135cfd..7b1c6f72 100644 --- a/tests/unit/client/test_artifact_client.py +++ b/tests/unit/client/test_artifact_client.py @@ -8,6 +8,7 @@ import pytest import xgboost from h2o import H2OFrame +from h2o.estimators.generic import H2OGenericEstimator from h2o.estimators.random_forest import H2ORandomForestEstimator from rubicon_ml import domain @@ -159,8 +160,19 @@ def test_download_location(mock_open, project_client): mock_file().write.assert_called_once_with(data) +@pytest.mark.parametrize( + ["use_mojo", "deserialization_method"], + [ + (False, "h2o"), + (False, "h2o_binary"), + (True, "h2o_mojo"), + ], +) def test_get_data_deserialize_h2o( - make_classification_df, rubicon_local_filesystem_client_with_project + make_classification_df, + rubicon_local_filesystem_client_with_project, + use_mojo, + deserialization_method, ): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project @@ -181,10 +193,13 @@ def test_get_data_deserialize_h2o( y=target_name, ) - artifact = project.log_h2o_model(h2o_model) - artifact_data = artifact.get_data(deserialize="h2o") + artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo) + artifact_data = artifact.get_data(deserialize=deserialization_method) - assert artifact_data.__class__ == h2o_model.__class__ + if use_mojo: + assert isinstance(artifact_data, H2OGenericEstimator) + else: + assert artifact_data.__class__ == h2o_model.__class__ def test_get_data_deserialize_xgboost( diff --git a/tests/unit/client/test_mixin_client.py b/tests/unit/client/test_mixin_client.py index fdcddcca..aad16bf7 100644 --- a/tests/unit/client/test_mixin_client.py +++ b/tests/unit/client/test_mixin_client.py @@ -202,7 +202,10 @@ def test_log_json(project_client): assert artifact_b.id in [a.id for a in artifacts] -def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_with_project): +@pytest.mark.parametrize("use_mojo", [False, True]) +def test_log_h2o_model( + make_classification_df, rubicon_local_filesystem_client_with_project, use_mojo +): """Test logging `h2o` model data.""" _, project = rubicon_local_filesystem_client_with_project X, y = make_classification_df @@ -222,7 +225,7 @@ def test_log_h2o_model(make_classification_df, rubicon_local_filesystem_client_w y=target_name, ) - artifact = project.log_h2o_model(h2o_model, tags=["h2o"]) + artifact = project.log_h2o_model(h2o_model, use_mojo=use_mojo, tags=["h2o"]) read_artifact = project.artifact(name=artifact.name) assert artifact.id == read_artifact.id