diff --git a/mlflow/store/_unity_catalog/registry/rest_store.py b/mlflow/store/_unity_catalog/registry/rest_store.py index aaa20074ba6c3..4bb4f16f8f12a 100644 --- a/mlflow/store/_unity_catalog/registry/rest_store.py +++ b/mlflow/store/_unity_catalog/registry/rest_store.py @@ -745,6 +745,19 @@ def get_model_version_by_alias(self, name, alias): response_proto = self._call_endpoint(GetModelVersionByAliasRequest, req_body) return model_version_from_uc_proto(response_proto.model_version) + def copy_model_version(self, src_mv, dst_name): + """ + Copy a model version from one registered model to another as a new model version. + + :param src_mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object representing + the source model version. + :param dst_name: the name of the registered model to copy the model version to. If a + registered model with this name does not exist, it will be created. + :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing + the cloned model version. + """ + return self._copy_model_version_impl(src_mv, dst_name) + def _await_model_version_creation(self, mv, await_creation_for): """ Does not wait for the model version to become READY as a successful creation will diff --git a/mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py b/mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py new file mode 100644 index 0000000000000..03a40f3259b12 --- /dev/null +++ b/mlflow/store/db_migrations/versions/acf3f17fdcc7_add_storage_location_field_to_model_.py @@ -0,0 +1,28 @@ +"""add storage location field to model versions + +Revision ID: acf3f17fdcc7 +Revises: 2d6e25af4d3e +Create Date: 2023-10-23 15:26:53.062080 + +""" +from alembic import op +import sqlalchemy as sa +from mlflow.store.model_registry.dbmodels.models import SqlModelVersion + + +# revision identifiers, used by Alembic. +revision = "acf3f17fdcc7" +down_revision = "2d6e25af4d3e" +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column( + SqlModelVersion.__tablename__, + sa.Column("storage_location", sa.String(500), nullable=True, default=None), + ) + + +def downgrade(): + pass diff --git a/mlflow/store/model_registry/abstract_store.py b/mlflow/store/model_registry/abstract_store.py index 764ceead20a1e..e7fdcab21cb6c 100644 --- a/mlflow/store/model_registry/abstract_store.py +++ b/mlflow/store/model_registry/abstract_store.py @@ -5,7 +5,7 @@ from mlflow.entities.model_registry import ModelVersionTag from mlflow.entities.model_registry.model_version_status import ModelVersionStatus from mlflow.exceptions import MlflowException -from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS +from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS, ErrorCode from mlflow.utils.annotations import developer_stable _logger = logging.getLogger(__name__) @@ -323,7 +323,6 @@ def get_model_version_by_alias(self, name, alias): """ pass - @abstractmethod def copy_model_version(self, src_mv, dst_name): """ Copy a model version from one registered model to another as a new model version. @@ -335,13 +334,17 @@ def copy_model_version(self, src_mv, dst_name): :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing the cloned model version. """ - pass + raise MlflowException( + "Method 'copy_model_version' has not yet been implemented for the current model " + "registry backend. To request support for implementing this method with this backend, " + "please submit an issue on GitHub." + ) def _copy_model_version_impl(self, src_mv, dst_name): try: self.create_registered_model(dst_name) except MlflowException as e: - if e.error_code != RESOURCE_ALREADY_EXISTS: + if e.error_code != ErrorCode.Name(RESOURCE_ALREADY_EXISTS): raise return self.create_model_version( diff --git a/mlflow/store/model_registry/dbmodels/models.py b/mlflow/store/model_registry/dbmodels/models.py index fbe1f5a57bc65..30039eceec30f 100644 --- a/mlflow/store/model_registry/dbmodels/models.py +++ b/mlflow/store/model_registry/dbmodels/models.py @@ -80,6 +80,8 @@ class SqlModelVersion(Base): source = Column(String(500), nullable=True, default=None) + storage_location = Column(String(500), nullable=True, default=None) + run_id = Column(String(32), nullable=True, default=None) run_link = Column(String(500), nullable=True, default=None) diff --git a/mlflow/store/model_registry/file_store.py b/mlflow/store/model_registry/file_store.py index fa2b2c6a4df73..513a8a2463048 100644 --- a/mlflow/store/model_registry/file_store.py +++ b/mlflow/store/model_registry/file_store.py @@ -570,7 +570,8 @@ def create_model_version( Create a new model version from given source and run ID. :param name: Registered model name. - :param source: Source path where the MLflow model is stored. + :param source: Source path or model version URI (in the format + ``models://``) where the MLflow model is stored. :param run_id: Run ID from MLflow tracking server that generated the model. :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag` instances associated with this model version. @@ -593,10 +594,10 @@ def next_version(registered_model_name): _validate_model_version_tag(tag.key, tag.value) storage_location = source if urllib.parse.urlparse(source).scheme == "models": - (src_model_name, src_model_version, _, _) = _parse_model_uri(source) + parsed_model_uri = _parse_model_uri(source) try: storage_location = self.get_model_version_download_uri( - src_model_name, src_model_version + parsed_model_uri.name, parsed_model_uri.version ) except Exception as e: raise MlflowException( diff --git a/mlflow/store/model_registry/sqlalchemy_store.py b/mlflow/store/model_registry/sqlalchemy_store.py index a4a65e552f7a0..cabef5c65e674 100644 --- a/mlflow/store/model_registry/sqlalchemy_store.py +++ b/mlflow/store/model_registry/sqlalchemy_store.py @@ -1,9 +1,11 @@ import logging +import urllib import sqlalchemy from sqlalchemy.future import select import mlflow.store.db.utils +from mlflow.entities.model_registry import ModelVersion from mlflow.entities.model_registry.model_version_stages import ( ALL_STAGES, DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS, @@ -18,6 +20,7 @@ RESOURCE_ALREADY_EXISTS, RESOURCE_DOES_NOT_EXIST, ) +from mlflow.store.artifact.utils.models import _parse_model_uri from mlflow.store.entities.paged_list import PagedList from mlflow.store.model_registry import ( SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT, @@ -612,7 +615,8 @@ def create_model_version( Create a new model version from given source and run ID. :param name: Registered model name. - :param source: Source path where the MLflow model is stored. + :param source: Source path or model version URI (in the format + ``models://``) where the MLflow model is stored. :param run_id: Run ID from MLflow tracking server that generated the model. :param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag` instances associated with this model version. @@ -631,6 +635,18 @@ def next_version(sql_registered_model): _validate_model_name(name) for tag in tags or []: _validate_model_version_tag(tag.key, tag.value) + storage_location = source + if urllib.parse.urlparse(source).scheme == "models": + parsed_model_uri = _parse_model_uri(source) + try: + storage_location = self.get_model_version_download_uri( + parsed_model_uri.name, parsed_model_uri.version + ) + except Exception as e: + raise MlflowException( + f"Unable to fetch model from model URI source artifact location '{source}'." + f"Error: {e}" + ) from e with self.ManagedSessionMaker() as session: creation_time = get_current_time_millis() for attempt in range(self.CREATE_MODEL_VERSION_RETRIES): @@ -644,6 +660,7 @@ def next_version(sql_registered_model): creation_time=creation_time, last_updated_time=creation_time, source=source, + storage_location=storage_location, run_id=run_id, run_link=run_link, description=description, @@ -856,7 +873,7 @@ def get_model_version_download_uri(self, name, version): """ with self.ManagedSessionMaker() as session: sql_model_version = self._get_sql_model_version(session, name, version) - return sql_model_version.source + return sql_model_version.storage_location or sql_model_version.source def search_model_versions( self, @@ -1099,6 +1116,19 @@ def get_model_version_by_alias(self, name, alias): f"Registered model alias {alias} not found.", INVALID_PARAMETER_VALUE ) + def copy_model_version(self, src_mv, dst_name) -> ModelVersion: + """ + Copy a model version from one registered model to another as a new model version. + + :param src_mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object representing + the source model version. + :param dst_name: the name of the registered model to copy the model version to. If a + registered model with this name does not exist, it will be created. + :return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing + the cloned model version. + """ + return self._copy_model_version_impl(src_mv, dst_name) + def _await_model_version_creation(self, mv, await_creation_for): """ Does not wait for the model version to become READY as a successful creation will diff --git a/tests/db/check_migration.py b/tests/db/check_migration.py index 104f1cf6ee43a..3bff39184671a 100644 --- a/tests/db/check_migration.py +++ b/tests/db/check_migration.py @@ -109,7 +109,7 @@ def post_migration(): for table in TABLES: df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn) df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl") - pd.testing.assert_frame_equal(df_actual, df_expected) + pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected) if __name__ == "__main__": diff --git a/tests/db/schemas/mssql.sql b/tests/db/schemas/mssql.sql index 9021ebb57aef4..8361adab9c319 100644 --- a/tests/db/schemas/mssql.sql +++ b/tests/db/schemas/mssql.sql @@ -79,6 +79,7 @@ CREATE TABLE model_versions ( status VARCHAR(20) COLLATE "SQL_Latin1_General_CP1_CI_AS", status_message VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", run_link VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", + storage_location VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS", CONSTRAINT model_version_pk PRIMARY KEY (name, version), CONSTRAINT "FK__model_vers__name__5812160E" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/mysql.sql b/tests/db/schemas/mysql.sql index dccf071fe4220..757201f205dfd 100644 --- a/tests/db/schemas/mysql.sql +++ b/tests/db/schemas/mysql.sql @@ -80,6 +80,7 @@ CREATE TABLE model_versions ( status VARCHAR(20), status_message VARCHAR(500), run_link VARCHAR(500), + storage_location VARCHAR(500), PRIMARY KEY (name, version), CONSTRAINT model_versions_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/postgresql.sql b/tests/db/schemas/postgresql.sql index f1134e4d03d94..158fe88288698 100644 --- a/tests/db/schemas/postgresql.sql +++ b/tests/db/schemas/postgresql.sql @@ -81,6 +81,7 @@ CREATE TABLE model_versions ( status VARCHAR(20), status_message VARCHAR(500), run_link VARCHAR(500), + storage_location VARCHAR(500), CONSTRAINT model_version_pk PRIMARY KEY (name, version), CONSTRAINT model_versions_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE ) diff --git a/tests/db/schemas/sqlite.sql b/tests/db/schemas/sqlite.sql index 2b385dc506822..09e310ea2e99c 100644 --- a/tests/db/schemas/sqlite.sql +++ b/tests/db/schemas/sqlite.sql @@ -82,6 +82,7 @@ CREATE TABLE model_versions ( status VARCHAR(20), status_message VARCHAR(500), run_link VARCHAR(500), + storage_location VARCHAR(500), CONSTRAINT model_version_pk PRIMARY KEY (name, version), FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE ) diff --git a/tests/resources/db/latest_schema.sql b/tests/resources/db/latest_schema.sql index 23dd6f13b5383..2ec1661808255 100644 --- a/tests/resources/db/latest_schema.sql +++ b/tests/resources/db/latest_schema.sql @@ -82,6 +82,7 @@ CREATE TABLE model_versions ( status VARCHAR(20), status_message VARCHAR(500), run_link VARCHAR(500), + storage_location VARCHAR(500), CONSTRAINT model_version_pk PRIMARY KEY (name, version), FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE ) diff --git a/tests/store/model_registry/test_file_store.py b/tests/store/model_registry/test_file_store.py index d4c577fbfe5c5..ef12611c2ba9e 100644 --- a/tests/store/model_registry/test_file_store.py +++ b/tests/store/model_registry/test_file_store.py @@ -1514,32 +1514,33 @@ def predict(self, context, model_input, params=None): assert mv2[0].name == "model2" -def test_copy_model_version(store): +@pytest.mark.parametrize("copy_to_same_model", [False, True]) +def test_copy_model_version(store, copy_to_same_model): name1 = "test_for_copy_MV1" store.create_registered_model(name1) src_tags = [ ModelVersionTag("key", "value"), ModelVersionTag("anotherKey", "some other value"), ] - with mock.patch("time.time", return_value=456778): - src_mv = _create_model_version( - store, name1, tags=src_tags, run_link="dummylink", description="test description" - ) + src_mv = _create_model_version( + store, name1, tags=src_tags, run_link="dummylink", description="test description" + ) # Make some changes to the src MV that won't be copied over store.transition_model_version_stage( name1, src_mv.version, "Production", archive_existing_versions=False ) - name2 = "test_for_copy_MV2" + copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2" + copy_mv_version = 2 if copy_to_same_model else 1 timestamp = time.time() - dst_mv = store.copy_model_version(src_mv, name2) - assert dst_mv.name == name2 - assert dst_mv.version == 1 + dst_mv = store.copy_model_version(src_mv, copy_rm_name) + assert dst_mv.name == copy_rm_name + assert dst_mv.version == copy_mv_version copied_mv = store.get_model_version(dst_mv.name, dst_mv.version) - assert copied_mv.name == name2 - assert copied_mv.version == 1 + assert copied_mv.name == copy_rm_name + assert copied_mv.version == copy_mv_version assert copied_mv.current_stage == "None" assert copied_mv.creation_timestamp >= timestamp assert copied_mv.last_updated_timestamp >= timestamp diff --git a/tests/store/model_registry/test_sqlalchemy_store.py b/tests/store/model_registry/test_sqlalchemy_store.py index 1e3c52da79bf4..40d8ea3a7be61 100644 --- a/tests/store/model_registry/test_sqlalchemy_store.py +++ b/tests/store/model_registry/test_sqlalchemy_store.py @@ -1639,3 +1639,48 @@ def test_delete_model_deletes_alias(store): match=r"Registered model alias test_alias not found.", ): store.get_model_version_by_alias(model_name, "test_alias") + + +@pytest.mark.parametrize("copy_to_same_model", [False, True]) +def test_copy_model_version(store, copy_to_same_model): + name1 = "test_for_copy_MV1" + store.create_registered_model(name1) + src_tags = [ + ModelVersionTag("key", "value"), + ModelVersionTag("anotherKey", "some other value"), + ] + src_mv = _mv_maker( + store, name1, tags=src_tags, run_link="dummylink", description="test description" + ) + + # Make some changes to the src MV that won't be copied over + store.transition_model_version_stage( + name1, src_mv.version, "Production", archive_existing_versions=False + ) + + copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2" + copy_mv_version = 2 if copy_to_same_model else 1 + timestamp = time.time() + dst_mv = store.copy_model_version(src_mv, copy_rm_name) + assert dst_mv.name == copy_rm_name + assert dst_mv.version == copy_mv_version + + copied_mv = store.get_model_version(dst_mv.name, dst_mv.version) + assert copied_mv.name == copy_rm_name + assert copied_mv.version == copy_mv_version + assert copied_mv.current_stage == "None" + assert copied_mv.creation_timestamp >= timestamp + assert copied_mv.last_updated_timestamp >= timestamp + assert copied_mv.description == "test description" + assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}" + assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source + assert copied_mv.run_link == "dummylink" + assert copied_mv.run_id == src_mv.run_id + assert copied_mv.status == "READY" + assert copied_mv.status_message is None + assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"} + + # Copy a model version copy + double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3") + assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}" + assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source