Skip to content

Commit

Permalink
py: fix order by ID (#355)
Browse files Browse the repository at this point in the history
* api: fix ID field in orderBy

Fixes: #353

Signed-off-by: Isabella do Amaral <[email protected]>

* py: please mypy

Signed-off-by: Isabella do Amaral <[email protected]>

* add stubs for dateutil

Signed-off-by: Isabella do Amaral <[email protected]>

---------

Signed-off-by: Isabella do Amaral <[email protected]>
  • Loading branch information
isinyaaa authored Sep 6, 2024
1 parent 2b1ad4b commit 12f6cb9
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 17 deletions.
4 changes: 2 additions & 2 deletions api/openapi/model-registry.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ components:
enum:
- CREATE_TIME
- LAST_UPDATE_TIME
- Id
- ID
type: string
Artifact:
oneOf:
Expand Down Expand Up @@ -1661,7 +1661,7 @@ components:
explode: true
examples:
orderBy:
value: Id
value: ID
name: orderBy
description: Specifies the order by criteria for listing entities.
schema:
Expand Down
5 changes: 4 additions & 1 deletion clients/python/noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def lint(session: Session) -> None:
def mypy(session: Session) -> None:
"""Type check using mypy."""
session.install(".")
session.install("mypy")
session.install(
"mypy",
"types-python-dateutil",
)

session.run("mypy", "src/model_registry")

Expand Down
14 changes: 12 additions & 2 deletions clients/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mypy = "^1.7.0"
pytest-asyncio = ">=0.23.7,<0.25.0"
requests = "^2.32.2"
black = "^24.4.2"
types-python-dateutil = "^2.9.0.20240906"

[tool.coverage.run]
branch = true
Expand Down
6 changes: 3 additions & 3 deletions clients/python/src/model_registry/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
*,
author: str,
is_secure: bool = True,
user_token: bytes | None = None,
user_token: str | None = None,
custom_ca: str | None = None,
):
"""Constructor.
Expand All @@ -44,8 +44,8 @@ def __init__(
Keyword Args:
author: Name of the author.
is_secure: Whether to use a secure connection. Defaults to True.
user_token: The PEM-encoded user token as a byte string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a byte string. Defaults to path on envvar CERT.
user_token: The PEM-encoded user token as a string. Defaults to content of path on envvar KF_PIPELINES_SA_TOKEN_PATH.
custom_ca: Path to the PEM-encoded root certificates as a string. Defaults to path on envvar CERT.
"""
import nest_asyncio

Expand Down
8 changes: 4 additions & 4 deletions clients/python/src/model_registry/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def secure_connection(
server_address: str,
port: int = 443,
*,
user_token: bytes,
user_token: str,
custom_ca: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Expand All @@ -52,7 +52,7 @@ def secure_connection(
port: Server port. Defaults to 443.
Keyword Args:
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
custom_ca: The path to a PEM-
"""
return cls(
Expand All @@ -68,14 +68,14 @@ def insecure_connection(
cls,
server_address: str,
port: int,
user_token: bytes | None = None,
user_token: str | None = None,
) -> ModelRegistryAPIClient:
"""Constructor.
Args:
server_address: Server address.
port: Server port.
user_token: The PEM-encoded user token as a byte string.
user_token: The PEM-encoded user token as a string.
"""
return cls(
Configuration(host=f"{server_address}:{port}", access_token=user_token)
Expand Down
4 changes: 2 additions & 2 deletions clients/python/src/model_registry/types/pager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def restart(self) -> Pager[T]:
This keeps the current options and page function, but resets the internal state.
"""
# as MLMD loops over pages, we need to keep track of the first page or we'll loop forever
self._start = None
self._current_page = None
self._start: str | None = None
self._current_page: list[T] | None = None
# tracks the next item on the current page
self._i = 0
self.options.next_page_token = None
Expand Down
2 changes: 1 addition & 1 deletion clients/python/src/mr_openapi/models/order_by_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OrderByField(str, Enum):
"""
CREATE_TIME = "CREATE_TIME"
LAST_UPDATE_TIME = "LAST_UPDATE_TIME"
ID = "Id"
ID = "ID"

@classmethod
def from_json(cls, json_str: str) -> Self:
Expand Down
124 changes: 124 additions & 0 deletions clients/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,72 @@ def test_get_registered_models(client: ModelRegistry):
assert i == models


@pytest.mark.e2e
def test_get_registered_models_order_by(client: ModelRegistry):
models = 5

rms = []
for name in [f"test_model{i}" for i in range(models)]:
rms.append(
client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version="1.0.0",
)
)

# id ordering should match creation order
i = 0
for rm, by_id in zip(
rms,
client.get_registered_models().order_by_id(),
):
assert rm.id == by_id.id
i += 1

assert i == models

# and obviously, creation ordering should match creation ordering
i = 0
for rm, by_creation in zip(
rms,
client.get_registered_models().order_by_creation_time(),
):
assert rm.id == by_creation.id
i += 1

assert i == models

# update order should match creation ordering by default
i = 0
for rm, by_update in zip(
rms,
client.get_registered_models().order_by_update_time(),
):
assert rm.id == by_update.id
i += 1

assert i == models

# now update the models in reverse order
for rm in reversed(rms):
rm.description = "updated"
client.update(rm)

# and they should match in reverse
i = 0
for rm, by_update in zip(
reversed(rms),
client.get_registered_models().order_by_update_time(),
):
assert rm.id == by_update.id
i += 1

assert i == models


@pytest.mark.e2e
def test_get_registered_models_and_reset(client: ModelRegistry):
model_count = 6
Expand Down Expand Up @@ -260,6 +326,64 @@ def test_get_model_versions(client: ModelRegistry):
assert i == models


@pytest.mark.e2e
def test_get_model_versions_order_by(client: ModelRegistry):
name = "test_model"
models = 5
mvs = []
for v in [f"1.0.{i}" for i in range(models)]:
client.register_model(
name,
"s3",
model_format_name="test_format",
model_format_version="test_version",
version=v,
)
mvs.append(client.get_model_version(name, v))

i = 0
for mv, by_id in zip(
mvs,
client.get_model_versions(name).order_by_id(),
):
assert mv.id == by_id.id
i += 1

assert i == models

i = 0
for mv, by_creation in zip(
mvs,
client.get_model_versions(name).order_by_creation_time(),
):
assert mv.id == by_creation.id
i += 1

assert i == models

i = 0
for mv, by_update in zip(
mvs,
client.get_model_versions(name).order_by_update_time(),
):
assert mv.id == by_update.id
i += 1

assert i == models

for mv in reversed(mvs):
mv.description = "updated"
client.update(mv)

i = 0
for mv, by_update in zip(
reversed(mvs),
client.get_model_versions(name).order_by_update_time(),
):
assert mv.id == by_update.id
i += 1


@pytest.mark.e2e
def test_get_model_versions_and_reset(client: ModelRegistry):
name = "test_model"
Expand Down
4 changes: 2 additions & 2 deletions pkg/openapi/model_order_by_field.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 12f6cb9

Please sign in to comment.