Skip to content

Commit

Permalink
Merge pull request #53 from gizatechxyz/fix/endpoints-retrieval
Browse files Browse the repository at this point in the history
Fix jobs and endpoints retrieval using RootModel
  • Loading branch information
Gonmeso authored Mar 14, 2024
2 parents 3bc27d4 + 1891141 commit 24a19d6
Show file tree
Hide file tree
Showing 9 changed files with 37 additions and 10 deletions.
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
description: Giza CLI 0.14.0
description: Giza CLI 0.14.1
---

# Giza CLI
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/full_transpilation.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pip install -r requirements.txt
Or:

```bash
pip install giza-cli==0.14.0 onnx==1.14.1 torch==2.1.0 torchvision==0.16.0
pip install giza-cli==0.14.1 onnx==1.14.1 torch==2.1.0 torchvision==0.16.0
```

We will use the libraries for the following purposes:
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist_pytorch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"Or:\n",
"\n",
"```bash\n",
"pip install giza-cli==0.14.0 onnx==1.14.1 torch==2.1.0 torchvision==0.16.0\n",
"pip install giza-cli==0.14.1 onnx==1.14.1 torch==2.1.0 torchvision==0.16.0\n",
"```\n",
"\n",
"We will use the libraries for the following purposes:\n",
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
giza-cli==0.14.0
giza-cli==0.14.1
onnx==1.14.1
tf2onnx==1.15.1
torch==2.1.0
Expand Down
2 changes: 1 addition & 1 deletion giza/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os

__version__ = "0.14.0"
__version__ = "0.14.1"
# Until DNS is fixed
API_HOST = os.environ.get("GIZA_API_HOST", "https://api.gizatech.xyz")
4 changes: 2 additions & 2 deletions giza/schemas/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, RootModel

from giza.utils.enums import Framework, ServiceSize

Expand Down Expand Up @@ -33,5 +33,5 @@ class Endpoint(BaseModel):
model_config["protected_namespaces"] = ()


class EndpointsList(BaseModel):
class EndpointsList(RootModel):
root: list[Endpoint]
4 changes: 2 additions & 2 deletions giza/schemas/jobs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Optional

from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, RootModel

from giza.utils.enums import Framework, JobKind, JobSize, JobStatus

Expand Down Expand Up @@ -29,5 +29,5 @@ class JobCreate(BaseModel):
model_config["protected_namespaces"] = ()


class JobList(BaseModel):
class JobList(RootModel):
root: list[Job]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "giza-cli"
version = "0.14.0"
version = "0.14.1"
description = "CLI for interacting with Giza"
authors = ["Gonzalo Mellizo-Soto <[email protected]>"]
readme = "README.md"
Expand Down
27 changes: 27 additions & 0 deletions tests/commands/test_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,33 @@ def test_list_deployments():
assert "giza-deployment-2" in result.stdout


def test_create_deployments_empty():
deployments_list = EndpointsList(root=[])
with patch.object(
EndpointsClient, "list", return_value=deployments_list
) as mock_list, patch.object(
EndpointsClient,
"create",
return_value=Endpoint(
id=1,
status="COMPLETED",
uri="https://giza-api.com/deployments/1",
size="S",
service_name="giza-deployment-1",
model_id=1,
version_id=1,
is_active=True,
),
):
result = invoke_cli_runner(
["endpoints", "deploy", "--model-id", "1", "--version-id", "1"],
)
mock_list.assert_called_once()
assert result.exit_code == 0
assert "Endpoint is successful" in result.stdout
assert "https://giza-api.com/deployments/1" in result.stdout


def test_list_deployments_http_error():
with patch.object(EndpointsClient, "list", side_effect=HTTPError):
result = invoke_cli_runner(
Expand Down

0 comments on commit 24a19d6

Please sign in to comment.