Skip to content

Commit

Permalink
fix: import/export typing
Browse files Browse the repository at this point in the history
  • Loading branch information
thearchitector committed Jan 9, 2025
1 parent 3bde9ce commit cfb3c15
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 47 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ You will also need the following env variables set for the Exchange integration
2. Activate the virtual environment
`source venv/bin/activate`
3. Install the client
`pip3 install --editable .`
`pip3 install --editable .[all]`
4. Install test deps
`pip3 install "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"`
5. Run tests
Expand Down
2 changes: 0 additions & 2 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
version: "3"

services:
indico-client-build:
build:
Expand Down
29 changes: 18 additions & 11 deletions indico/queries/model_export.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import TYPE_CHECKING

from indico.client.request import Delay, GraphQLRequest, RequestChain
from indico.types.model_export import ModelExport

if TYPE_CHECKING: # pragma: no cover
from typing import Any, Iterator, List, Union

from indico.typing import Payload


class _CreateModelExport(GraphQLRequest):
class _CreateModelExport(GraphQLRequest["ModelExport"]):
query = """
mutation ($modelId: Int!) {
createModelExport(
Expand All @@ -20,11 +27,11 @@ def __init__(self, model_id: int):
self.model_id = model_id
super().__init__(self.query, variables={"modelId": model_id})

def process_response(self, response) -> ModelExport:
return ModelExport(**super().process_response(response)["createModelExport"])
def process_response(self, response: "Payload") -> ModelExport:
return ModelExport(**super().parse_payload(response)["createModelExport"])


class CreateModelExport(RequestChain):
class CreateModelExport(RequestChain["List[ModelExport]"]):
"""
Create a model export.
Expand All @@ -36,20 +43,20 @@ class CreateModelExport(RequestChain):
request_interval (int | float): the interval between requests in seconds. Defaults to 5.
"""

previous: ModelExport | None = None
previous: "Any" = None

def __init__(
self,
model_id: int,
wait: bool = True,
request_interval: int | float = 5,
request_interval: "Union[int, float]" = 5,
):
self.wait = wait
self.model_id = model_id
self.request_interval = request_interval
super().__init__()

def requests(self):
def requests(self) -> "Iterator[Union[_CreateModelExport, Delay, GetModelExports]]":
yield _CreateModelExport(self.model_id)
if self.wait:
while self.previous and self.previous.status not in ["COMPLETE", "FAILED"]:
Expand All @@ -60,7 +67,7 @@ def requests(self):
yield GetModelExports([self.previous.id], with_signed_url=self.wait is True)


class GetModelExports(GraphQLRequest):
class GetModelExports(GraphQLRequest["List[ModelExport]"]):
"""
Get model export(s).
Expand Down Expand Up @@ -91,17 +98,17 @@ class GetModelExports(GraphQLRequest):
"createdBy",
]

def __init__(self, export_ids: list[int], with_signed_url: bool = False):
def __init__(self, export_ids: "List[int]", with_signed_url: bool = False):
if with_signed_url:
self._base_fields.append("signedUrl")

query_with_fields = self.query.replace("{fields}", "\n".join(self._base_fields))
super().__init__(query_with_fields, variables={"exportIds": export_ids})

def process_response(self, response) -> list[ModelExport]:
def process_response(self, response: "Payload") -> "List[ModelExport]":
return [
ModelExport(**export)
for export in super().process_response(response)["modelExports"][
for export in super().parse_payload(response)["modelExports"][
"modelExports"
]
]
48 changes: 30 additions & 18 deletions indico/queries/model_import.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from typing import Generator
from typing import TYPE_CHECKING, cast

import requests

from indico.client.request import GraphQLRequest, RequestChain
from indico.errors import IndicoInputError, IndicoRequestError
from indico.queries.jobs import JobStatus
from indico.types.jobs import Job

from .jobs import JobStatus

class _UploadSMExport(GraphQLRequest):
if TYPE_CHECKING: # pragma: no cover
from typing import Dict, Iterator, Optional, Union # noqa: F401

from indico.typing import Payload


class _UploadSMExport(GraphQLRequest[str]):
query = """
query exportUpload {
exportUpload {
Expand All @@ -22,25 +28,26 @@ def __init__(self, file_path: str):
self.file_path = file_path
super().__init__(self.query)

def process_response(self, response) -> str:
resp = super().process_response(response)["exportUpload"]
def process_response(self, response: "Payload") -> str:
resp: "Dict[str, str]" = super().parse_payload(response)["exportUpload"]
signed_url = resp["signedUrl"]
storage_uri = resp["storageUri"]

with open(self.file_path, "rb") as file:
file_content = file.read()

headers = {"Content-Type": "application/zip"}
response = requests.put(signed_url, data=file_content, headers=headers)
export_response = requests.put(signed_url, data=file_content, headers=headers)

if response.status_code != 200:
if export_response.status_code != 200:
raise IndicoRequestError(
f"Failed to upload static model export: {response.text}"
f"Failed to upload static model export: {export_response.text}",
export_response.status_code,
)
return storage_uri


class ProcessStaticModelExport(GraphQLRequest):
class ProcessStaticModelExport(GraphQLRequest["Job"]):
"""
Process a static model export.
Expand Down Expand Up @@ -77,12 +84,12 @@ def __init__(
},
)

