Skip to content

Commit

Permalink
Build copy model version API for SQL and UC model registry stores (ml…
Browse files Browse the repository at this point in the history
…flow#10078)

Signed-off-by: Jerry Liang <[email protected]>
Signed-off-by: mlflow-automation <[email protected]>
Co-authored-by: mlflow-automation <[email protected]>
  • Loading branch information
jerrylian-db and mlflow-automation authored Oct 26, 2023
1 parent 0d66ba2 commit 779451b
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 21 deletions.
13 changes: 13 additions & 0 deletions mlflow/store/_unity_catalog/registry/rest_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,19 @@ def get_model_version_by_alias(self, name, alias):
response_proto = self._call_endpoint(GetModelVersionByAliasRequest, req_body)
return model_version_from_uc_proto(response_proto.model_version)

def copy_model_version(self, src_mv, dst_name):
"""
Copy a model version from one registered model to another as a new model version.
:param src_mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
the source model version.
:param dst_name: the name of the registered model to copy the model version to. If a
registered model with this name does not exist, it will be created.
:return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
the cloned model version.
"""
return self._copy_model_version_impl(src_mv, dst_name)

def _await_model_version_creation(self, mv, await_creation_for):
"""
Does not wait for the model version to become READY as a successful creation will
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""add storage location field to model versions
Revision ID: acf3f17fdcc7
Revises: 2d6e25af4d3e
Create Date: 2023-10-23 15:26:53.062080
"""
from alembic import op
import sqlalchemy as sa
from mlflow.store.model_registry.dbmodels.models import SqlModelVersion


# revision identifiers, used by Alembic.
revision = "acf3f17fdcc7"
down_revision = "2d6e25af4d3e"
branch_labels = None
depends_on = None


def upgrade():
op.add_column(
SqlModelVersion.__tablename__,
sa.Column("storage_location", sa.String(500), nullable=True, default=None),
)


def downgrade():
pass
11 changes: 7 additions & 4 deletions mlflow/store/model_registry/abstract_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from mlflow.entities.model_registry import ModelVersionTag
from mlflow.entities.model_registry.model_version_status import ModelVersionStatus
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS
from mlflow.protos.databricks_pb2 import RESOURCE_ALREADY_EXISTS, ErrorCode
from mlflow.utils.annotations import developer_stable

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -323,7 +323,6 @@ def get_model_version_by_alias(self, name, alias):
"""
pass

@abstractmethod
def copy_model_version(self, src_mv, dst_name):
"""
Copy a model version from one registered model to another as a new model version.
Expand All @@ -335,13 +334,17 @@ def copy_model_version(self, src_mv, dst_name):
:return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
the cloned model version.
"""
pass
raise MlflowException(
"Method 'copy_model_version' has not yet been implemented for the current model "
"registry backend. To request support for implementing this method with this backend, "
"please submit an issue on GitHub."
)

def _copy_model_version_impl(self, src_mv, dst_name):
try:
self.create_registered_model(dst_name)
except MlflowException as e:
if e.error_code != RESOURCE_ALREADY_EXISTS:
if e.error_code != ErrorCode.Name(RESOURCE_ALREADY_EXISTS):
raise

return self.create_model_version(
Expand Down
2 changes: 2 additions & 0 deletions mlflow/store/model_registry/dbmodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class SqlModelVersion(Base):

source = Column(String(500), nullable=True, default=None)

storage_location = Column(String(500), nullable=True, default=None)

run_id = Column(String(32), nullable=True, default=None)

run_link = Column(String(500), nullable=True, default=None)
Expand Down
7 changes: 4 additions & 3 deletions mlflow/store/model_registry/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,8 @@ def create_model_version(
Create a new model version from given source and run ID.
:param name: Registered model name.
:param source: Source path where the MLflow model is stored.
:param source: Source path or model version URI (in the format
``models:/<model_name>/<version>``) where the MLflow model is stored.
:param run_id: Run ID from MLflow tracking server that generated the model.
:param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
instances associated with this model version.
Expand All @@ -593,10 +594,10 @@ def next_version(registered_model_name):
_validate_model_version_tag(tag.key, tag.value)
storage_location = source
if urllib.parse.urlparse(source).scheme == "models":
(src_model_name, src_model_version, _, _) = _parse_model_uri(source)
parsed_model_uri = _parse_model_uri(source)
try:
storage_location = self.get_model_version_download_uri(
src_model_name, src_model_version
parsed_model_uri.name, parsed_model_uri.version
)
except Exception as e:
raise MlflowException(
Expand Down
34 changes: 32 additions & 2 deletions mlflow/store/model_registry/sqlalchemy_store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
import urllib

import sqlalchemy
from sqlalchemy.future import select

import mlflow.store.db.utils
from mlflow.entities.model_registry import ModelVersion
from mlflow.entities.model_registry.model_version_stages import (
ALL_STAGES,
DEFAULT_STAGES_FOR_GET_LATEST_VERSIONS,
Expand All @@ -18,6 +20,7 @@
RESOURCE_ALREADY_EXISTS,
RESOURCE_DOES_NOT_EXIST,
)
from mlflow.store.artifact.utils.models import _parse_model_uri
from mlflow.store.entities.paged_list import PagedList
from mlflow.store.model_registry import (
SEARCH_MODEL_VERSION_MAX_RESULTS_DEFAULT,
Expand Down Expand Up @@ -612,7 +615,8 @@ def create_model_version(
Create a new model version from given source and run ID.
:param name: Registered model name.
:param source: Source path where the MLflow model is stored.
:param source: Source path or model version URI (in the format
``models:/<model_name>/<version>``) where the MLflow model is stored.
:param run_id: Run ID from MLflow tracking server that generated the model.
:param tags: A list of :py:class:`mlflow.entities.model_registry.ModelVersionTag`
instances associated with this model version.
Expand All @@ -631,6 +635,18 @@ def next_version(sql_registered_model):
_validate_model_name(name)
for tag in tags or []:
_validate_model_version_tag(tag.key, tag.value)
storage_location = source
if urllib.parse.urlparse(source).scheme == "models":
parsed_model_uri = _parse_model_uri(source)
try:
storage_location = self.get_model_version_download_uri(
parsed_model_uri.name, parsed_model_uri.version
)
except Exception as e:
raise MlflowException(
f"Unable to fetch model from model URI source artifact location '{source}'."
f"Error: {e}"
) from e
with self.ManagedSessionMaker() as session:
creation_time = get_current_time_millis()
for attempt in range(self.CREATE_MODEL_VERSION_RETRIES):
Expand All @@ -644,6 +660,7 @@ def next_version(sql_registered_model):
creation_time=creation_time,
last_updated_time=creation_time,
source=source,
storage_location=storage_location,
run_id=run_id,
run_link=run_link,
description=description,
Expand Down Expand Up @@ -856,7 +873,7 @@ def get_model_version_download_uri(self, name, version):
"""
with self.ManagedSessionMaker() as session:
sql_model_version = self._get_sql_model_version(session, name, version)
return sql_model_version.source
return sql_model_version.storage_location or sql_model_version.source

def search_model_versions(
self,
Expand Down Expand Up @@ -1099,6 +1116,19 @@ def get_model_version_by_alias(self, name, alias):
f"Registered model alias {alias} not found.", INVALID_PARAMETER_VALUE
)

def copy_model_version(self, src_mv, dst_name) -> ModelVersion:
"""
Copy a model version from one registered model to another as a new model version.
:param src_mv: A :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
the source model version.
:param dst_name: the name of the registered model to copy the model version to. If a
registered model with this name does not exist, it will be created.
:return: Single :py:class:`mlflow.entities.model_registry.ModelVersion` object representing
the cloned model version.
"""
return self._copy_model_version_impl(src_mv, dst_name)

def _await_model_version_creation(self, mv, await_creation_for):
"""
Does not wait for the model version to become READY as a successful creation will
Expand Down
2 changes: 1 addition & 1 deletion tests/db/check_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def post_migration():
for table in TABLES:
df_actual = pd.read_sql(sa.text(f"SELECT * FROM {table}"), conn)
df_expected = pd.read_pickle(SNAPSHOTS_DIR / f"{table}.pkl")
pd.testing.assert_frame_equal(df_actual, df_expected)
pd.testing.assert_frame_equal(df_actual[df_expected.columns], df_expected)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions tests/db/schemas/mssql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ CREATE TABLE model_versions (
status VARCHAR(20) COLLATE "SQL_Latin1_General_CP1_CI_AS",
status_message VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS",
run_link VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS",
storage_location VARCHAR(500) COLLATE "SQL_Latin1_General_CP1_CI_AS",
CONSTRAINT model_version_pk PRIMARY KEY (name, version),
CONSTRAINT "FK__model_vers__name__5812160E" FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE
)
Expand Down
1 change: 1 addition & 0 deletions tests/db/schemas/mysql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ CREATE TABLE model_versions (
status VARCHAR(20),
status_message VARCHAR(500),
run_link VARCHAR(500),
storage_location VARCHAR(500),
PRIMARY KEY (name, version),
CONSTRAINT model_versions_ibfk_1 FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE
)
Expand Down
1 change: 1 addition & 0 deletions tests/db/schemas/postgresql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ CREATE TABLE model_versions (
status VARCHAR(20),
status_message VARCHAR(500),
run_link VARCHAR(500),
storage_location VARCHAR(500),
CONSTRAINT model_version_pk PRIMARY KEY (name, version),
CONSTRAINT model_versions_name_fkey FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE
)
Expand Down
1 change: 1 addition & 0 deletions tests/db/schemas/sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ CREATE TABLE model_versions (
status VARCHAR(20),
status_message VARCHAR(500),
run_link VARCHAR(500),
storage_location VARCHAR(500),
CONSTRAINT model_version_pk PRIMARY KEY (name, version),
FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE
)
Expand Down
1 change: 1 addition & 0 deletions tests/resources/db/latest_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ CREATE TABLE model_versions (
status VARCHAR(20),
status_message VARCHAR(500),
run_link VARCHAR(500),
storage_location VARCHAR(500),
CONSTRAINT model_version_pk PRIMARY KEY (name, version),
FOREIGN KEY(name) REFERENCES registered_models (name) ON UPDATE CASCADE
)
Expand Down
23 changes: 12 additions & 11 deletions tests/store/model_registry/test_file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,32 +1514,33 @@ def predict(self, context, model_input, params=None):
assert mv2[0].name == "model2"


def test_copy_model_version(store):
@pytest.mark.parametrize("copy_to_same_model", [False, True])
def test_copy_model_version(store, copy_to_same_model):
name1 = "test_for_copy_MV1"
store.create_registered_model(name1)
src_tags = [
ModelVersionTag("key", "value"),
ModelVersionTag("anotherKey", "some other value"),
]
with mock.patch("time.time", return_value=456778):
src_mv = _create_model_version(
store, name1, tags=src_tags, run_link="dummylink", description="test description"
)
src_mv = _create_model_version(
store, name1, tags=src_tags, run_link="dummylink", description="test description"
)

# Make some changes to the src MV that won't be copied over
store.transition_model_version_stage(
name1, src_mv.version, "Production", archive_existing_versions=False
)

name2 = "test_for_copy_MV2"
copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2"
copy_mv_version = 2 if copy_to_same_model else 1
timestamp = time.time()
dst_mv = store.copy_model_version(src_mv, name2)
assert dst_mv.name == name2
assert dst_mv.version == 1
dst_mv = store.copy_model_version(src_mv, copy_rm_name)
assert dst_mv.name == copy_rm_name
assert dst_mv.version == copy_mv_version

copied_mv = store.get_model_version(dst_mv.name, dst_mv.version)
assert copied_mv.name == name2
assert copied_mv.version == 1
assert copied_mv.name == copy_rm_name
assert copied_mv.version == copy_mv_version
assert copied_mv.current_stage == "None"
assert copied_mv.creation_timestamp >= timestamp
assert copied_mv.last_updated_timestamp >= timestamp
Expand Down
45 changes: 45 additions & 0 deletions tests/store/model_registry/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1639,3 +1639,48 @@ def test_delete_model_deletes_alias(store):
match=r"Registered model alias test_alias not found.",
):
store.get_model_version_by_alias(model_name, "test_alias")


@pytest.mark.parametrize("copy_to_same_model", [False, True])
def test_copy_model_version(store, copy_to_same_model):
name1 = "test_for_copy_MV1"
store.create_registered_model(name1)
src_tags = [
ModelVersionTag("key", "value"),
ModelVersionTag("anotherKey", "some other value"),
]
src_mv = _mv_maker(
store, name1, tags=src_tags, run_link="dummylink", description="test description"
)

# Make some changes to the src MV that won't be copied over
store.transition_model_version_stage(
name1, src_mv.version, "Production", archive_existing_versions=False
)

copy_rm_name = name1 if copy_to_same_model else "test_for_copy_MV2"
copy_mv_version = 2 if copy_to_same_model else 1
timestamp = time.time()
dst_mv = store.copy_model_version(src_mv, copy_rm_name)
assert dst_mv.name == copy_rm_name
assert dst_mv.version == copy_mv_version

copied_mv = store.get_model_version(dst_mv.name, dst_mv.version)
assert copied_mv.name == copy_rm_name
assert copied_mv.version == copy_mv_version
assert copied_mv.current_stage == "None"
assert copied_mv.creation_timestamp >= timestamp
assert copied_mv.last_updated_timestamp >= timestamp
assert copied_mv.description == "test description"
assert copied_mv.source == f"models:/{src_mv.name}/{src_mv.version}"
assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source
assert copied_mv.run_link == "dummylink"
assert copied_mv.run_id == src_mv.run_id
assert copied_mv.status == "READY"
assert copied_mv.status_message is None
assert copied_mv.tags == {"key": "value", "anotherKey": "some other value"}

# Copy a model version copy
double_copy_mv = store.copy_model_version(copied_mv, "test_for_copy_MV3")
assert double_copy_mv.source == f"models:/{copied_mv.name}/{copied_mv.version}"
assert store.get_model_version_download_uri(dst_mv.name, dst_mv.version) == src_mv.source

0 comments on commit 779451b

Please sign in to comment.