def process_response(self, response) -> Job:
job_id = super().process_response(response)["processStaticModelExport"]["jobId"]
def process_response(self, response: "Payload") -> Job:
job_id = super().parse_payload(response)["processStaticModelExport"]["jobId"]
return Job(id=job_id)


class UploadStaticModelExport(RequestChain):
class UploadStaticModelExport(RequestChain["Union[Job, str]"]):
"""
Upload a static model export to Indico.
Expand All @@ -100,22 +107,27 @@ class UploadStaticModelExport(RequestChain):
"""

def __init__(
self, file_path: str, auto_process: bool = False, workflow_id: int | None = None
self,
file_path: str,
auto_process: bool = False,
workflow_id: "Optional[int]" = None,
):
self.file_path = file_path
self.auto_process = auto_process
if auto_process and not workflow_id:
if auto_process and workflow_id is None:
raise IndicoInputError(
"Must provide `workflow_id` if `auto_process` is True."
)

self.file_path = file_path
self.auto_process = auto_process
self.workflow_id = workflow_id

def requests(self) -> Generator[str | Job, None, None]:
def requests(
self,
) -> "Iterator[Union[_UploadSMExport, ProcessStaticModelExport, JobStatus]]":
if self.auto_process:
yield _UploadSMExport(self.file_path)
yield ProcessStaticModelExport(
storage_uri=self.previous, workflow_id=self.workflow_id
storage_uri=self.previous, workflow_id=cast(int, self.workflow_id)
)
yield JobStatus(self.previous.id)
if self.previous.status == "FAILURE":
Expand Down
24 changes: 13 additions & 11 deletions indico/queries/workflow_components.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast

import jsons

Expand All @@ -15,7 +15,7 @@
)

if TYPE_CHECKING: # pragma: no cover
from typing import Iterator, List, Optional, Union
from typing import Any, Iterator, List, Optional, Union

from indico.typing import AnyDict, Payload

Expand Down Expand Up @@ -455,7 +455,7 @@ def process_response(self, response: "Payload") -> "Workflow":
)


class AddStaticModelComponent(RequestChain):
class AddStaticModelComponent(RequestChain["Workflow"]):
"""
Add a static model component to a workflow.
Expand All @@ -470,17 +470,17 @@ class AddStaticModelComponent(RequestChain):
`export_file(str)`: the path to the static model export file.
"""

previous = None
previous: "Any" = None

def __init__(
self,
workflow_id: int,
after_component_id: int | None = None,
after_component_link_id: int | None = None,
static_component_config: dict[str, Any] | None = None,
component_name: str | None = None,
after_component_id: "Optional[int]" = None,
after_component_link_id: "Optional[int]" = None,
static_component_config: "Optional[AnyDict]" = None,
component_name: "Optional[str]" = None,
auto_process: bool = False,
export_file: str | None = None,
export_file: "Optional[str]" = None,
):
if not export_file and auto_process:
raise IndicoInputError("Must provide export_file if auto_process is True.")
Expand Down Expand Up @@ -511,11 +511,13 @@ def __init__(
self.auto_process = auto_process
self.export_file = export_file

def requests(self):
def requests(
self,
) -> "Iterator[Union[UploadStaticModelExport, _AddWorkflowComponent]]":
if self.auto_process:
yield UploadStaticModelExport(
auto_process=True,
file_path=self.export_file,
file_path=cast(str, self.export_file),
workflow_id=self.workflow_id,
)
self.component.update(
Expand Down
8 changes: 4 additions & 4 deletions tests/integration/queries/test_workflow_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
ModelGroup,
ModelTaskType,
NewLabelsetArguments,
StaticModelConfig,
# StaticModelConfig,
)

from ..data.datasets import * # noqa
Expand Down Expand Up @@ -257,9 +257,9 @@ def test_add_static_model_component(indico, org_annotate_dataset):
static_model_req = AddStaticModelComponent(
workflow_id=wf.id,
after_component_id=after_component_id,
static_component_config=StaticModelConfig(
export_meta=finished_job.result,
),
static_component_config={
"export_meta": finished_job.result,
},
)
wf = client.call(static_model_req)

Expand Down

0 comments on commit cfb3c15

Please sign in to comment.