diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..b9d0a461 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +venv +dist +docs +docsrc +.hypothesis +.mypy_cache +.pytest_cache +.tox +build +readme_docs +scripts diff --git a/.gitignore b/.gitignore index 353831df..229ed956 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ pip-log.txt .coverage .tox nosetests.xml +.mypy_cache # Translations *.mo @@ -37,6 +38,7 @@ nosetests.xml .noseids .pytest* *.json +.hypothesis codeship.aes .env @@ -55,10 +57,11 @@ examples/readme.py docsrc/_build docs/.doctrees setup_tests.sh +test.csv .venv .vscode/ .DS_Store .ipynb_checkpoints/ -*.ipynb \ No newline at end of file +*.ipynb diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 81452a10..4a16eeb8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,19 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 + rev: v4.5.0 hooks: - id: check-yaml + exclude: ^docs - id: end-of-file-fixer + exclude: ^docs - id: trailing-whitespace + exclude: ^docs +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 23.11.0 hooks: - id: black - language_version: python3 diff --git a/AUTHORS b/AUTHORS index ca3d20cd..06b5eca9 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,2 +1,2 @@ Chris Lee -Nicolas Kassis \ No newline at end of file +Nicolas Kassis diff --git a/CHANGES.txt b/CHANGES.txt index 94d1e6f5..3366f7d4 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1 +1 @@ -Oct 16 2019 v2.0.0: Complete change of IndicoIo API through the Indico IPA platform. For IPA Platform API use only. \ No newline at end of file +Oct 16 2019 v2.0.0: Complete change of IndicoIo API through the Indico IPA platform. For IPA Platform API use only. diff --git a/CODEOWNERS b/CODEOWNERS index 07e76c37..18b87b03 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1 +1 @@ -* @IndicoDataSolutions/pr-be-indicodata-ai \ No newline at end of file +* @IndicoDataSolutions/pr-be-indicodata-ai diff --git a/README.md b/README.md index 3303f451..280be057 100644 --- a/README.md +++ b/README.md @@ -105,13 +105,20 @@ You will also need the following env variables set for the Exchange integration 2. Activate the virtual environment `source venv/bin/activate` 3. Install the client - `python3 setup.py install` -4. Install pytest - `pip3 install pytest` + `pip3 install --editable .` +4. Install test deps + `pip3 install "pytest<8" "requests-mock>=1.8.0" "pytest-asyncio>0.21"` 5. Run tests `pytest -sv --host tests/` _ Only run unit tests `pytest -sv --host tests/unit/` _ Only run integration tests `pytest -sv --host tests/integration/` +6. Check typing + ```sh + pip3 install "mypy==1.8" typing_extensions pandas-stubs types-requests types-pytz types-openpyxl + mypy --config-file=pyproject.toml + ``` + +Alternatively, run all the unit tests via `docker compose run --rm tester` # Contributing diff --git a/README.rst b/README.rst index 7234a576..4d555f37 100644 --- a/README.rst +++ b/README.rst @@ -19,4 +19,4 @@ From source: .. code:: bash git clone https://github.com/IndicoDataSolutions/indico-client-python.git - python setup.py install \ No newline at end of file + python setup.py install diff --git a/docker-compose.yaml b/docker-compose.yaml index 815e8da0..5b6b4c59 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -4,4 +4,15 @@ services: indico-client-build: build: context: . - entrypoint: [ "sleep", "infinity" ] + entrypoint: ["sleep", "infinity"] + + tester: + build: + context: . + dockerfile: tox.Dockerfile + volumes: + - ./:/indico-client + - ${INDICO_API_TOKEN_PATH:-$HOME/indico_api_token.txt}:/root/indico_api_token.txt + environment: + INDICO_HOST: ${INDICO_HOST:-dev-ci.us-east-2.indico-dev.indico.io} + command: tox -vvvvp diff --git a/docsrc/conf.py b/docsrc/conf.py index 8e56797b..9a25d7d7 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -35,7 +35,7 @@ "sphinx.ext.autodoc", "sphinx.ext.napoleon", "sphinx_autodoc_typehints", - "sphinx_markdown_builder" + "sphinx_markdown_builder", ] # Add any paths that contain templates here, relative to this directory. diff --git a/examples/aio_client.py b/examples/aio_client.py index efaa9db9..cc73059a 100644 --- a/examples/aio_client.py +++ b/examples/aio_client.py @@ -4,10 +4,10 @@ """ import asyncio + from indico import AsyncIndicoClient, IndicoConfig from indico.queries import CreateDataset - """ Examples for client creation """ diff --git a/examples/classifier_extraction_workflow/component_demo.py b/examples/classifier_extraction_workflow/component_demo.py index a90c8fbb..7910c505 100644 --- a/examples/classifier_extraction_workflow/component_demo.py +++ b/examples/classifier_extraction_workflow/component_demo.py @@ -3,32 +3,31 @@ Note the dataset in this example was built from data/swaps_with_class.csv """ + from datetime import datetime +from component_links import get_component_link_id +from config import CLASSIFIER_CLASSES, DATASET_ID, INDICO_CLIENT +from queries import get_user_ids + +from indico import IndicoClient, IndicoConfig from indico.queries import ( - GetDataset, - AddModelGroupComponent, AddLinkClassificationComponent, + AddModelGroupComponent, CreateWorkflow, + GetDataset, ) -from indico.types import NewLabelsetArguments, NewQuestionnaireArguments, ModelTaskType - -from config import INDICO_CLIENT, DATASET_ID, CLASSIFIER_CLASSES -from queries import get_user_ids -from component_links import get_component_link_id - - -from indico import IndicoClient, IndicoConfig +from indico.types import ModelTaskType, NewLabelsetArguments, NewQuestionnaireArguments HOST = "indico.host" API_TOKEN_PATH = "/path/to/indico_api_token.txt" INDICO_CONFIG = IndicoConfig(host=HOST, api_token_path=API_TOKEN_PATH) INDICO_CLIENT = IndicoClient(config=INDICO_CONFIG) -DATASET_ID = +DATASET_ID = int("") CLASSIFIER_CLASSES = ["class A", "class B"] - + timestamp = datetime.now() dataset = INDICO_CLIENT.call(GetDataset(DATASET_ID)) @@ -50,13 +49,13 @@ "name": classifier_name, "num_labelers_required": 1, "task_type": ModelTaskType.CLASSIFICATION, - "target_names": CLASSIFIER_CLASSES + "target_names": CLASSIFIER_CLASSES, } questionnaire_args = { - "instructions": "Click things", - "show_predictions": True, - "users": user_ids - } + "instructions": "Click things", + "show_predictions": True, + "users": user_ids, +} classifier_wf = INDICO_CLIENT.call( AddModelGroupComponent( @@ -82,7 +81,6 @@ after_component_id=classifier_component.id, filtered_classes=filtered_classes, labels=None, - ) ) @@ -93,14 +91,14 @@ class_a_link_id = get_component_link_id(workflow.id, ["class A"]) class_b_link_id = get_component_link_id(workflow.id, ["class B"]) -#create extraction model 1 +# create extraction model 1 extraction_model_1_name = f"Extraction 1 {timestamp}" new_labelset_args = { "datacolumn_id": source_col_id, "name": extraction_model_1_name, "num_labelers_required": 1, "task_type": ModelTaskType.ANNOTATION, - "target_names": ["extraction class 1"] + "target_names": ["extraction class 1"], } extraction_model_1 = INDICO_CLIENT.call( AddModelGroupComponent( @@ -114,14 +112,14 @@ ) ) -#create extraction model 2 +# create extraction model 2 extraction_model_2_name = f"Extraction 2 {timestamp}" new_labelset_args = { "datacolumn_id": source_col_id, "name": extraction_model_2_name, "num_labelers_required": 1, "task_type": ModelTaskType.ANNOTATION, - "target_names": ["extraction class 2"] + "target_names": ["extraction class 2"], } extraction_model_2 = INDICO_CLIENT.call( AddModelGroupComponent( diff --git a/examples/classifier_extraction_workflow/component_links.py b/examples/classifier_extraction_workflow/component_links.py index 518f8aef..3dc4c565 100644 --- a/examples/classifier_extraction_workflow/component_links.py +++ b/examples/classifier_extraction_workflow/component_links.py @@ -1,5 +1,6 @@ from queries import get_component_links + def get_component_link_id(workflow_id, class_filter): component_links = get_component_links(workflow_id) for component_link in component_links: @@ -7,5 +8,3 @@ def get_component_link_id(workflow_id, class_filter): if link_filter: if class_filter == link_filter["classes"]: return component_link["id"] - - \ No newline at end of file diff --git a/examples/classifier_extraction_workflow/queries.py b/examples/classifier_extraction_workflow/queries.py index dadde731..879b1806 100644 --- a/examples/classifier_extraction_workflow/queries.py +++ b/examples/classifier_extraction_workflow/queries.py @@ -24,18 +24,18 @@ def get_component_links(workflow_id, indico_client): workflow_query = """ query ListWorkflows($workflowIds: [Int]){ workflows(workflowIds: $workflowIds){ - workflows { + workflows { componentLinks{ id headComponentId tailComponentId filters{ classes - } + } } } } - } + } """ variables = {"workflowIds": [workflow_id]} resp = indico_client.call(GraphQLRequest(workflow_query, variables)) diff --git a/examples/create-download-export.py b/examples/create-download-export.py index 93fe93f3..0bb60845 100644 --- a/examples/create-download-export.py +++ b/examples/create-download-export.py @@ -1,7 +1,7 @@ -from indico import IndicoClient, IndicoConfig -from indico.queries import GetDataset, CreateExport, DownloadExport import pandas as pd +from indico import IndicoClient, IndicoConfig +from indico.queries import CreateExport, DownloadExport, GetDataset """Example 1: Create export from dataset and download as csv""" diff --git a/examples/create-image-dataset.py b/examples/create-image-dataset.py index 6ef618f4..4ecfa1fc 100644 --- a/examples/create-image-dataset.py +++ b/examples/create-image-dataset.py @@ -6,9 +6,10 @@ below shows you what to do if the images are on your computer. """ +import pandas as pd + from indico import IndicoClient, IndicoConfig from indico.queries import CreateDataset -import pandas as pd # Create an Indico API client my_config = IndicoConfig( diff --git a/examples/custom_client.py b/examples/custom_client.py index 2004a529..22f31944 100644 --- a/examples/custom_client.py +++ b/examples/custom_client.py @@ -1,4 +1,5 @@ import os + from indico import IndicoClient, IndicoConfig # Will connect to https://try.indico.io diff --git a/examples/get-document-report.py b/examples/get-document-report.py index 4f266f32..aed8e4c6 100644 --- a/examples/get-document-report.py +++ b/examples/get-document-report.py @@ -1,10 +1,10 @@ from datetime import datetime from typing import List -from indico import IndicoConfig, IndicoClient + +from indico import IndicoClient, IndicoConfig +from indico.filters import DocumentReportFilter from indico.queries.document_report import GetDocumentReport from indico.types.document_report import DocumentReport -from indico.filters import DocumentReportFilter - """Example 1: Document Report for a date range and page through the results""" diff --git a/examples/get-user-metrics.py b/examples/get-user-metrics.py index 2086207f..b86531f1 100644 --- a/examples/get-user-metrics.py +++ b/examples/get-user-metrics.py @@ -1,13 +1,13 @@ from datetime import datetime, timedelta -from indico import IndicoConfig, IndicoClient -from indico.filters import or_, UserMetricsFilter +from indico import IndicoClient, IndicoConfig +from indico.filters import UserMetricsFilter, or_ from indico.queries import JobStatus, RetrieveStorageObject from indico.queries.usermetrics import ( - GetUserSummary, - GetUserSnapshots, GenerateChangelogReport, GetUserChangelog, + GetUserSnapshots, + GetUserSummary, ) from indico.types.user_metrics import UserSummary diff --git a/examples/get-workflow-metrics.py b/examples/get-workflow-metrics.py index 32550293..9ad03fff 100644 --- a/examples/get-workflow-metrics.py +++ b/examples/get-workflow-metrics.py @@ -1,7 +1,9 @@ from datetime import datetime from typing import List + import pandas as pd -from indico import IndicoConfig, IndicoClient + +from indico import IndicoClient, IndicoConfig from indico.queries import ListWorkflows from indico.queries.workflow_metrics import GetWorkflowMetrics from indico.types.workflow_metrics import WorkflowMetrics, WorkflowMetricsOptions diff --git a/examples/graphql.py b/examples/graphql.py index a35ee1de..f7c42856 100644 --- a/examples/graphql.py +++ b/examples/graphql.py @@ -1,7 +1,6 @@ from indico import IndicoClient, IndicoConfig from indico.queries import GraphQLRequest - my_config = IndicoConfig( host="try.indico.io", api_token_path="./path/to/indico_api_token.txt" ) diff --git a/examples/train-classification-model.py b/examples/train-classification-model.py index 7f916b70..03e5b6a7 100644 --- a/examples/train-classification-model.py +++ b/examples/train-classification-model.py @@ -1,10 +1,10 @@ from indico import IndicoClient, IndicoConfig from indico.queries import ( - CreateDataset, - ModelGroupPredict, AddModelGroupComponent, - JobStatus, + CreateDataset, GetModelGroupSelectedModelStatus, + JobStatus, + ModelGroupPredict, ) # Create an Indico API client diff --git a/examples/typed_calls.py b/examples/typed_calls.py new file mode 100644 index 00000000..776a4559 --- /dev/null +++ b/examples/typed_calls.py @@ -0,0 +1,52 @@ +""" +The async client can be used as a replacement for the synchronous +IndicoClient to make concurrent calls to the platform +""" + +import asyncio +from typing import AsyncIterator, List + +from indico import AsyncIndicoClient, IndicoConfig +from indico.queries import CreateDataset +from indico.queries.datasets import GetDataset, ListDatasets +from indico.types.dataset import Dataset +from indico.types.submission import Submission + +""" +Example illustrating how to use the client in typed contexts +""" + +config = IndicoConfig(host="try.indico.io") + + +async def main(): + async with AsyncIndicoClient(config=config) as client: + ipa_version: str = await client.get_ipa_version() + print(ipa_version) + + filename: str = "my_file_for_all_datasets.pdf" + + # CreateDataset is typed to return a Dataset, so multiple concurrent calls + # via asyncio.gather should, and does, return List[Dataset] + datasets: List[Dataset] = await asyncio.gather( + *( + client.call(CreateDataset(name=f"My Dataset {i}", files=[filename])) + for i in range(1, 4) + ) + ) + assert len(datasets) == 3 + + # paginated calls are also properly typed + pages: AsyncIterator[List[Dataset]] = client.paginate(ListDatasets()) + async for datasets in pages: + for d in datasets: + print(d.id) + + # incorrect typing will throw mypy / ide linting errors when using those tools. + # here, Pyright correctly reports '"Dataset" is not the same as "Submission"' + not_a_submission: Submission = await client.call(GetDataset(datasets[0].id)) + + +if __name__ == "__main__": + # How to run a Python script using async + asyncio.run(main()) diff --git a/examples/workflow.py b/examples/workflow.py index 5d6f5566..cf0702f0 100644 --- a/examples/workflow.py +++ b/examples/workflow.py @@ -1,21 +1,21 @@ from indico import IndicoClient, IndicoConfig from indico.queries import ( - JobStatus, - RetrieveStorageObject, - WorkflowSubmission, + AddLinkClassificationComponent, + AddLinkedLabelComponent, + AddModelGroupComponent, CreateWorkflow, GetDataset, - ListWorkflows, - AddModelGroupComponent, GetWorkflow, - AddLinkedLabelComponent, - AddLinkClassificationComponent, + JobStatus, + ListWorkflows, + RetrieveStorageObject, + WorkflowSubmission, ) from indico.types import ( - NewLabelsetArguments, - ModelTaskType, LinkedLabelGroup, LinkedLabelStrategy, + ModelTaskType, + NewLabelsetArguments, ) # Use your dataset's id to call its associated workflow diff --git a/indico/__init__.py b/indico/__init__.py index 13c158f2..21dfec16 100644 --- a/indico/__init__.py +++ b/indico/__init__.py @@ -1,5 +1,6 @@ import logging -from indico._version import get_versions + +from indico._version import get_versions # type: ignore logging.basicConfig() Version = version = __version__ = VERSION = get_versions()["version"] @@ -7,7 +8,7 @@ from indico.client import * from indico.client.request import * -from ._version import get_versions +from ._version import get_versions # type: ignore __version__ = get_versions()["version"] del get_versions diff --git a/indico/_version.py b/indico/_version.py index 0af9ef31..ebc46f87 100644 --- a/indico/_version.py +++ b/indico/_version.py @@ -1,3 +1,6 @@ +# type: ignore +# pragma: no cover + # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build diff --git a/indico/client/client.py b/indico/client/client.py index 8710a359..b25db3cd 100644 --- a/indico/client/client.py +++ b/indico/client/client.py @@ -2,21 +2,42 @@ import asyncio import time -from typing import Optional, Union +from typing import TYPE_CHECKING, cast import urllib3 -from indico.client.request import ( - GraphQLRequest, - HTTPRequest, - PagedRequest, - RequestChain, -) +from indico.client.request import Delay, GraphQLRequest, HTTPRequest, RequestChain from indico.config import IndicoConfig from indico.errors import IndicoError -from indico.client.request import Delay from indico.http.client import AIOHTTPClient, HTTPClient +if TYPE_CHECKING: # pragma: no cover + from types import TracebackType + from typing import AsyncIterator, Iterator, Optional, Type, TypeVar, Union + + from typing_extensions import Self + + from indico.client.request import PagedRequest + from indico.typing import Payload + + ReturnType = TypeVar("ReturnType") + + +# here to avoid circular imports +class GetIPAVersion(GraphQLRequest[str]): + query = """ + "query getIPAVersion { + ipaVersion + } + """ + + def __init__(self) -> None: + super().__init__(self.query) + + def process_response(self, response: "Payload") -> str: + version: str = super().parse_payload(response)["ipaVersion"] + return version + class IndicoClient: """ @@ -34,35 +55,44 @@ class IndicoClient: RuntimeError: If api_token_path does not exist. """ - def __init__(self, config: Optional[IndicoConfig] = None): + def __init__(self, config: "Optional[IndicoConfig]" = None): if not config: config = IndicoConfig() + if not config.verify_ssl: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + self.config = config self._http = HTTPClient(config) - def _handle_request_chain(self, chain: RequestChain): - response = None + def _handle_request_chain( + self, + chain: "RequestChain[ReturnType]", + ) -> "ReturnType": + response: "Optional[ReturnType]" = None + for request in chain.requests(): - if isinstance(request, HTTPRequest): - response = self._http.execute_request(request) - chain.previous = response - elif isinstance(request, RequestChain): + if isinstance(request, RequestChain): response = self._handle_request_chain(request) chain.previous = response + elif isinstance(request, HTTPRequest): + response = self._http.execute_request(request) + chain.previous = response elif isinstance(request, Delay): time.sleep(request.seconds) - if chain.result: + + if chain.result is not None: return chain.result - return response - def get_ipa_version(self): - return self._http.execute_request( - GraphQLRequest("query getIPAVersion {\n ipaVersion\n}\n") - )["ipaVersion"] + return cast("ReturnType", response) + + def get_ipa_version(self) -> str: + return self.call(GetIPAVersion()) - def call(self, request: Union[HTTPRequest, RequestChain]): + def call( + self, + request: "Union[HTTPRequest[ReturnType], RequestChain[ReturnType]]", + ) -> "ReturnType": """ Make a call to the Indico IPA Platform @@ -78,10 +108,14 @@ def call(self, request: Union[HTTPRequest, RequestChain]): if isinstance(request, RequestChain): return self._handle_request_chain(request) - elif request and isinstance(request, HTTPRequest): + elif isinstance(request, HTTPRequest): return self._http.execute_request(request) + else: + raise ValueError( + "Invalid request type! Must be one of HTTPRequest or RequestChain." + ) - def paginate(self, request: PagedRequest): + def paginate(self, request: "PagedRequest[ReturnType]") -> "Iterator[ReturnType]": """ Provides a generator that continues paging through responses Available with List<> Requests that offer pagination @@ -120,54 +154,63 @@ class AsyncIndicoClient: RuntimeError: If api_token_path does not exist. """ - def __init__(self, config: Optional[IndicoConfig] = None): + def __init__(self, config: "Optional[IndicoConfig]" = None): if not config: config = IndicoConfig() if not config.verify_ssl: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + self.config = config self._http = AIOHTTPClient(config) - self._created = False + self._created: bool = False - async def __aenter__(self): + async def __aenter__(self) -> "Self": return await self.create() - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: "Optional[Type[BaseException]]", + exc: "Optional[BaseException]", + tb: "Optional[TracebackType]", + ) -> None: await self.cleanup() - async def create(self): + async def create(self) -> "Self": await self._http.get_short_lived_access_token() self._created = True return self - async def cleanup(self): + async def cleanup(self) -> None: await self._http.request_session.close() - async def _handle_request_chain(self, chain: RequestChain): - response = None + async def _handle_request_chain( + self, + chain: "RequestChain[ReturnType]", + ) -> "ReturnType": + response: "Optional[ReturnType]" = None + for request in chain.requests(): - if isinstance(request, HTTPRequest): - response = await self._http.execute_request(request) - chain.previous = response - elif isinstance(request, RequestChain): + if isinstance(request, RequestChain): response = await self._handle_request_chain(request) chain.previous = response + elif isinstance(request, HTTPRequest): + response = await self._http.execute_request(request) + chain.previous = response elif isinstance(request, Delay): await asyncio.sleep(request.seconds) - if chain.result: + + if chain.result is not None: return chain.result - return response - async def get_ipa_version(self): - if not self._created: - raise IndicoError("Please .create() your client") - return ( - await self._http.execute_request( - GraphQLRequest("query getIPAVersion {ipaVersion}") - ) - )["ipaVersion"] + return cast("ReturnType", response) + + async def get_ipa_version(self) -> str: + return await self.call(GetIPAVersion()) - async def call(self, request: Union[HTTPRequest, RequestChain]): + async def call( + self, + request: "Union[HTTPRequest[ReturnType], RequestChain[ReturnType]]", + ) -> "ReturnType": """ Make a call to the Indico IPA Platform @@ -185,10 +228,16 @@ async def call(self, request: Union[HTTPRequest, RequestChain]): if isinstance(request, RequestChain): return await self._handle_request_chain(request) - elif request and isinstance(request, HTTPRequest): + elif isinstance(request, HTTPRequest): return await self._http.execute_request(request) + else: + raise ValueError( + "Invalid request type! Must be one of HTTPRequest or RequestChain." + ) - async def paginate(self, request: PagedRequest): + async def paginate( + self, request: "PagedRequest[ReturnType]" + ) -> "AsyncIterator[ReturnType]": """ Provides a generator that continues paging through responses Available with List<> Requests that offer pagination diff --git a/indico/client/request.py b/indico/client/request.py index 1db5e11f..7b8e43b9 100644 --- a/indico/client/request.py +++ b/indico/client/request.py @@ -1,8 +1,16 @@ +from abc import abstractmethod from enum import Enum -from typing import Any, Dict, Union +from typing import TYPE_CHECKING, Generic, TypeVar, cast from indico.errors import IndicoRequestError +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Dict, Iterator, List, Optional, Union + + from indico.typing import AnyDict + +ResponseType = TypeVar("ResponseType", covariant=True) + class HTTPMethod(Enum): GET = "GET" @@ -13,41 +21,55 @@ class HTTPMethod(Enum): OPTIONS = "OPTIONS" -class HTTPRequest: - def __init__(self, method: HTTPMethod, path: str, **kwargs): - self.method = method - self.path = path - self.kwargs = kwargs +class HTTPRequest(Generic[ResponseType]): + def __init__(self, method: HTTPMethod, path: str, **kwargs: "Any"): + self.method: HTTPMethod = method + self.path: str = path + self._kwargs: "AnyDict" = kwargs + + @property + def kwargs(self) -> "AnyDict": + return self._kwargs + + def process_response(self, response: "Any") -> "ResponseType": + return cast("ResponseType", response) - def process_response(self, response): - return response +class GraphQLRequest(Generic[ResponseType], HTTPRequest[ResponseType]): + def __init__(self, query: str, variables: "Optional[AnyDict]" = None): + self.query: str = query + self.variables: "Optional[AnyDict]" = variables -class GraphQLRequest(HTTPRequest): - def __init__(self, query: str, variables: Dict[str, Any] = None): - self.query = query - self.variables = variables - self.method = HTTPMethod.POST - self.path = "/graph/api/graphql" + super().__init__(HTTPMethod.POST, "/graph/api/graphql") @property - def kwargs(self): + def kwargs(self) -> "AnyDict": return {"json": {"query": self.query, "variables": self.variables}} - def process_response(self, response): - response = super().process_response(response) - errors = response.pop("errors", []) + def parse_payload(self, response: "AnyDict") -> "Any": + raw_response: "AnyDict" = cast("AnyDict", super().process_response(response)) + errors: "List[AnyDict]" = raw_response.pop("errors", []) + if errors: - extras = {"locations": [error.pop("locations", None) for error in errors]} + extras: "Dict[str, List[Any]]" = { + "locations": [error.pop("locations", None) for error in errors] + } + raise IndicoRequestError( error="\n".join(error["message"] for error in errors), code=400, extras=extras, ) - return response["data"] + return raw_response["data"] + + def process_response(self, response: "AnyDict") -> "ResponseType": + raw_response = self.parse_payload(response) + # technically incorrect, but necessary for backwards compatibility + return cast("ResponseType", raw_response) -class PagedRequest(GraphQLRequest): + +class PagedRequest(GraphQLRequest[ResponseType]): """ To enable pagination, query must include $after as an argument and request pageInfo @@ -68,27 +90,43 @@ class PagedRequest(GraphQLRequest): } """ - def __init__(self, query: str, variables: Dict[str, Any] = None): + def __init__(self, query: str, variables: "Optional[AnyDict]" = None): + if variables is None: + variables = {} + variables["after"] = None self.has_next_page = True super().__init__(query, variables=variables) - def process_response(self, response): - response = super().process_response(response) - _pg = next(iter(response.values()))["pageInfo"] + def parse_payload(self, response: "AnyDict") -> "Any": + raw_response: "AnyDict" = cast("AnyDict", super().parse_payload(response)) + + _pg = next(iter(raw_response.values())).get("pageInfo") + if not _pg: + raise ValueError("The supplied GraphQL must include 'pageInfo'.") + self.has_next_page = _pg["hasNextPage"] - self.variables["after"] = _pg["endCursor"] if self.has_next_page else None - return response + cast("AnyDict", self.variables)["after"] = ( + _pg["endCursor"] if self.has_next_page else None + ) + + return raw_response -class RequestChain: - previous: Any = None - result: Any = None +class RequestChain(Generic[ResponseType]): + previous: "Any" = None + result: "Optional[ResponseType]" = None - def requests(self): - pass + @abstractmethod + def requests( + self, + ) -> "Iterator[Union[RequestChain[Any], HTTPRequest[Any], Delay]]": + raise NotImplementedError( + "RequestChains must define an iterator for their requests;" + "otherwise, subclass GraphQLResponse instead." + ) class Delay: - def __init__(self, seconds: Union[int, float] = 2): + def __init__(self, seconds: "Union[int, float]" = 2): self.seconds = seconds diff --git a/indico/config/config.py b/indico/config/config.py index 1f253133..04713351 100644 --- a/indico/config/config.py +++ b/indico/config/config.py @@ -1,10 +1,16 @@ # -*- coding: utf-8 -*- import os - from pathlib import Path +from typing import TYPE_CHECKING + from indico.errors import IndicoInvalidConfigSetting +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Optional, Tuple, Union + + from indico.typing import AnyDict + class IndicoConfig: """ @@ -24,22 +30,17 @@ class IndicoConfig: RuntimeError: If api_token_path does not exist. """ - host: str - protocol: str - serializer: str - api_token_path: str - api_token: str = None - verify_ssl: bool = True - requests_params: dict = None - _disable_cookie_domain: bool = False - - def __init__(self, **kwargs): - - self.host: str = os.getenv("INDICO_HOST") + def __init__(self, **kwargs: "Any"): + self.host: str = os.getenv("INDICO_HOST", "") self.protocol: str = os.getenv("INDICO_PROTOCOL", "https") self.serializer: str = os.getenv("INDICO_SERIALIZER", "msgpack") - self.api_token_path: str = os.getenv("INDICO_API_TOKEN_PATH", Path.home()) - self.api_token: str = os.getenv("INDICO_API_TOKEN") + self.api_token_path: "Union[str, Path]" = os.getenv( + "INDICO_API_TOKEN_PATH", Path.home() + ) + self.api_token: "Optional[str]" = os.getenv("INDICO_API_TOKEN") + self.verify_ssl: bool = True + self.requests_params: "Optional[AnyDict]" = None + self._disable_cookie_domain: bool = False for key, value in kwargs.items(): if hasattr(self, key): @@ -50,10 +51,12 @@ def __init__(self, **kwargs): if not self.api_token: self.api_token_path, self.api_token = self._resolve_api_token() - def _resolve_api_token(self): + def _resolve_api_token(self) -> "Tuple[Path, str]": path = self.api_token_path + if not isinstance(path, Path): path = Path(path) + if not path.exists(): path = Path.home() if not path.is_file(): @@ -65,5 +68,4 @@ def _resolve_api_token(self): "or provided as indicoio.config.api_token_path" ) - with path.open("r") as f: - return path, f.read().strip() + return path, path.read_text().strip() diff --git a/indico/errors.py b/indico/errors.py index 515ee6f9..f96ce612 100644 --- a/indico/errors.py +++ b/indico/errors.py @@ -1,47 +1,59 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Union + + class IndicoError(Exception): - pass + """Base exception for known error modes""" class IndicoRequestError(IndicoError): - def __init__(self, error, code, extras=None): + def __init__( + self, + error: str, + code: "Union[int, str]", + extras: "Any" = None, + ): super().__init__(f"Status: {code}, Error: {error}\n\tExtras: {extras}") class IndicoTimeoutError(IndicoError): - def __init__(self, duration): + def __init__(self, duration: "Union[int, float]"): super().__init__(f"Request timed out after {duration:0.3f} seconds.") class IndicoDecodingError(IndicoError): - def __init__(self, mime, charset, content): + def __init__(self, mime: str, charset: str, content: str): super().__init__(f"Failed to decode with {mime}:{charset}. Content {content}") class IndicoInputError(IndicoError): - def __init__(self, msg): + def __init__(self, msg: str): super().__init__(msg) class IndicoInvalidConfigSetting(IndicoError): def __init__( self, - setting_name, + setting_name: str, ): super().__init__(f"{setting_name} is not a valid configuration setting") class IndicoNotFound(IndicoError): - def __init__(self, cls): + def __init__(self, cls: str): super().__init__(f"Could not find {cls}") class IndicoAuthenticationFailed(IndicoError): - def __init__(self): + def __init__(self) -> None: super().__init__("Failed to authenticate") class IndicoHibernationError(IndicoError): - def __init__(self, after): - self.after = after - super().__init__(f"Platform is currently hibernating. Wait {after} seconds and retry this request.") - + def __init__(self, after: int): + self.after: int = after + super().__init__( + f"Platform is currently hibernating. Wait {after} seconds and retry this request." + ) diff --git a/indico/filters/__init__.py b/indico/filters/__init__.py index 44d495be..52aa2c54 100644 --- a/indico/filters/__init__.py +++ b/indico/filters/__init__.py @@ -1,27 +1,32 @@ import datetime -from typing import Any, Iterable, List, Mapping, Union +from typing import TYPE_CHECKING, Dict from indico.errors import IndicoInputError +if TYPE_CHECKING: # pragma: no cover + from typing import Any, ClassVar, List, Optional, Tuple, Union -def or_(*args: Iterable[Mapping[str, Any]]): + from indico.typing import AnyDict + + +def or_(*args: "Any") -> "Dict[str, List[Any]]": return {"OR": list(args)} -def and_(*args: Iterable[Mapping[str, Any]]): +def and_(*args: "Any") -> "Dict[str, List[Any]]": return {"AND": list(args)} -class Filter(dict): +class Filter(Dict[str, "Any"]): """ Base filter class that allows users to construct filter statements for GraphQL queries. Search keys are constrained by the implementing subclasses If multiple arguments are supplied, they are treated as arg1 AND arg2 AND ... """ - __options__ = None + __options__: "ClassVar[Tuple[Any, ...]]" = tuple() - def __init__(self, **kwargs): + def __init__(self, **kwargs: "Any"): kwargs = {k: v for k, v in kwargs.items() if v is not None} if not kwargs: raise IndicoInputError(f"One of {self.__options__} must be specified") @@ -49,9 +54,9 @@ class SubmissionReviewFilter(Filter): def __init__( self, - rejected: Union[bool, None] = None, - created_by: Union[int, None] = None, - review_type: Union[str, None] = None, + rejected: "Optional[bool]" = None, + created_by: "Optional[int]" = None, + review_type: "Optional[str]" = None, ): kwargs = { "rejected": rejected, @@ -62,7 +67,7 @@ def __init__( super().__init__(**kwargs) -class DateRangeFilter(dict): +class DateRangeFilter(Dict[str, "Optional[str]"]): """ Create a Filter when querying for Submissions within a certain date range Args: @@ -71,9 +76,9 @@ class DateRangeFilter(dict): """ def __init__( - self, filter_from: Union[str, None] = None, filter_to: Union[str, None] = None + self, filter_from: "Optional[str]" = None, filter_to: "Optional[str]" = None ): - kwargs = {"from": filter_from, "to": filter_to} + kwargs: "Dict[str, Optional[str]]" = {"from": filter_from, "to": filter_to} self.update(kwargs) @@ -111,17 +116,17 @@ class SubmissionFilter(Filter): def __init__( self, - file_type: Union[List[str], None] = None, - input_filename: Union[str, None] = None, - status: Union[str, None] = None, - retrieved: Union[bool, None] = None, - reviews: Union[SubmissionReviewFilter, None] = None, - review_in_progress: Union[bool, None] = None, - files_deleted: Union[bool, None] = None, - created_at: Union[DateRangeFilter, None] = None, - updated_at: Union[DateRangeFilter, None] = None, + file_type: "Optional[List[str]]" = None, + input_filename: "Optional[str]" = None, + status: "Optional[str]" = None, + retrieved: "Optional[bool]" = None, + reviews: "Optional[SubmissionReviewFilter]" = None, + review_in_progress: "Optional[bool]" = None, + files_deleted: "Optional[bool]" = None, + created_at: "Optional[DateRangeFilter]" = None, + updated_at: "Optional[DateRangeFilter]" = None, ): - kwargs = { + kwargs: "AnyDict" = { "filetype": file_type, "inputFilename": input_filename, "status": status.upper() if status else status, @@ -155,12 +160,12 @@ class ModelGroupExampleFilter(Filter): def __init__( self, - file_name: Union[str, None] = None, - partial: Union[bool, None] = None, - status: Union[str, None] = None, - text_search: Union[str, None] = None, + file_name: "Optional[str]" = None, + partial: "Optional[bool]" = None, + status: "Optional[str]" = None, + text_search: "Optional[str]" = None, ): - kwargs = { + kwargs: "Dict[str, Optional[Union[bool, str]]]" = { "fileName": file_name, "partial": partial, "textSearch": text_search, @@ -184,9 +189,14 @@ class UserMetricsFilter(Filter): __options__ = ("user_id", "user_email") def __init__( - self, user_id: Union[int, None] = None, user_email: Union[str, None] = None + self, + user_id: "Optional[int]" = None, + user_email: "Optional[str]" = None, ): - kwargs = {"userId": user_id, "userEmail": user_email} + kwargs: "Dict[str, Optional[Union[int, str]]]" = { + "userId": user_id, + "userEmail": user_email, + } super().__init__(**kwargs) @@ -219,15 +229,19 @@ class DocumentReportFilter(Filter): def __init__( self, - submission_id: Union[int, None] = None, - workflow_id: Union[int, None] = None, - status: Union[str, None] = None, - created_at_start_date: Union[datetime.datetime, None] = None, - created_at_end_date: Union[datetime.datetime, None] = None, - updated_at_start_date: Union[datetime.datetime, None] = None, - updated_at_end_date: Union[datetime.datetime, None] = None, + submission_id: "Optional[int]" = None, + workflow_id: "Optional[int]" = None, + status: "Optional[str]" = None, + created_at_start_date: "Optional[datetime.datetime]" = None, + created_at_end_date: "Optional[datetime.datetime]" = None, + updated_at_start_date: "Optional[datetime.datetime]" = None, + updated_at_end_date: "Optional[datetime.datetime]" = None, ): - kwargs = {"workflowId": workflow_id, "id": submission_id, "status": status} + kwargs: "AnyDict" = { + "workflowId": workflow_id, + "id": submission_id, + "status": status, + } if created_at_end_date and not created_at_start_date: raise IndicoInputError("Must specify created_at_start_date") if created_at_start_date: diff --git a/indico/http/client.py b/indico/http/client.py index 0e67dfbc..a4ab3513 100644 --- a/indico/http/client.py +++ b/indico/http/client.py @@ -1,15 +1,14 @@ import asyncio -import http.cookiejar import logging from contextlib import contextmanager from copy import deepcopy +from http.cookiejar import DefaultCookiePolicy from pathlib import Path -from typing import Optional, Union +from typing import TYPE_CHECKING, cast import aiohttp import requests -from indico.client.request import HTTPRequest from indico.config import IndicoConfig from indico.errors import ( IndicoAuthenticationFailed, @@ -20,45 +19,63 @@ from .retry import aioretry +if TYPE_CHECKING: # pragma: no cover + from http.cookiejar import Cookie + from io import IOBase + from typing import Any, Dict, Iterator, List, Optional, Union + from urllib.request import Request + + from indico.client.request import HTTPRequest, ResponseType + from indico.typing import AnyDict + + logger = logging.getLogger(__file__) -class CookiePolicyOverride(http.cookiejar.DefaultCookiePolicy): - def set_ok(self, cookie, request): +class CookiePolicyOverride(DefaultCookiePolicy): + def set_ok(self, cookie: "Cookie", request: "Request") -> bool: return True - def return_ok(self, cookie, request): + def return_ok(self, cookie: "Cookie", request: "Request") -> bool: return True - def path_return_ok(self, path, request): + def path_return_ok(self, path: str, request: "Request") -> bool: return True - def domain_return_ok(self, domain, request): + def domain_return_ok(self, domain: str, request: "Request") -> bool: return True class HTTPClient: - def __init__(self, config: IndicoConfig = None): - self.config = config + def __init__(self, config: "Optional[IndicoConfig]" = None): + self.config = config or IndicoConfig() self.base_url = f"{self.config.protocol}://{self.config.host}" self.request_session = requests.Session() - if config and isinstance(config.requests_params, dict): - for param in config.requests_params.keys(): - setattr(self.request_session, param, config.requests_params[param]) + if isinstance(self.config.requests_params, dict): + for param in self.config.requests_params.keys(): + setattr(self.request_session, param, self.config.requests_params[param]) self.request_session.cookies.set_policy(CookiePolicyOverride()) + self.get_short_lived_access_token() - def post(self, *args, json: Union[dict, list] = None, **kwargs): + def post( + self, + *args: "Any", + json: "Optional[Union[AnyDict, List[Any]]]" = None, + **kwargs: "Any", + ) -> "Any": return self._make_request("post", *args, json=json, **kwargs) - def get(self, *args, params: dict = None, **kwargs): + def get( + self, *args: "Any", params: "Optional[AnyDict]" = None, **kwargs: "Any" + ) -> "Any": return self._make_request("post", *args, params=params, **kwargs) - def get_short_lived_access_token(self): - # If the cookie here is already due to _disable_cookie_domain set and we try to pop it later - # it will error out because we have two cookies with the same name. We just remove the old one - # here as we are about to refresh it. + def get_short_lived_access_token(self) -> "AnyDict": + # If the cookie here is already due to _disable_cookie_domain set and we try to + # pop it later it will error out because we have two cookies with the same + # name. We just remove the old one here as we are about to refresh it. if "auth_token" in self.request_session.cookies: self.request_session.cookies.pop("auth_token") @@ -75,12 +92,13 @@ def get_short_lived_access_token(self): raise IndicoAuthenticationFailed() self.request_session.cookies.pop("auth_token") self.request_session.cookies.set_cookie( - requests.cookies.create_cookie(name="auth_token", value=value) + # must ignore because untyped in typeshed + requests.cookies.create_cookie(name="auth_token", value=value) # type: ignore ) - return r + return cast("AnyDict", r) - def execute_request(self, request: HTTPRequest): + def execute_request(self, request: "HTTPRequest[ResponseType]") -> "ResponseType": return request.process_response( self._make_request( method=request.method.value.lower(), path=request.path, **request.kwargs @@ -88,7 +106,7 @@ def execute_request(self, request: HTTPRequest): ) @contextmanager - def _handle_files(self, req_kwargs): + def _handle_files(self, req_kwargs: "AnyDict") -> "Iterator[AnyDict]": streams = None # deepcopying buffers is not supported # so, remove "streams" before the deepcopy. @@ -97,11 +115,11 @@ def _handle_files(self, req_kwargs): streams = req_kwargs["streams"].copy() del req_kwargs["streams"] - new_kwargs = deepcopy(req_kwargs) + new_kwargs: "AnyDict" = deepcopy(req_kwargs) - files = [] + files: "List[IOBase]" = [] file_arg = {} - dup_counts = {} + dup_counts: "Dict[str, int]" = {} if "files" in new_kwargs and new_kwargs["files"] is not None: for filepath in new_kwargs["files"]: path = Path(filepath) @@ -129,19 +147,20 @@ def _handle_files(self, req_kwargs): new_kwargs["files"] = file_arg - yield new_kwargs - - if files: - [f.close() for f in files] + try: + yield new_kwargs + finally: + for f in files: + f.close() def _make_request( self, method: str, path: str, - headers: dict = None, - _refreshed=False, - **request_kwargs, - ): + headers: "Optional[Dict[str, str]]" = None, + _refreshed: bool = False, + **request_kwargs: "Any", + ) -> "Any": logger.debug( f"[{method}] {path}\n\t Headers: {headers}\n\tRequest Args:{request_kwargs}" ) @@ -178,7 +197,9 @@ def _make_request( extras=repr(response.content), ) - content = deserialize(response, force_json=json, force_decompress=decompress) + content: "Any" = deserialize( + response, force_json=json, force_decompress=decompress + ) if response.status_code >= 400: if isinstance(content, dict): @@ -194,16 +215,17 @@ def _make_request( raise IndicoRequestError( error=error, code=response.status_code, extras=extras ) + return content -class AIOHTTPClient(HTTPClient): +class AIOHTTPClient: """ Beta client with a minimal set of features that can execute requests using the aiohttp library """ - def __init__(self, config: Optional[IndicoConfig] = None): + def __init__(self, config: "Optional[IndicoConfig]" = None): """ Config options specific to aiohttp unsafe - allows interacting with IP urls @@ -212,25 +234,34 @@ def __init__(self, config: Optional[IndicoConfig] = None): self.base_url = f"{self.config.protocol}://{self.config.host}" self.request_session = aiohttp.ClientSession() - if config and isinstance(config.requests_params, dict): - for param in config.requests_params.keys(): - setattr(self.request_session, param, config.requests_params[param]) + if isinstance(self.config.requests_params, dict): + for param in self.config.requests_params.keys(): + setattr(self.request_session, param, self.config.requests_params[param]) - async def post(self, *args, json: Union[dict, list] = None, **kwargs): + async def post( + self, + *args: "Any", + json: "Optional[Union[AnyDict, List[Any]]]" = None, + **kwargs: "Any", + ) -> "Any": return await self._make_request("post", *args, json=json, **kwargs) - async def get(self, *args, params: dict = None, **kwargs): + async def get( + self, *args: "Any", params: "Optional[AnyDict]" = None, **kwargs: "Any" + ) -> "Any": return await self._make_request("post", *args, params=params, **kwargs) - async def get_short_lived_access_token(self): + async def get_short_lived_access_token(self) -> "AnyDict": r = await self.post( "/auth/users/refresh_token", headers={"Authorization": f"Bearer {self.config.api_token}"}, _refreshed=True, ) - return r + return cast("AnyDict", r) - async def execute_request(self, request: HTTPRequest): + async def execute_request( + self, request: "HTTPRequest[ResponseType]" + ) -> "ResponseType": return request.process_response( await self._make_request( method=request.method.value.lower(), path=request.path, **request.kwargs @@ -238,10 +269,12 @@ async def execute_request(self, request: HTTPRequest): ) @contextmanager - def _handle_files(self, req_kwargs): + def _handle_files( + self, req_kwargs: "AnyDict" + ) -> "Iterator[List[aiohttp.FormData]]": files = [] file_args = [] - dup_counts = {} + dup_counts: "Dict[str, int]" = {} for filepath in req_kwargs.pop("files", []) or []: data = aiohttp.FormData() path = Path(filepath) @@ -269,7 +302,7 @@ def _handle_files(self, req_kwargs): data.add_field( "file", stream, - filename=filename + f"({dup_counts[filename]})" + _add_suffix, + filename=filename + f"({dup_counts[filename]})", ) dup_counts[filename] += 1 else: @@ -277,20 +310,21 @@ def _handle_files(self, req_kwargs): dup_counts[filename] = 1 file_args.append(data) - yield file_args + try: + yield file_args + finally: + for f in files: + f.close() - if files: - [f.close() for f in files] - - @aioretry((aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError)) + @aioretry(aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError) async def _make_request( self, method: str, path: str, - headers: dict = None, - _refreshed=False, - **request_kwargs, - ): + headers: "Optional[Dict[str, str]]" = None, + _refreshed: bool = False, + **request_kwargs: "Any", + ) -> "Any": logger.debug( f"[{method}] {path}\n\t Headers: {headers}\n\tRequest Args:{request_kwargs}" ) @@ -308,6 +342,7 @@ async def _make_request( ) ) return [resp for resp_set in resps for resp in resp_set] + async with getattr(self.request_session, method)( f"{self.base_url}{path}", headers=headers, @@ -330,12 +365,12 @@ async def _make_request( if response.status >= 500: raise IndicoRequestError( - code=response.status_code, + code=response.status, error=response.reason, extras=repr(response.content), ) - content = await aio_deserialize( + content: "Any" = await aio_deserialize( response, force_json=json, force_decompress=decompress ) @@ -351,6 +386,7 @@ async def _make_request( extras = None raise IndicoRequestError( - error=error, code=response.status_code, extras=extras + error=error, code=response.status, extras=extras ) + return content diff --git a/indico/http/retry.py b/indico/http/retry.py index b9c9cd0d..85ed1850 100644 --- a/indico/http/retry.py +++ b/indico/http/retry.py @@ -1,23 +1,30 @@ import asyncio -from random import randint import time -import typing as t from functools import wraps +from random import randint +from typing import TYPE_CHECKING + +if TYPE_CHECKING: # pragma: no cover + from typing import Awaitable, Callable, Optional, Tuple, Type, TypeVar, Union + from typing_extensions import ParamSpec + + P = ParamSpec("P") + T = TypeVar("T") def retry( - ExceptionTypes: t.Type[Exception], tries: int = 3, delay: int = 1, backoff: int = 2 -) -> t.Callable: + *ExceptionTypes: "Type[Exception]", tries: int = 3, delay: int = 1, backoff: int = 2 +) -> "Callable[[Callable[P, T]], Callable[P, T]]": """ Retry with exponential backoff Original from: http://wiki.python.org/moin/PythonDecoratorLibrary#Retry """ - def retry_decorator(f: t.Callable) -> t.Any: + def retry_decorator(f: "Callable[P, T]") -> "Callable[P, T]": @wraps(f) - def retry_fn(*args: t.Any, **kwargs: t.Any) -> t.Any: + def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T": n_tries, n_delay = tries, delay while n_tries > 1: try: @@ -32,13 +39,14 @@ def retry_fn(*args: t.Any, **kwargs: t.Any) -> t.Any: return retry_decorator + def aioretry( - ExceptionTypes: t.Type[Exception], + *ExceptionTypes: "Type[Exception]", tries: int = 3, - delay: t.Union[int, t.Tuple[int, int]] = 1, + delay: "Union[int, Tuple[int, int]]" = 1, backoff: int = 2, - condition: t.Optional[t.Callable[[Exception], bool]] = None, -) -> t.Callable: + condition: "Optional[Callable[[Exception], bool]]" = None, +) -> "Callable[[Callable[P, Awaitable[T]]], Callable[P, Awaitable[T]]]": """ Retry with exponential backoff @@ -51,9 +59,9 @@ def aioretry( same time across multiple concurrent function calls """ - def retry_decorator(f: t.Callable) -> t.Callable: + def retry_decorator(f: "Callable[P, Awaitable[T]]") -> "Callable[P, Awaitable[T]]": @wraps(f) - async def retry_fn(*args: t.Any, **kwargs: t.Any) -> t.Any: + async def retry_fn(*args: "P.args", **kwargs: "P.kwargs") -> "T": n_tries = tries if isinstance(delay, tuple): # pick a random number to sleep diff --git a/indico/http/serialization.py b/indico/http/serialization.py index c3e52174..7e22f92a 100644 --- a/indico/http/serialization.py +++ b/indico/http/serialization.py @@ -1,29 +1,48 @@ """ Handles deserialization / decoding of responses """ -import cgi + import gzip import io import json import logging import traceback from collections import defaultdict +from email.message import Message +from typing import TYPE_CHECKING import msgpack from indico.errors import IndicoDecodingError +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Callable, Dict, Mapping, Optional, Tuple + + from aiohttp import ClientResponse + from requests import Response + logger = logging.getLogger(__name__) -def decompress(response): +def _parse_header(header: str) -> "Tuple[str, Dict[str, str]]": + """Parse a header and return as a tuple of a main value and additional params""" + m = Message() + m["content-type"] = header + params = m.get_params(failobj=[]) + return params[0][0], dict(params[1:]) + + +def decompress(response: "Response") -> bytes: response.raw.decode_content = True - value = io.BytesIO(response.raw.data).getvalue() + value: bytes = io.BytesIO(response.raw.data).getvalue() return gzip.decompress(value) -def deserialize(response, force_json=False, force_decompress=False): - content_type, params = cgi.parse_header(response.headers.get("Content-Type")) +def deserialize( + response: "Response", force_json: bool = False, force_decompress: bool = False +) -> "Any": + content_type, params = _parse_header(response.headers["Content-Type"]) + content: bytes if force_decompress or content_type in ["application/x-gzip", "application/gzip"]: content = decompress(response) @@ -44,14 +63,17 @@ def deserialize(response, force_json=False, force_decompress=False): content_type, charset, content.decode("ascii", "ignore") ) -async def aio_deserialize(response, force_json=False, force_decompress=False): - content_type, params = cgi.parse_header(response.headers.get("Content-Type")) - content = await response.read() + +async def aio_deserialize( + response: "ClientResponse", force_json: bool = False, force_decompress: bool = False +) -> "Any": + content_type, params = _parse_header(response.headers["Content-Type"]) + content: bytes = await response.read() if force_decompress or content_type in ["application/x-gzip", "application/gzip"]: - content = gzip.decompress(io.BytesIO(content).get_value()) + content = gzip.decompress(io.BytesIO(content).getvalue()) - charset = params.get("charset", "utf-8") + charset: str = params.get("charset", "utf-8") # For storage object for example where the content is json based on url ending if force_json: @@ -65,31 +87,34 @@ async def aio_deserialize(response, force_json=False, force_decompress=False): content_type, charset, content.decode("ascii", "ignore") ) -def raw_bytes(content, *args, **kwargs): + +def raw_bytes( + content: bytes, charset: "Optional[str]", *args: "Any", **kwargs: "Any" +) -> bytes: return content -def msgpack_deserialization(content, charset): +def msgpack_deserialization(content: bytes, charset: "Optional[str]" = None) -> "Any": return msgpack.unpackb(content) -def json_deserialization(content, charset="utf-8"): +def json_deserialization(content: bytes, charset: str = "utf-8") -> "Any": return json.loads(content.decode(charset)) -def text_deserialization(content, charset="utf-8"): +def text_deserialization(content: bytes, charset: str = "utf-8") -> str: return content.decode(charset) -def image_serialization(content, charset=None): +def image_serialization(content: bytes, charset: "Optional[str]" = None) -> bytes: return content -def zip_serialization(content, charset=None): +def zip_serialization(content: bytes, charset: "Optional[str]" = None) -> bytes: return content -_SERIALIZATION_FNS = defaultdict( +_SERIALIZATION_FNS: "Mapping[str, Callable[[bytes, str], Any]]" = defaultdict( lambda: text_deserialization, { "application/pdf": raw_bytes, @@ -104,7 +129,6 @@ def zip_serialization(content, charset=None): "application/vnd.ms-excel": raw_bytes, "application/msexcel": raw_bytes, "application/excel": raw_bytes, - "application/msexcel": raw_bytes, "application/x-dos_ms_excel": raw_bytes, "application/x-excel": raw_bytes, "application/x-ms-excel": raw_bytes, @@ -116,7 +140,6 @@ def zip_serialization(content, charset=None): "application/vnd.openxmlformats-officedocument.presentationml.presentation": raw_bytes, "application/mspowerpoint": raw_bytes, "application/powerpoint": raw_bytes, - "application/vnd.ms-powerpoint": raw_bytes, "application/x-mspowerpoint": raw_bytes, "image/png": image_serialization, "image/jpeg": image_serialization, diff --git a/indico/py.typed b/indico/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/indico/queries/__init__.py b/indico/queries/__init__.py index 7b793bd4..29a467f8 100644 --- a/indico/queries/__init__.py +++ b/indico/queries/__init__.py @@ -1,13 +1,13 @@ from .datasets import * -from .model_groups import * -from .jobs import * from .documents import * +from .example import * +from .export import * +from .forms import * +from .integration import * +from .jobs import * +from .model_groups import * +from .questionnaire import * from .storage import * from .submission import * from .workflow import * -from .forms import * -from .export import * from .workflow_components import * -from .example import * -from .questionnaire import * -from .integration import * \ No newline at end of file diff --git a/indico/queries/custom_blueprint.py b/indico/queries/custom_blueprint.py index e9833bbb..1a49ef57 100644 --- a/indico/queries/custom_blueprint.py +++ b/indico/queries/custom_blueprint.py @@ -1,18 +1,24 @@ -import typing as t +from typing import TYPE_CHECKING from indico import GraphQLRequest from indico.errors import IndicoInputError -from indico.types.workflow import ComponentFamily from indico.types.custom_blueprint import TaskBlueprint +from indico.types.workflow import ComponentFamily + +if TYPE_CHECKING: # pragma: no cover + from typing import List, Optional + + from indico.typing import AnyDict, Payload + -SUPPORTED_CUSTOM_COMPONENT_FAMILIES = [ +SUPPORTED_CUSTOM_COMPONENT_FAMILIES: "List[ComponentFamily]" = [ ComponentFamily.OUTPUT, ComponentFamily.FILTER, ComponentFamily.MODEL, ] -class RegisterCustomBlueprint(GraphQLRequest): +class RegisterCustomBlueprint(GraphQLRequest["TaskBlueprint"]): """ Mutation to register a custom blueprint, making it available in the gallery to add to workflows @@ -70,15 +76,15 @@ class RegisterCustomBlueprint(GraphQLRequest): def __init__( self, - component_family: ComponentFamily, + component_family: "ComponentFamily", name: str, description: str, - config: t.Dict, - tags: t.List[str], + config: "AnyDict", + tags: "List[str]", footer: str = "", - icon: str = None, - all_access: bool = None, - dataset_ids: t.List[int] = None, + icon: "Optional[str]" = None, + all_access: "Optional[bool]" = None, + dataset_ids: "Optional[List[int]]" = None, ): if ( not component_family @@ -133,5 +139,5 @@ def __init__( }, ) - def process_response(self, response) -> TaskBlueprint: - return TaskBlueprint(**super().process_response(response)[self.mutation_name]) + def process_response(self, response: "Payload") -> "TaskBlueprint": + return TaskBlueprint(**super().parse_payload(response)[self.mutation_name]) diff --git a/indico/queries/datasets.py b/indico/queries/datasets.py index 7addd2ee..f8d68b09 100644 --- a/indico/queries/datasets.py +++ b/indico/queries/datasets.py @@ -3,9 +3,8 @@ import json import tempfile from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING -import deprecation import jsons import pandas as pd @@ -20,16 +19,16 @@ from indico.errors import IndicoInputError, IndicoNotFound from indico.filters import DatasetFilter from indico.queries.storage import UploadBatched, UploadImages -from indico.types.dataset import ( - Dataset, - OcrEngine, - OcrInputLanguage, - OmnipageOcrOptionsInput, - ReadApiOcrOptionsInput, -) +from indico.types.dataset import Dataset, OcrEngine, OcrInputLanguage + +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union + from indico.types.dataset import OmnipageOcrOptionsInput, ReadApiOcrOptionsInput + from indico.typing import AnyDict, Payload -class ListDatasets(PagedRequest): + +class ListDatasets(PagedRequest["List[Dataset]"]): """ List all of your datasets @@ -71,7 +70,7 @@ class ListDatasets(PagedRequest): def __init__( self, *, - filters: Optional[Union[Dict, DatasetFilter]] = None, + filters: "Optional[Union[AnyDict, DatasetFilter]]" = None, limit: int = 100, order_by: str = "ID", desc: bool = False, @@ -86,12 +85,12 @@ def __init__( }, ) - def process_response(self, response) -> List[Dataset]: - response = super().process_response(response) + def process_response(self, response: "Payload") -> "List[Dataset]": + response = super().parse_payload(response) return [Dataset(**dataset) for dataset in response["datasetsPage"]["datasets"]] -class GetDataset(GraphQLRequest): +class GetDataset(GraphQLRequest["Dataset"]): """ Retrieve a dataset description object @@ -125,8 +124,8 @@ class GetDataset(GraphQLRequest): def __init__(self, id: int): super().__init__(self.query, variables={"id": id}) - def process_response(self, response) -> Dataset: - response = super().process_response(response) + def process_response(self, response: "Payload") -> "Dataset": + response = super().parse_payload(response) if "dataset" not in response or not isinstance(response["dataset"], dict): raise IndicoNotFound("Failed to find dataset") return Dataset(**response["dataset"]) @@ -165,7 +164,7 @@ class GetDatasetFileStatus(GetDataset): """ -class GetDatasetStatus(GraphQLRequest): +class GetDatasetStatus(GraphQLRequest[str]): """ Get the status of a dataset @@ -187,11 +186,12 @@ class GetDatasetStatus(GraphQLRequest): def __init__(self, id: int): super().__init__(self.query, variables={"id": id}) - def process_response(self, response) -> str: - return response["data"]["dataset"]["status"] + def process_response(self, response: "Payload") -> str: + status: str = super().parse_payload(response)["dataset"]["status"] + return status -class CreateDataset(RequestChain): +class CreateDataset(RequestChain["Dataset"]): """ Create a dataset and upload the associated files. @@ -213,21 +213,19 @@ class CreateDataset(RequestChain): """ - previous = None - def __init__( self, name: str, - files: List[str], + files: "Union[str, List[str]]", wait: bool = True, dataset_type: str = "TEXT", from_local_images: bool = False, image_filename_col: str = "filename", batch_size: int = 20, - ocr_engine: OcrEngine = None, - omnipage_ocr_options: OmnipageOcrOptionsInput = None, - read_api_ocr_options: ReadApiOcrOptionsInput = None, - request_interval: Union[int, float] = 5, + ocr_engine: "Optional[OcrEngine]" = None, + omnipage_ocr_options: "Optional[OmnipageOcrOptionsInput]" = None, + read_api_ocr_options: "Optional[ReadApiOcrOptionsInput]" = None, + request_interval: "Union[int, float]" = 5, ): self.files = files self.name = name @@ -246,8 +244,15 @@ def __init__( ) super().__init__() - def requests(self): + def requests( + self, + ) -> "Iterator[Union[UploadBatched, _UploadDatasetFiles, CreateEmptyDataset, _AddFiles, GetDatasetFileStatus, Delay, GetDataset]]": if self.from_local_images: + if not isinstance(self.files, str): + raise ValueError( + "'files' should be a string path when using `from_local_images`." + ) + self.dataset_type = "IMAGE" # Assume image filenames are in the same directory as the csv with # image labels and that there is a column representing their name @@ -267,11 +272,15 @@ def requests(self): df.to_csv(image_csv_path) yield _UploadDatasetFiles(files=[image_csv_path]) else: + if not isinstance(self.files, list): + raise ValueError("'files' should be a list of paths.") + yield UploadBatched( files=self.files, batch_size=self.batch_size, request_cls=_UploadDatasetFiles, ) + file_metadata = self.previous yield CreateEmptyDataset( name=self.name, @@ -294,7 +303,7 @@ def requests(self): yield GetDataset(id=dataset_id) -class RemoveDatasetFile(GraphQLRequest): +class RemoveDatasetFile(GraphQLRequest["Dataset"]): """ Remove a file from a dataset by ID. To retrieve a list of files in a dataset, see `GetDatasetFileStatus`. @@ -322,18 +331,18 @@ def __init__(self, dataset_id: int, file_id: int): variables={"datasetId": dataset_id, "fileId": file_id}, ) - def process_response(self, response): - return Dataset(**super().process_response(response)["deleteDatasetFile"]) + def process_response(self, response: "Payload") -> "Dataset": + return Dataset(**super().parse_payload(response)["deleteDatasetFile"]) -class _UploadDatasetFiles(HTTPRequest): - def __init__(self, files: List[str]): +class _UploadDatasetFiles(HTTPRequest["List[AnyDict]"]): + def __init__(self, files: "List[str]"): super().__init__( method=HTTPMethod.POST, path="/storage/files/upload", files=files ) -class DeleteDataset(GraphQLRequest): +class DeleteDataset(GraphQLRequest[bool]): """ Delete a dataset @@ -352,14 +361,15 @@ class DeleteDataset(GraphQLRequest): } """ - def __init__(self, id): + def __init__(self, id: int): super().__init__(self.query, variables={"id": id}) - def process_response(self, response): - return super().process_response(response)["deleteDataset"]["success"] + def process_response(self, response: "Payload") -> bool: + status: bool = super().parse_payload(response)["deleteDataset"]["success"] + return status -class CreateEmptyDataset(GraphQLRequest): +class CreateEmptyDataset(GraphQLRequest["Dataset"]): query = """ mutation($name: String!, $datasetType: DatasetType, $config: DataConfigInput) { createDataset(name: $name, datasetType: $datasetType, config: $config ) { @@ -372,10 +382,10 @@ class CreateEmptyDataset(GraphQLRequest): def __init__( self, name: str, - dataset_type: str = None, - ocr_engine: OcrEngine = None, - omnipage_ocr_options: OmnipageOcrOptionsInput = None, - readapi_ocr_options: ReadApiOcrOptionsInput = None, + dataset_type: "Optional[str]" = None, + ocr_engine: "Optional[OcrEngine]" = None, + omnipage_ocr_options: "Optional[OmnipageOcrOptionsInput]" = None, + readapi_ocr_options: "Optional[ReadApiOcrOptionsInput]" = None, ): if not dataset_type: dataset_type = "TEXT" @@ -401,11 +411,11 @@ def __init__( }, ) - def process_response(self, response): - return Dataset(**super().process_response(response)["createDataset"]) + def process_response(self, response: "Payload") -> "Dataset": + return Dataset(**super().parse_payload(response)["createDataset"]) -class _AddFiles(GraphQLRequest): +class _AddFiles(GraphQLRequest["Dataset"]): query = """ mutation AddFiles($datasetId: Int!, $metadata: JSONString!, $autoprocess: Boolean){ addDatasetFiles(datasetId: $datasetId, metadataList: $metadata, autoprocess: $autoprocess) { @@ -415,7 +425,7 @@ class _AddFiles(GraphQLRequest): } """ - def __init__(self, dataset_id: int, metadata: List[str], autoprocess: bool): + def __init__(self, dataset_id: int, metadata: "List[str]", autoprocess: bool): super().__init__( self.query, variables={ @@ -425,11 +435,11 @@ def __init__(self, dataset_id: int, metadata: List[str], autoprocess: bool): }, ) - def process_response(self, response): - return Dataset(**super().process_response(response)["addDatasetFiles"]) + def process_response(self, response: "Payload") -> "Dataset": + return Dataset(**super().parse_payload(response)["addDatasetFiles"]) -class AddDatasetFiles(RequestChain): +class AddDatasetFiles(RequestChain["Dataset"]): """ Add files to a dataset. @@ -446,12 +456,10 @@ class AddDatasetFiles(RequestChain): Dataset """ - previous = None - def __init__( self, dataset_id: int, - files: List[str], + files: "List[str]", autoprocess: bool = True, wait: bool = True, batch_size: int = 20, @@ -468,7 +476,9 @@ def __init__( ) super().__init__() - def requests(self): + def requests( + self, + ) -> "Iterator[Union[UploadBatched, _AddFiles, GetDatasetFileStatus, Delay]]": yield UploadBatched( files=self.files, batch_size=self.batch_size, @@ -484,15 +494,15 @@ def requests(self): while not all( f.status in self.expected_statuses for f in self.previous.files ): - yield GetDatasetFileStatus(id=self.previous.id) yield Delay() + yield GetDatasetFileStatus(id=self.previous.id) # Alias for backwards compatibility AddFiles = AddDatasetFiles -class GetAvailableOcrEngines(GraphQLRequest): +class GetAvailableOcrEngines(GraphQLRequest["List[OcrEngine]"]): """ Fetches and lists the available OCR engines """ @@ -505,15 +515,15 @@ class GetAvailableOcrEngines(GraphQLRequest): } }""" - def __init__(self): + def __init__(self) -> None: super().__init__(self.query) - def process_response(self, response): - engines = super().process_response(response)["ocrOptions"]["engines"] + def process_response(self, response: "Payload") -> "List[OcrEngine]": + engines = super().parse_payload(response)["ocrOptions"]["engines"] return [OcrEngine[e["name"]] for e in engines] -class GetOcrEngineLanguageCodes(GraphQLRequest): +class GetOcrEngineLanguageCodes(GraphQLRequest["List[OcrInputLanguage]"]): """ Fetches and lists the available languages by name and code for the given OCR Engine @@ -533,12 +543,12 @@ class GetOcrEngineLanguageCodes(GraphQLRequest): } }""" - def __init__(self, engine: OcrEngine): + def __init__(self, engine: "OcrEngine"): self.engine = engine super().__init__(self.query) - def process_response(self, response): - data = super().process_response(response)["ocrOptions"]["engines"] + def process_response(self, response: "Payload") -> "List[OcrInputLanguage]": + data = super().parse_payload(response)["ocrOptions"]["engines"] engine_laguages = next( x["languages"] for x in data if x["name"] == self.engine.name ) diff --git a/indico/queries/document_report.py b/indico/queries/document_report.py index 84d15f07..65be0032 100644 --- a/indico/queries/document_report.py +++ b/indico/queries/document_report.py @@ -1,18 +1,23 @@ """""" -from typing import List, Union + +from typing import TYPE_CHECKING, List from indico import PagedRequest from indico.filters import DocumentReportFilter from indico.types import BaseType from indico.types.document_report import DocumentReport +if TYPE_CHECKING: # pragma: no cover + from typing import Optional, Union + + from indico.typing import AnyDict, Payload + class _DocumentReportList(BaseType): submissions: List[DocumentReport] - pass -class GetDocumentReport(PagedRequest): +class GetDocumentReport(PagedRequest["List[DocumentReport]"]): """ Query to generate a Document Report, otherwise known as a log of past submissions. """ @@ -53,12 +58,19 @@ class GetDocumentReport(PagedRequest): """ def __init__( - self, filters: Union[dict, DocumentReportFilter] = None, limit: int = None, all_submissions = False + self, + filters: "Optional[Union[AnyDict, DocumentReportFilter]]" = None, + limit: "Optional[int]" = None, + all_submissions: bool = False, ): - variables = {"filters": filters, "limit": limit, "allSubmissions": all_submissions} + variables = { + "filters": filters, + "limit": limit, + "allSubmissions": all_submissions, + } super().__init__(self.query, variables=variables) - def process_response(self, response): + def process_response(self, response: "Payload") -> "List[DocumentReport]": return _DocumentReportList( - **super().process_response(response)["submissionsLog"] + **super().parse_payload(response)["submissionsLog"] ).submissions diff --git a/indico/queries/documents.py b/indico/queries/documents.py index 4d14d24f..9532e7e4 100644 --- a/indico/queries/documents.py +++ b/indico/queries/documents.py @@ -1,15 +1,19 @@ # -*- coding: utf-8 -*- import json -from typing import List +from typing import TYPE_CHECKING -from indico.client.request import RequestChain, GraphQLRequest, HTTPMethod, HTTPRequest +from indico.client.request import GraphQLRequest, RequestChain +from indico.queries.storage import UploadBatched, UploadDocument from indico.types.jobs import Job -from indico.queries.storage import UploadDocument, UploadBatched +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class _DocumentExtraction(GraphQLRequest): + from indico.typing import AnyDict, Payload + +class _DocumentExtraction(GraphQLRequest["List[Job]"]): query = """ mutation($files: [FileInput], $jsonConfig: JSONString, $ocrEngine: OCREngine) { documentExtraction(files: $files, jsonConfig: $jsonConfig, ocrEngine: $ocrEngine) { @@ -18,22 +22,34 @@ class _DocumentExtraction(GraphQLRequest): } """ - def __init__(self, files, json_config={"preset_config": "legacy"}, ocr_engine=None): - if json_config and type(json_config) == dict: - json_config = json.dumps(json_config) + def __init__( + self, + files: "List[AnyDict]", + json_config: "Optional[Union[AnyDict, str]]" = {"preset_config": "legacy"}, + ocr_engine: "Optional[str]" = None, + ): + json_config_json: "Optional[str]" = None + if json_config: + if isinstance(json_config, dict): + json_config_json = json.dumps(json_config) + else: + json_config_json = json_config + super().__init__( - query=self.query, variables={"files": files, "jsonConfig": json_config, "ocrEngine": ocr_engine} + query=self.query, + variables={ + "files": files, + "jsonConfig": json_config_json, + "ocrEngine": ocr_engine, + }, ) - def process_response(self, response): - jobs = super().process_response(response)["documentExtraction"]["jobIds"] - if jobs: - return [Job(id=j) for j in jobs] - else: - return [] + def process_response(self, response: "Payload") -> "List[Job]": + jobs = super().parse_payload(response)["documentExtraction"]["jobIds"] or set() + return [Job(id=j) for j in jobs] -class DocumentExtraction(RequestChain): +class DocumentExtraction(RequestChain["Job"]): """ Extract raw text from PDF or TIF files. @@ -67,14 +83,20 @@ class DocumentExtraction(RequestChain): """ def __init__( - self, files: List[str], json_config: dict = None, upload_batch_size: int = None, ocr_engine: str = "OMNIPAGE" + self, + files: "List[str]", + json_config: "Optional[AnyDict]" = None, + upload_batch_size: "Optional[int]" = None, + ocr_engine: str = "OMNIPAGE", ): self.files = files self.json_config = json_config self.upload_batch_size = upload_batch_size self.ocr_engine = ocr_engine - def requests(self): + def requests( + self, + ) -> "Iterator[Union[UploadBatched, UploadDocument, _DocumentExtraction]]": if self.upload_batch_size: yield UploadBatched( files=self.files, @@ -83,4 +105,9 @@ def requests(self): ) else: yield UploadDocument(files=self.files) - yield _DocumentExtraction(files=self.previous, json_config=self.json_config, ocr_engine=self.ocr_engine) + + yield _DocumentExtraction( + files=self.previous, + json_config=self.json_config, + ocr_engine=self.ocr_engine, + ) diff --git a/indico/queries/example.py b/indico/queries/example.py index 746d8099..9b52066f 100644 --- a/indico/queries/example.py +++ b/indico/queries/example.py @@ -1,11 +1,16 @@ -from typing import Dict, List, Union +from typing import TYPE_CHECKING -from indico.client.request import GraphQLRequest, RequestChain, PagedRequest +from indico.client.request import PagedRequest from indico.filters import ModelGroupExampleFilter -from indico.types import Example, model_group +from indico.types.questionnaire import Example +if TYPE_CHECKING: # pragma: no cover + from typing import List, Optional, Union -class ListModelGroupExamples(PagedRequest): + from indico.typing import AnyDict, Payload + + +class ListModelGroupExamples(PagedRequest["List[Example]"]): """ List all examples associated with a given model group ID. Supports pagination (limit becomes page_size) @@ -37,22 +42,22 @@ class ListModelGroupExamples(PagedRequest): endCursor hasNextPage } - } + } } - } + } } """ def __init__( self, *, - model_group_id: int = None, - filters: Union[Dict, ModelGroupExampleFilter] = None, + model_group_id: "Optional[int]" = None, + filters: "Optional[Union[AnyDict, ModelGroupExampleFilter]]" = None, limit: int = 1000, order_by: str = "ID", desc: bool = True, - after: int = None, - before: int = None, + after: "Optional[int]" = None, + before: "Optional[int]" = None, ): variables = { "modelGroupId": model_group_id, @@ -68,9 +73,6 @@ def __init__( variables=variables, ) - def process_response(self, response) -> List[Example]: - response = response["data"]["modelGroups"]["modelGroups"][0] - _pg = next(iter(response.values()))["pageInfo"] - self.has_next_page = _pg["hasNextPage"] - self.variables["after"] = _pg["endCursor"] if self.has_next_page else None - return [Example(**s) for s in response["pagedExamples"]["examples"]] + def process_response(self, response: "Payload") -> "List[Example]": + example_page = super().parse_payload(response)["modelGroups"]["modelGroups"][0] + return [Example(**s) for s in example_page["pagedExamples"]["examples"]] diff --git a/indico/queries/export.py b/indico/queries/export.py index bd8cfdf7..c21a84b3 100644 --- a/indico/queries/export.py +++ b/indico/queries/export.py @@ -1,16 +1,21 @@ import io import warnings -from typing import List, Union +from typing import TYPE_CHECKING import pandas as pd from indico.client import Delay, GraphQLRequest, RequestChain -from indico.errors import IndicoNotFound, IndicoRequestError +from indico.errors import IndicoInputError, IndicoRequestError from indico.queries.storage import RetrieveStorageObject from indico.types.export import Export, LabelResolutionStrategy +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class _CreateExport(GraphQLRequest): + from indico.typing import Payload + + +class _CreateExport(GraphQLRequest["Export"]): query = """ mutation CreateExport( $datasetId: Int!, @@ -51,22 +56,36 @@ def __init__( self, dataset_id: int, labelset_id: int, - column_ids: List[int] = None, - model_ids: List[int] = None, - frozen_labelset_ids: List[int] = None, - combine_labels: LabelResolutionStrategy = LabelResolutionStrategy.ALL.name, - file_info: bool = None, - anonymoous: bool = None, - anonymous: bool = None, + column_ids: "Optional[List[int]]" = None, + model_ids: "Optional[List[int]]" = None, + frozen_labelset_ids: "Optional[List[int]]" = None, + combine_labels: "Union[LabelResolutionStrategy, str]" = LabelResolutionStrategy.ALL, + file_info: "Optional[bool]" = None, + anonymoous: "Optional[bool]" = None, + anonymous: "Optional[bool]" = None, ): if anonymoous: warnings.warn( "Argument anonymoous is deprecated and will be removed in future versions. Use argument anonymous instead." ) if anonymous: - raise IndicoRequestError("Cannot use both anonymoous and anonymous.") - else: - anonymous = anonymoous + raise IndicoInputError("Cannot use both anonymoous and anonymous.") + anonymous = anonymoous + + if isinstance(combine_labels, LabelResolutionStrategy): + combine_labels = combine_labels.name + else: + if combine_labels not in LabelResolutionStrategy.__members__: + raise IndicoInputError( + "Invalid resolution strategy. Use the LabelResolutionStrategy enum" + " instead." + ) + + warnings.warn( + "String-based values to combine_labels is deprecated. Use the" + " LabelResolutionStrategy enum instead." + ) + super().__init__( self.query, variables={ @@ -81,12 +100,12 @@ def __init__( }, ) - def process_response(self, response): - response = super().process_response(response) + def process_response(self, response: "Payload") -> "Export": + response = super().parse_payload(response) return Export(**response["createExport"]) -class GetExport(GraphQLRequest): +class GetExport(GraphQLRequest["Export"]): """ Get information on an Export job @@ -122,18 +141,18 @@ class GetExport(GraphQLRequest): def __init__(self, export_id: int): super().__init__(self.query, variables={"exportIds": [export_id]}) - def process_response(self, response): - response = super().process_response(response) + def process_response(self, response: "Payload") -> "Export": + response = super().parse_payload(response) return Export(**response["exports"]["exports"][0]) class _RetrieveExport(RetrieveStorageObject): - def process_response(self, response): - response = super().process_response(response) - return pd.read_csv(io.StringIO(response)) + def process_response(self, response: "Payload") -> "pd.DataFrame": + raw_response: str = super().process_response(response) + return pd.read_csv(io.StringIO(raw_response)) -class DownloadExport(RequestChain): +class DownloadExport(RequestChain["pd.DataFrame"]): """ Download an export from an Indico storage url @@ -148,7 +167,9 @@ class DownloadExport(RequestChain): IndicoRequestError if the Export job is not complete or failed """ - def __init__(self, export_id: int = None, export: Export = None): + def __init__( + self, export_id: "Optional[int]" = None, export: "Optional[Export]" = None + ): if not export_id and not export: raise IndicoRequestError( code="FAILURE", @@ -157,11 +178,11 @@ def __init__(self, export_id: int = None, export: Export = None): self.export_id = export_id self.export = export - def requests(self): + def requests(self) -> "Iterator[Union[GetExport, _RetrieveExport]]": if self.export_id: yield GetExport(self.export_id) - export = self.export or self.previous + export = self.export or self.previous if export.status != "COMPLETE": raise IndicoRequestError( code="400", @@ -171,7 +192,7 @@ def requests(self): yield _RetrieveExport(export.download_url) -class CreateExport(RequestChain): +class CreateExport(RequestChain["Export"]): """ Create an export job for a dataset. @@ -192,20 +213,18 @@ class CreateExport(RequestChain): """ - previous = None - def __init__( self, dataset_id: int, labelset_id: int, - column_ids: List[int] = None, - model_ids: List[int] = None, - frozen_labelset_ids: List[int] = None, - combine_labels: LabelResolutionStrategy = LabelResolutionStrategy.ALL.name, + column_ids: "Optional[List[int]]" = None, + model_ids: "Optional[List[int]]" = None, + frozen_labelset_ids: "Optional[List[int]]" = None, + combine_labels: "Union[LabelResolutionStrategy, str]" = LabelResolutionStrategy.ALL, file_info: bool = False, anonymous: bool = False, wait: bool = True, - request_interval: Union[int, float] = 5, + request_interval: "Union[int, float]" = 5, ): self.dataset_id = dataset_id self.labelset_id = labelset_id @@ -219,7 +238,7 @@ def __init__( self.request_interval = request_interval super().__init__() - def requests(self): + def requests(self) -> "Iterator[Union[_CreateExport, GetExport, Delay]]": yield _CreateExport( dataset_id=self.dataset_id, labelset_id=self.labelset_id, diff --git a/indico/queries/forms.py b/indico/queries/forms.py index c91917d2..d98b3e99 100644 --- a/indico/queries/forms.py +++ b/indico/queries/forms.py @@ -1,13 +1,16 @@ -from typing import List -from pathlib import Path +from typing import TYPE_CHECKING -from indico.client.request import RequestChain, GraphQLRequest +from indico.client.request import GraphQLRequest, RequestChain from indico.queries.storage import UploadBatched, UploadDocument -from indico.queries.jobs import Job +from indico.types import Job +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class _FormPreprocessing(GraphQLRequest): + from indico.typing import AnyDict, Payload + +class _FormPreprocessing(GraphQLRequest["List[Job]"]): query = """ mutation($files: [FileInput]) { activeFormFields( @@ -18,19 +21,16 @@ class _FormPreprocessing(GraphQLRequest): } """ - def __init__(self, files): + def __init__(self, files: "List[AnyDict]"): super().__init__(query=self.query, variables={"files": files}) - def process_response(self, response): - jobs = super().process_response(response)["activeFormFields"]["jobIds"] - if jobs: - return [Job(id=j) for j in jobs] - else: - return [] + def process_response(self, response: "Payload") -> "List[Job]": + jobs = super().parse_payload(response)["activeFormFields"]["jobIds"] or set() + return [Job(id=j) for j in jobs] # TODO: move into indico-client -class FormPreprocessing(RequestChain): +class FormPreprocessing(RequestChain["List[Job]"]): """ Attempt to auto-detect form fields and labels @@ -42,12 +42,17 @@ class FormPreprocessing(RequestChain): """ def __init__( - self, files: List[str], json_config: dict = None, upload_batch_size: int = None + self, + files: "List[str]", + json_config: "Optional[AnyDict]" = None, + upload_batch_size: "Optional[int]" = None, ): self.files = files self.upload_batch_size = upload_batch_size - def requests(self): + def requests( + self, + ) -> "Iterator[Union[UploadBatched, UploadDocument, _FormPreprocessing]]": if self.upload_batch_size: yield UploadBatched( files=self.files, @@ -56,4 +61,5 @@ def requests(self): ) else: yield UploadDocument(files=self.files) + yield _FormPreprocessing(files=self.previous) diff --git a/indico/queries/integration.py b/indico/queries/integration.py index 077cd4a5..51ef3104 100644 --- a/indico/queries/integration.py +++ b/indico/queries/integration.py @@ -1,16 +1,17 @@ -import jsons -from typing import List +from typing import TYPE_CHECKING -from indico import GraphQLRequest, RequestChain -from indico.errors import IndicoInputError -from indico.types.integration import ( - ExchangeIntegration, - ExchangeIntegrationConfiguration, - ExchangeIntegrationCredentials, -) +from indico import GraphQLRequest +from indico.types.integration import ExchangeIntegration +if TYPE_CHECKING: # pragma: no cover + from indico.types.integration import ( + ExchangeIntegrationConfiguration, + ExchangeIntegrationCredentials, + ) + from indico.typing import AnyDict, Payload # noqa: F401 -class AddExchangeIntegration(GraphQLRequest): + +class AddExchangeIntegration(GraphQLRequest["ExchangeIntegration"]): """ Mutation to add a Microsoft Exchange integration to a workflow @@ -42,8 +43,8 @@ class AddExchangeIntegration(GraphQLRequest): def __init__( self, - config: ExchangeIntegrationConfiguration, - credentials: ExchangeIntegrationCredentials, + config: "ExchangeIntegrationConfiguration", + credentials: "ExchangeIntegrationCredentials", workflow_id: int, ): super().__init__( @@ -55,15 +56,15 @@ def __init__( }, ) - def process_response(self, response) -> ExchangeIntegration: + def process_response(self, response: "Payload") -> "ExchangeIntegration": return ExchangeIntegration( - **super().process_response(response)["addExchangeIntegrationToWorkflow"][ + **super().parse_payload(response)["addExchangeIntegrationToWorkflow"][ "integration" ] ) -class StartIntegration(GraphQLRequest): +class StartIntegration(GraphQLRequest["AnyDict"]): """ Mutation to start an existing integration. Once an integration is started, documents will be submitted to the associated workflow. @@ -90,7 +91,7 @@ def __init__( ) -class DeleteIntegration(GraphQLRequest): +class DeleteIntegration(GraphQLRequest["AnyDict"]): """ Mutation to delete an existing Integration. @@ -114,14 +115,14 @@ def __init__(self, integration_id: int): ) -class PauseIntegration(GraphQLRequest): +class PauseIntegration(GraphQLRequest["AnyDict"]): """ Mutation to pause an existing Integration. Args: integration_id(int): id of the integration to pause """ - + query = """ mutation PauseIntegration($integrationId: Int!){ pauseWorkflowIntegration(integrationId: $integrationId){ @@ -129,7 +130,7 @@ class PauseIntegration(GraphQLRequest): } } """ - + def __init__(self, integration_id: int): super().__init__( self.query, diff --git a/indico/queries/jobs.py b/indico/queries/jobs.py index 506c5e40..b87eed39 100644 --- a/indico/queries/jobs.py +++ b/indico/queries/jobs.py @@ -1,12 +1,17 @@ # -*- coding: utf-8 -*- -from typing import Union +from typing import TYPE_CHECKING from indico.client.request import Delay, GraphQLRequest, RequestChain from indico.types.jobs import Job from indico.types.utils import Timer +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, Optional, Union -class _JobStatus(GraphQLRequest): + from indico.typing import Payload + + +class _JobStatus(GraphQLRequest["Job"]): query = """ query JobStatus($id: String) { job(id: $id) { @@ -17,14 +22,14 @@ class _JobStatus(GraphQLRequest): } """ - def __init__(self, id): + def __init__(self, id: str): super().__init__(self.query, variables={"id": id}) - def process_response(self, response): - return Job(**super().process_response(response)["job"]) + def process_response(self, response: "Payload") -> "Job": + return Job(**super().parse_payload(response)["job"]) -class _JobStatusWithResult(GraphQLRequest): +class _JobStatusWithResult(GraphQLRequest["Job"]): query = """ query JobStatus($id: String) { job(id: $id) { @@ -36,14 +41,14 @@ class _JobStatusWithResult(GraphQLRequest): } """ - def __init__(self, id): + def __init__(self, id: str): super().__init__(self.query, variables={"id": id}) - def process_response(self, response): - return Job(**super().process_response(response)["job"]) + def process_response(self, response: "Payload") -> "Job": + return Job(**super().parse_payload(response)["job"]) -class JobStatus(RequestChain): +class JobStatus(RequestChain["Job"]): """ Status of a Job in the Indico Platform. @@ -67,26 +72,28 @@ class JobStatus(RequestChain): completed after `timeout` seconds """ - previous: Job = None + previous: "Job" def __init__( self, id: str, wait: bool = True, - request_interval: Union[int, float] = 0.2, - timeout: Union[int, float] = None, + request_interval: "Union[int, float]" = 0.2, + timeout: "Optional[Union[int, float]]" = None, ): self.id = id self.wait = wait self.request_interval = request_interval self.timeout = timeout - def requests(self): + def requests(self) -> "Iterator[Union[_JobStatus, Delay, _JobStatusWithResult]]": yield _JobStatus(id=self.id) + if self.wait: - timer = None + timer: "Optional[Timer]" = None if self.timeout is not None: timer = Timer(self.timeout) + # Check status of job until done if wait == True while not ( (self.previous.status in ["SUCCESS"] and self.previous.ready) @@ -103,4 +110,5 @@ def requests(self): timer.check() yield Delay(seconds=self.request_interval) yield _JobStatus(id=self.id) + yield _JobStatusWithResult(id=self.id) diff --git a/indico/queries/model_groups/metrics.py b/indico/queries/model_groups/metrics.py index ddde4224..b032eeb8 100644 --- a/indico/queries/model_groups/metrics.py +++ b/indico/queries/model_groups/metrics.py @@ -1,11 +1,18 @@ +import json +from typing import TYPE_CHECKING + from indico.client.request import GraphQLRequest, RequestChain -from indico.types.model_metrics import SequenceMetrics +from indico.errors import IndicoInputError from indico.queries.model_groups import GetModelGroup -from indico.errors import IndicoRequestError -import json +from indico.types.model_metrics import SequenceMetrics + +if TYPE_CHECKING: # pragma: no cover + from typing import Dict, Iterator, Type, Union + + from indico.typing import AnyDict, Payload -class AnnotationModelGroupMetrics(GraphQLRequest): +class AnnotationModelGroupMetrics(GraphQLRequest["SequenceMetrics"]): """ Get metrics for annotation or "sequence" models. Metrics for the most recently succesfully trained model of the model group are returned. @@ -58,15 +65,15 @@ class AnnotationModelGroupMetrics(GraphQLRequest): def __init__(self, model_group_id: int): super().__init__(self.query, variables={"modelGroupId": model_group_id}) - def process_response(self, response): + def process_response(self, response: "Payload") -> "SequenceMetrics": return SequenceMetrics( - **super().process_response(response)["modelGroups"]["modelGroups"][0][ + **super().parse_payload(response)["modelGroups"]["modelGroups"][0][ "selectedModel" ]["evaluation"]["metrics"] ) -class ObjectDetectionMetrics(GraphQLRequest): +class ObjectDetectionMetrics(GraphQLRequest["AnyDict"]): """ Get metrics for a trained object detection model. Metrics for the most recently succesfully trained model of the model group are returned. @@ -97,21 +104,22 @@ class ObjectDetectionMetrics(GraphQLRequest): def __init__(self, model_group_id: int): super().__init__(self.query, variables={"modelGroupId": model_group_id}) - def process_response(self, response): - return json.loads( + def process_response(self, response: "Payload") -> "AnyDict": + raw_response: "AnyDict" = json.loads( super().process_response(response)["modelGroups"]["modelGroups"][0][ "selectedModel" ]["evaluation"]["metrics"] ) + return raw_response -task_type_query_mapping = { +task_type_query_mapping: "Dict[str, Type[Union[AnnotationModelGroupMetrics, ObjectDetectionMetrics]]]" = { "ANNOTATION": AnnotationModelGroupMetrics, "OBJECT_DETECTION": ObjectDetectionMetrics, } -class GetModelGroupMetrics(RequestChain): +class GetModelGroupMetrics(RequestChain["Union[SequenceMetrics, AnyDict]"]): """ Args: model_group_id (int): ModelGroup id @@ -124,12 +132,14 @@ def __init__(self, model_group_id: int): self.model_group_id = model_group_id super().__init__() - def requests(self): + def requests( + self, + ) -> "Iterator[Union[GetModelGroup, AnnotationModelGroupMetrics, ObjectDetectionMetrics]]": yield GetModelGroup(id=self.model_group_id) if self.previous.task_type not in task_type_query_mapping: - raise IndicoRequestError( + raise IndicoInputError( "Metrics queries are only supported for annotation and object detection at this time." ) - metrics_query_fn = task_type_query_mapping.get(self.previous.task_type) + metrics_query_fn = task_type_query_mapping[self.previous.task_type] yield metrics_query_fn(self.model_group_id) return self.previous diff --git a/indico/queries/model_groups/model_groups.py b/indico/queries/model_groups/model_groups.py index 7a833328..f1a97b3e 100644 --- a/indico/queries/model_groups/model_groups.py +++ b/indico/queries/model_groups/model_groups.py @@ -1,18 +1,23 @@ import json -from typing import Any, Dict, List, Optional, Union - -import deprecation +from typing import TYPE_CHECKING from indico.client.request import Delay, GraphQLRequest, RequestChain from indico.errors import IndicoNotFound -from indico.queries.workflow_components import AddModelGroupComponent + +# backwards compat +from indico.queries.workflow_components import AddModelGroupComponent # noqa: F401 from indico.types.jobs import Job from indico.types.model import Model, ModelOptions from indico.types.model_group import ModelGroup from indico.types.utils import cc_to_snake +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Iterator, List, Optional, Union + + from indico.typing import AnyDict, Payload + -class GetModelGroup(RequestChain): +class GetModelGroup(RequestChain[ModelGroup]): """ Get an object describing a model group @@ -26,23 +31,26 @@ class GetModelGroup(RequestChain): """ def __init__( - self, id: int, wait: bool = False, request_interval: Union[int, float] = 5 + self, id: int, wait: bool = False, request_interval: "Union[int, float]" = 5 ): self.id = id self.wait = wait self.request_interval = request_interval - def requests(self): + def requests( + self, + ) -> "Iterator[Union[GetModelGroupSelectedModelStatus, Delay, _GetModelGroup]]": if self.wait: req = GetModelGroupSelectedModelStatus(id=self.id) yield req while self.previous not in ["FAILED", "COMPLETE", "NOT_ENOUGH_DATA"]: yield Delay(seconds=self.request_interval) yield req + yield _GetModelGroup(id=self.id) -class _GetModelGroup(GraphQLRequest): +class _GetModelGroup(GraphQLRequest[ModelGroup]): """ Get an object describing a model group @@ -73,19 +81,18 @@ class _GetModelGroup(GraphQLRequest): def __init__(self, id: int): super().__init__(query=self.query, variables={"id": id}) - def process_response(self, response): + def process_response(self, response: "Payload") -> "ModelGroup": try: - mg = ModelGroup( - **super().process_response(response)["modelGroups"]["modelGroups"][0] + return ModelGroup( + **super().parse_payload(response)["modelGroups"]["modelGroups"][0] ) except IndexError: raise IndicoNotFound( "ModelGroup not found. Please check the ID you are using." ) - return mg -class GetTrainingModelWithProgress(GraphQLRequest): +class GetTrainingModelWithProgress(GraphQLRequest[Model]): """ Get progress (percent complete) of a training model group @@ -115,13 +122,12 @@ class GetTrainingModelWithProgress(GraphQLRequest): def __init__(self, id: int): super().__init__(query=self.query, variables={"id": id}) - def process_response(self, response): - response = super().process_response(response) - model_groups = response["modelGroups"]["modelGroups"] + def process_response(self, response: "Payload") -> "Model": + model_groups = super().parse_payload(response)["modelGroups"]["modelGroups"] if len(model_groups) != 1: raise IndicoNotFound("Model Group") - models = model_groups[0]["models"] + models = model_groups[0]["models"] last = max(models, key=lambda m: m["id"]) if not last: raise IndicoNotFound("Training Model") @@ -129,7 +135,7 @@ def process_response(self, response): return Model(**last) -class GetModelGroupSelectedModelStatus(GraphQLRequest): +class GetModelGroupSelectedModelStatus(GraphQLRequest[str]): """ Get the status string of the selected model for the given model group id @@ -157,14 +163,14 @@ class GetModelGroupSelectedModelStatus(GraphQLRequest): def __init__(self, id: int): super().__init__(query=self.query, variables={"id": id}) - def process_response(self, response): + def process_response(self, response: "Payload") -> str: mg = ModelGroup( - **super().process_response(response)["modelGroups"]["modelGroups"][0] + **super().parse_payload(response)["modelGroups"]["modelGroups"][0] ) return mg.selected_model.status -class _ModelGroupPredict(GraphQLRequest): +class _ModelGroupPredict(GraphQLRequest[Job]): query = """ mutation ModelGroupPredict() { modelPredict() { @@ -175,7 +181,7 @@ class _ModelGroupPredict(GraphQLRequest): query_args = {"modelId": "Int!", "data": "[String]", "predictOptions": "JSONString"} - def _args_strings(self, **kwargs): + def _args_strings(self, **kwargs: "Any") -> str: args = [k for k in self.query_args.keys() if kwargs.get(cc_to_snake(k))] query_args_string = ",".join(f"${k}: {self.query_args[k]}" for k in args) @@ -186,12 +192,21 @@ def _args_strings(self, **kwargs): return query - def __init__(self, model_id: int, data: List[str], predict_options: Dict = None): + def __init__( + self, + model_id: int, + data: "List[str]", + predict_options: "Optional[Union[str, AnyDict]]" = None, + ): + predict_options_json: "Optional[str]" = None if predict_options: - predict_options = json.dumps(predict_options) + if isinstance(predict_options, dict): + predict_options_json = json.dumps(predict_options) + else: + predict_options_json = predict_options query = self._args_strings( - model_id=model_id, data=data, predict_options=predict_options + model_id=model_id, data=data, predict_options=predict_options_json ) super().__init__( @@ -199,15 +214,15 @@ def __init__(self, model_id: int, data: List[str], predict_options: Dict = None) variables={ "modelId": model_id, "data": data, - "predictOptions": predict_options, + "predictOptions": predict_options_json, }, ) - def process_response(self, response): - return Job(**super().process_response(response)["modelPredict"]) + def process_response(self, response: "Payload") -> "Job": + return Job(**super().parse_payload(response)["modelPredict"]) -class ModelGroupPredict(RequestChain): +class ModelGroupPredict(RequestChain[Job]): """ Generate predictions from a model group on new data @@ -223,22 +238,21 @@ class ModelGroupPredict(RequestChain): def __init__( self, model_id: int, - data: List[str], + data: "List[str]", load: bool = True, - predict_options: Dict = None, + predict_options: "Optional[AnyDict]" = None, ): self.model_id = model_id self.data = data self.predict_options = predict_options - def requests(self): - + def requests(self) -> "Iterator[_ModelGroupPredict]": yield _ModelGroupPredict( model_id=self.model_id, data=self.data, predict_options=self.predict_options ) -class UpdateModelGroupSettings(GraphQLRequest): +class UpdateModelGroupSettings(GraphQLRequest["ModelOptions"]): """ Updates an existing model group component in the platform. @@ -287,27 +301,35 @@ class UpdateModelGroupSettings(GraphQLRequest): def __init__( self, model_group_id: int, - model_training_options: Optional[Dict[str, Any]] = None, - predict_options: Optional[Dict[str, Any]] = None, + model_training_options: "Optional[Union[str, AnyDict]]" = None, + predict_options: "Optional[Union[str, AnyDict]]" = None, ): + model_training_options_json: "Optional[str]" = None if model_training_options: - model_training_options = json.dumps(model_training_options) + if isinstance(model_training_options, dict): + model_training_options = json.dumps(model_training_options) + else: + model_training_options = model_training_options + predict_options_json: "Optional[str]" = None if predict_options: - predict_options = json.dumps(predict_options) + if isinstance(predict_options, dict): + predict_options_json = json.dumps(predict_options) + else: + predict_options_json = predict_options super().__init__( self.query, variables={ "modelGroupId": model_group_id, - "modelTrainingOptions": model_training_options, - "predictOptions": predict_options, + "modelTrainingOptions": model_training_options_json, + "predictOptions": predict_options_json, }, ) - def process_response(self, response): + def process_response(self, response: "Payload") -> "ModelOptions": return ModelOptions( - **super().process_response(response)["updateModelGroupSettings"][ + **super().parse_payload(response)["updateModelGroupSettings"][ "modelOptions" ] ) diff --git a/indico/queries/questionnaire.py b/indico/queries/questionnaire.py index e7a3bf55..9dd16a89 100644 --- a/indico/queries/questionnaire.py +++ b/indico/queries/questionnaire.py @@ -1,11 +1,16 @@ -from typing import List, Optional +from typing import TYPE_CHECKING from indico.client.request import Delay, GraphQLRequest, RequestChain from indico.errors import IndicoError, IndicoNotFound from indico.types.questionnaire import Example, Questionnaire +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class AddLabels(GraphQLRequest): + from indico.typing import AnyDict, Payload + + +class AddLabels(GraphQLRequest["AnyDict"]): """ Add labels to an existing labelset. @@ -34,8 +39,8 @@ class AddLabels(GraphQLRequest): def __init__( self, labelset_id: int, - labels: List[dict], - model_group_id: int = None, + labels: "List[AnyDict]", + model_group_id: "Optional[int]" = None, ): super().__init__( query=self.query, @@ -47,7 +52,7 @@ def __init__( ) -class GetQuestionnaireExamples(GraphQLRequest): +class GetQuestionnaireExamples(GraphQLRequest["List[Example]"]): """ Gets unlabeled examples from a Questionnaire. @@ -84,7 +89,7 @@ def __init__( self, questionnaire_id: int, num_examples: int, - datafile_id: Optional[int] = None, + datafile_id: "Optional[int]" = None, ): super().__init__( query=self.query, @@ -95,11 +100,11 @@ def __init__( }, ) - def process_response(self, response): + def process_response(self, response: "Payload") -> "List[Example]": try: - examples = [ + return [ Example(**e) - for e in super().process_response(response)["questionnaires"][ + for e in super().parse_payload(response)["questionnaires"][ "questionnaires" ][0]["examples"] ] @@ -107,10 +112,9 @@ def process_response(self, response): raise IndicoNotFound( "Examples not found. Please check the ID you are using." ) - return examples -class _GetQuestionnaire(GraphQLRequest): +class _GetQuestionnaire(GraphQLRequest["Questionnaire"]): """ Gets a questionnaire from an ID. @@ -153,16 +157,17 @@ def __init__(self, questionnaire_id: int): variables={"questionnaire_id": questionnaire_id}, ) - def process_response(self, response): - questionnaire_list = super().process_response(response)["questionnaires"][ + def process_response(self, response: "Payload") -> "Questionnaire": + questionnaire_list: "List[AnyDict]" = super().parse_payload(response)[ "questionnaires" - ] + ]["questionnaires"] if not questionnaire_list or not questionnaire_list[0]: raise IndicoError("Cannot find questionnaire") + return Questionnaire(**questionnaire_list[0]) -class GetQuestionnaire(RequestChain): +class GetQuestionnaire(RequestChain["Questionnaire"]): """ Gets a questionnaire from an ID. @@ -174,15 +179,15 @@ class GetQuestionnaire(RequestChain): Questionnaire object """ - previous = None + previous: "Questionnaire" def __init__(self, questionnaire_id: int, wait: bool = True): self.questionnaire_id = questionnaire_id self.wait = wait - def requests(self): + def requests(self) -> "Iterator[Union[_GetQuestionnaire, Delay]]": yield _GetQuestionnaire(questionnaire_id=self.questionnaire_id) if self.wait: while self.previous.questions_status == "STARTED": - yield _GetQuestionnaire(questionnaire_id=self.questionnaire_id) yield Delay() + yield _GetQuestionnaire(questionnaire_id=self.questionnaire_id) diff --git a/indico/queries/storage.py b/indico/queries/storage.py index d40bc409..b7bcf121 100644 --- a/indico/queries/storage.py +++ b/indico/queries/storage.py @@ -1,14 +1,19 @@ import io import json -from typing import List, Dict +from typing import TYPE_CHECKING + from indico.client.request import HTTPMethod, HTTPRequest, RequestChain -from indico.errors import IndicoRequestError, IndicoInputError +from indico.errors import IndicoInputError, IndicoRequestError + +if TYPE_CHECKING: # pragma: no cover + from typing import Any, Dict, Iterator, List, Optional, Type, Union + from indico.typing import AnyDict URL_PREFIX = "indico-file:///storage" -class RetrieveStorageObject(HTTPRequest): +class RetrieveStorageObject(HTTPRequest["Any"]): """ Retrieve an object stored on the Indico Platform @@ -23,15 +28,18 @@ class RetrieveStorageObject(HTTPRequest): contents (dict): Contents of the storage object, most often JSON """ - def __init__(self, storage_object): - if type(storage_object) == dict: + def __init__(self, storage_object: "Union[AnyDict, str]"): + if isinstance(storage_object, dict): try: url = storage_object["url"] except KeyError: raise IndicoRequestError( code="FAILURE", - error="Unable to retrieve result. Please check the status of the job object. If the status is \ - 'FAILURE', check the job object result for more detailed information.", + error=( + "Unable to retrieve result. Please check the status of the job" + " object. If the status is 'FAILURE', check the job object" + " result for more detailed information." + ), ) else: url = storage_object @@ -40,7 +48,7 @@ def __init__(self, storage_object): super().__init__(method=HTTPMethod.GET, path=url) -class UploadDocument(HTTPRequest): +class UploadDocument(HTTPRequest["List[AnyDict]"]): """ Upload an object stored on the Indico Platform @@ -55,14 +63,21 @@ class UploadDocument(HTTPRequest): files: Storage object to be used for further processing requests (e.g., document extraction). """ - def __init__(self, files: List[str] = None, streams: Dict[str, io.BufferedIOBase] = None): - - if (files is None and streams is None) or (files is not None and streams is not None): + def __init__( + self, + files: "Optional[List[str]]" = None, + streams: "Optional[Dict[str, io.BufferedIOBase]]" = None, + ): + if (files is None and streams is None) or ( + files is not None and streams is not None + ): raise IndicoInputError("Must define one of files or streams, but not both.") - super().__init__(HTTPMethod.POST, "/storage/files/store", files=files, streams=streams) + super().__init__( + HTTPMethod.POST, "/storage/files/store", files=files, streams=streams + ) - def process_response(self, uploaded_files: List[dict]): + def process_response(self, response: "List[AnyDict]") -> "List[AnyDict]": files = [ { "filename": f["name"], @@ -74,12 +89,12 @@ def process_response(self, uploaded_files: List[dict]): } ), } - for f in uploaded_files + for f in response ] return files -class UploadBatched(RequestChain): +class UploadBatched(RequestChain["List[AnyDict]"]): """ Batch uploading of files to the Indico Platform @@ -94,16 +109,16 @@ class UploadBatched(RequestChain): def __init__( self, - files: List[str], + files: "List[str]", batch_size: int = 20, - request_cls: HTTPRequest = UploadDocument, + request_cls: "Type[Any]" = UploadDocument, ): - self.result = None + self.result: "Optional[List[Any]]" = None self.files = files self.batch_size = batch_size self.request_cls = request_cls - def requests(self): + def requests(self) -> "Iterator[Any]": self.result = [] for i in range(0, len(self.files), self.batch_size): yield self.request_cls(self.files[i : i + self.batch_size]) @@ -121,13 +136,13 @@ class CreateStorageURLs(UploadDocument): urls: list of storage urls to be use for further processing requests (e.g., form extraction) """ - def process_response(self, uploaded_files: List[dict]) -> List[str]: - errors = [f["error"] for f in uploaded_files if f.get("error")] + def process_response(self, response: "Any") -> "Any": + errors = [f["error"] for f in response if f.get("error")] if errors: raise IndicoInputError( "\n".join(error for error in errors), ) - urls = [URL_PREFIX + f["path"] for f in uploaded_files] + urls = [URL_PREFIX + f["path"] for f in response] return urls diff --git a/indico/queries/submission.py b/indico/queries/submission.py index 1e150a1e..f89d9d3a 100644 --- a/indico/queries/submission.py +++ b/indico/queries/submission.py @@ -1,7 +1,7 @@ import json from functools import partial from operator import eq, ne -from typing import Dict, List, Union +from typing import TYPE_CHECKING from indico.client.request import Delay, GraphQLRequest, PagedRequest, RequestChain from indico.errors import IndicoInputError, IndicoTimeoutError @@ -11,8 +11,13 @@ from indico.types.submission import VALID_SUBMISSION_STATUSES from indico.types.utils import Timer +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class ListSubmissions(PagedRequest): + from indico.typing import AnyDict, Payload + + +class ListSubmissions(PagedRequest["List[Submission]"]): """ List all Submissions visible to the authenticated user by most recent. Supports pagination (limit becomes page_size) @@ -125,10 +130,10 @@ class ListSubmissions(PagedRequest): def __init__( self, *, - submission_ids: List[int] = None, - workflow_ids: List[int] = None, - filters: Union[Dict, SubmissionFilter] = None, - limit: int = 1000, + submission_ids: "Optional[List[int]]" = None, + workflow_ids: "Optional[List[int]]" = None, + filters: "Optional[Union[AnyDict, SubmissionFilter]]" = None, + limit: "Optional[int]" = 1000, order_by: str = "ID", desc: bool = True, ): @@ -144,14 +149,14 @@ def __init__( }, ) - def process_response(self, response) -> List[Submission]: + def process_response(self, response: "Payload") -> "List[Submission]": return [ Submission(**s) - for s in super().process_response(response)["submissions"]["submissions"] + for s in super().parse_payload(response)["submissions"]["submissions"] ] -class GetSubmission(GraphQLRequest): +class GetSubmission(GraphQLRequest["Submission"]): """ Retrieve a Submission by id @@ -233,11 +238,11 @@ class GetSubmission(GraphQLRequest): def __init__(self, submission_id: int): super().__init__(self.query, variables={"submissionId": submission_id}) - def process_response(self, response) -> Submission: - return Submission(**(super().process_response(response)["submission"])) + def process_response(self, response: "Payload") -> "Submission": + return Submission(**(super().parse_payload(response)["submission"])) -class WaitForSubmissions(RequestChain): +class WaitForSubmissions(RequestChain["List[Submission]"]): """ Given submission_ids, wait for all to finish processing """ @@ -282,7 +287,7 @@ class WaitForSubmissions(RequestChain): } """ - def __init__(self, submission_ids: List[int], timeout: Union[int, float] = 60): + def __init__(self, submission_ids: "List[int]", timeout: "Union[int, float]" = 60): if not submission_ids: raise IndicoInputError("Please provide submission ids") @@ -293,7 +298,7 @@ def __init__(self, submission_ids: List[int], timeout: Union[int, float] = 60): ListSubmissions, submission_ids=self.submission_ids, limit=None ) - def requests(self) -> List[Submission]: + def requests(self) -> "Iterator[ListSubmissions]": timer = Timer(self.timeout) while True: @@ -303,7 +308,7 @@ def requests(self) -> List[Submission]: break -class UpdateSubmission(GraphQLRequest): +class UpdateSubmission(GraphQLRequest["Submission"]): """ Update the retrieval status of a Submission by id @@ -356,11 +361,11 @@ def __init__(self, submission_id: int, retrieved: bool): variables={"submissionId": submission_id, "retrieved": retrieved}, ) - def process_response(self, response) -> Submission: - return Submission(**(super().process_response(response)["updateSubmission"])) + def process_response(self, response: "Payload") -> "Submission": + return Submission(**(super().parse_payload(response)["updateSubmission"])) -class GenerateSubmissionResult(GraphQLRequest): +class GenerateSubmissionResult(GraphQLRequest["Job"]): query = """ mutation CreateSubmissionResults($submissionId: Int!) { submissionResults(submissionId: $submissionId) { @@ -370,16 +375,16 @@ class GenerateSubmissionResult(GraphQLRequest): """ - def __init__(self, submission: Union[int, Submission]): + def __init__(self, submission: "Union[int, Submission]"): submission_id = submission if isinstance(submission, int) else submission.id super().__init__(self.query, variables={"submissionId": submission_id}) - def process_response(self, response) -> Job: - response = super().process_response(response)["submissionResults"] + def process_response(self, response: "Payload") -> "Job": + response = super().parse_payload(response)["submissionResults"] return Job(id=response["jobId"]) -class SubmissionResult(RequestChain): +class SubmissionResult(RequestChain["Job"]): """ Generate a result file for a Submission @@ -404,15 +409,13 @@ class SubmissionResult(RequestChain): IndicoInputError: The requested Submission is not `check_status` """ - previous: Submission = None - def __init__( self, - submission: Union[int, Submission], - check_status: str = None, + submission: "Union[int, Submission]", + check_status: "Optional[str]" = None, wait: bool = False, - timeout: Union[int, float] = 30, - request_interval: Union[int, float] = 5, + timeout: "Union[int, float]" = 30, + request_interval: "Union[int, float]" = 5, ): self.submission_id = ( submission if isinstance(submission, int) else submission.id @@ -431,7 +434,9 @@ def __init__( else partial(ne, "PROCESSING") ) - def requests(self) -> Union[Job, str]: + def requests( + self, + ) -> "Iterator[Union[GetSubmission, Delay, GenerateSubmissionResult, JobStatus]]": timer = Timer(self.timeout) timer.check() yield GetSubmission(self.submission_id) @@ -452,7 +457,7 @@ def requests(self) -> Union[Job, str]: yield JobStatus(id=self.previous.id, wait=True, timeout=self.timeout) -class SubmitReview(GraphQLRequest): +class SubmitReview(GraphQLRequest["Job"]): """ Submit an "Auto" Review for a submission. Requires that the submission be in PENDING_AUTO_REVIEW status. @@ -489,19 +494,24 @@ class SubmitReview(GraphQLRequest): def __init__( self, - submission: Union[int, Submission], - changes: Dict | List = None, + submission: "Union[int, Submission]", + changes: "Optional[Union[str, AnyDict, List[AnyDict]]]" = None, rejected: bool = False, - force_complete: bool = None, + force_complete: "Optional[bool]" = None, ): + changes_json: "Optional[str]" = None submission_id = submission if isinstance(submission, int) else submission.id if not changes and not rejected: raise IndicoInputError("Must provide changes or reject=True") - elif changes and isinstance(changes, (dict, list)): - changes = json.dumps(changes) + elif changes: + if isinstance(changes, (dict, list)): + changes_json = json.dumps(changes) + else: + changes_json = changes + _vars = { "submissionId": submission_id, - "changes": changes, + "changes": changes_json, "rejected": rejected, } @@ -519,12 +529,12 @@ def __init__( super().__init__(query, variables=_vars) - def process_response(self, response) -> Job: - response = super().process_response(response)["submitAutoReview"] + def process_response(self, response: "Payload") -> "Job": + response = super().parse_payload(response)["submitAutoReview"] return Job(id=response["jobId"]) -class GetReviews(GraphQLRequest): +class GetReviews(GraphQLRequest["List[SubmissionReviewFull]"]): """ Given a submission Id, return all the full Review objects back with changes @@ -560,14 +570,14 @@ class GetReviews(GraphQLRequest): def __init__(self, submission_id: int): super().__init__(self.query, variables={"submissionId": submission_id}) - def process_response(self, response) -> List[SubmissionReviewFull]: + def process_response(self, response: "Payload") -> "List[SubmissionReviewFull]": return [ SubmissionReviewFull(**review) - for review in super().process_response(response)["submission"]["reviews"] + for review in super().parse_payload(response)["submission"]["reviews"] ] -class RetrySubmission(GraphQLRequest): +class RetrySubmission(GraphQLRequest["List[Submission]"]): """ Given a list of submission ids, retry those failed submissions. Submissions must be in a failed state and cannot be requested before @@ -599,14 +609,13 @@ class RetrySubmission(GraphQLRequest): } """ - def __init__(self, submission_ids: List[int]): + def __init__(self, submission_ids: "List[int]"): if submission_ids is None or len(submission_ids) < 1: raise IndicoInputError("You must specify submission ids") super().__init__(self.query, variables={"submissionIds": submission_ids}) - def process_response(self, response) -> List[Submission]: + def process_response(self, response: "Payload") -> "List[Submission]": return [ - Submission(**s) - for s in super().process_response(response)["retrySubmissions"] + Submission(**s) for s in super().parse_payload(response)["retrySubmissions"] ] diff --git a/indico/queries/usermetrics.py b/indico/queries/usermetrics.py index 3ae35619..0b0de318 100644 --- a/indico/queries/usermetrics.py +++ b/indico/queries/usermetrics.py @@ -1,19 +1,27 @@ -import datetime -from typing import List, Union, Dict +from typing import TYPE_CHECKING, List -from indico.client.request import ( - GraphQLRequest, - PagedRequest, -) +from indico.client.request import GraphQLRequest, PagedRequest from indico.filters import UserMetricsFilter from indico.types import BaseType -from indico.types.user_metrics import UserSummary, UserSnapshot, UserChangelog, UserChangelogReport +from indico.types.user_metrics import ( + UserChangelog, + UserChangelogReport, + UserSnapshot, + UserSummary, +) + +if TYPE_CHECKING: # pragma: no cover + from datetime import datetime + from typing import Optional, Union + + from indico.typing import AnyDict, Payload class _PagedUserSnapshots(BaseType): """ Class to hold paged snapshot data to make parsing easier. """ + results: List[UserSnapshot] @@ -21,10 +29,11 @@ class _PagedUserChangelog(BaseType): """ Class to hold paged snapshot data to make parsing easier. """ + results: List[UserChangelog] -class GetUserSummary(GraphQLRequest): +class GetUserSummary(GraphQLRequest["UserSummary"]): """ Requests summary-level information of users in the app on a specific date. @@ -32,6 +41,7 @@ class GetUserSummary(GraphQLRequest): date (datetime): specific day to summarize. """ + query = """ query GetUserSummary($date: Date){ userSummary(date: $date){ @@ -47,17 +57,17 @@ class GetUserSummary(GraphQLRequest): } """ - def __init__(self, date=None): + def __init__(self, date: "Optional[datetime]" = None): if date is not None: - super().__init__(self.query, variables={"date": date.strftime('%Y-%m-%d')}) + super().__init__(self.query, variables={"date": date.strftime("%Y-%m-%d")}) else: super().__init__(self.query) - def process_response(self, response) -> UserSummary: - return UserSummary(**super().process_response(response)["userSummary"]) + def process_response(self, response: "Payload") -> "UserSummary": + return UserSummary(**super().parse_payload(response)["userSummary"]) -class GetUserSnapshots(PagedRequest): +class GetUserSnapshots(PagedRequest["List[UserSnapshot]"]): """ Requests paged detailed information about app users on a specific date. @@ -67,6 +77,7 @@ class GetUserSnapshots(PagedRequest): date (datetime): specific day to query. limit (int): limit how many come back per query or per page. """ + query = """ query GetUserSnapshot($date: Date, $filters: UserReportFilter, $after: Int, $limit: Int){ userSnapshot(date: $date, filters: $filters, after: $after, limit: $limit){ @@ -81,7 +92,7 @@ class GetUserSnapshots(PagedRequest): datasetId role } - + } pageInfo{ startCursor @@ -93,29 +104,38 @@ class GetUserSnapshots(PagedRequest): } """ - def __init__(self, *, date: datetime, filters: Union[Dict, UserMetricsFilter] = None, limit: int = None): + def __init__( + self, + *, + date: "datetime", + filters: "Optional[Union[AnyDict, UserMetricsFilter]]" = None, + limit: "Optional[int]" = None, + ): variables = { - "date": date.strftime('%Y-%m-%d') if date is not None else None, + "date": date.strftime("%Y-%m-%d") if date is not None else None, "filters": filters, - "limit": limit + "limit": limit, } super().__init__(self.query, variables=variables) - def process_response(self, response) -> List[UserSnapshot]: - return _PagedUserSnapshots(**super().process_response(response)["userSnapshot"]).results + def process_response(self, response: "Payload") -> "List[UserSnapshot]": + return _PagedUserSnapshots( + **super().parse_payload(response)["userSnapshot"] + ).results -class GetUserChangelog(PagedRequest): +class GetUserChangelog(PagedRequest["List[UserChangelog]"]): """ - Gets paged detailed information about app users. + Gets paged detailed information about app users. + + Args: + filters (UserSnapshotFilter): filter the query based on UserMetricsFilter criteria. + start_date (datetime): specific start date for query. + end_date (datetime): specific end date for query. + limit (int): limit how many come back per query or per page. + """ - Args: - filters (UserSnapshotFilter): filter the query based on UserMetricsFilter criteria. - start_date (datetime): specific start date for query. - end_date (datetime): specific end date for query. - limit (int): limit how many come back per query or per page. - """ query = """ query GetUserChangelog($sdate: Date, $edate: Date, $filters: UserReportFilter, $after: Int, $limit: Int){ userChangelog(startDate: $sdate, endDate: $edate, filters: $filters, after:$after, limit:$limit){ @@ -140,32 +160,43 @@ class GetUserChangelog(PagedRequest): } """ - def __init__(self, *, start_date: datetime, end_date: datetime, filters: Union[Dict, UserMetricsFilter] = None, - limit: int = None): + def __init__( + self, + *, + start_date: "datetime", + end_date: "datetime", + filters: "Optional[Union[AnyDict, UserMetricsFilter]]" = None, + limit: "Optional[int]" = None, + ): variables = { - "sdate": start_date.strftime('%Y-%m-%d') if start_date is not None else None, - "edate": end_date.strftime('%Y-%m-%d') if end_date is not None else None, + "sdate": start_date.strftime("%Y-%m-%d") + if start_date is not None + else None, + "edate": end_date.strftime("%Y-%m-%d") if end_date is not None else None, "filters": filters, - "limit": limit + "limit": limit, } super().__init__(self.query, variables=variables) - def process_response(self, response) -> List[UserSnapshot]: - return _PagedUserChangelog(**super().process_response(response)["userChangelog"]).results + def process_response(self, response: "Payload") -> "List[UserChangelog]": + return _PagedUserChangelog( + **super().parse_payload(response)["userChangelog"] + ).results -class GenerateChangelogReport(GraphQLRequest): +class GenerateChangelogReport(GraphQLRequest["UserChangelogReport"]): """ - Creates a job to generate a report of detailed information about app users - - Args: - filters (UserSnapshotFilter): filter the query based on UserMetricsFilter criteria. - start_date (datetime): specific start date for query. - end_date (datetime): specific end date for query. - report_format (str): specific format of the report, JSON or CSV. + Creates a job to generate a report of detailed information about app users + + Args: + filters (UserSnapshotFilter): filter the query based on UserMetricsFilter criteria. + start_date (datetime): specific start date for query. + end_date (datetime): specific end date for query. + report_format (str): specific format of the report, JSON or CSV. + + """ - """ query = """ mutation GenerateChangeReport ($sdate: Date, $edate: Date, $filters: UserReportFilter){ userChangelogReport( @@ -179,15 +210,25 @@ class GenerateChangelogReport(GraphQLRequest): } """ - def __init__(self, *, start_date: datetime, end_date: datetime, filters: Union[Dict, UserMetricsFilter] = None, - report_format: str = "csv"): + def __init__( + self, + *, + start_date: "datetime", + end_date: "datetime", + filters: "Optional[Union[AnyDict, UserMetricsFilter]]" = None, + report_format: str = "csv", + ): variables = { - "sdate": start_date.strftime('%Y-%m-%d') if start_date is not None else None, - "edate": end_date.strftime('%Y-%m-%d') if end_date is not None else None, + "sdate": start_date.strftime("%Y-%m-%d") + if start_date is not None + else None, + "edate": end_date.strftime("%Y-%m-%d") if end_date is not None else None, "filters": filters, - "format": report_format + "format": report_format, } super().__init__(self.query, variables=variables) - def process_response(self, response) -> List[UserSnapshot]: - return UserChangelogReport(**super().process_response(response)["userChangelogReport"]) + def process_response(self, response: "Payload") -> "UserChangelogReport": + return UserChangelogReport( + **super().parse_payload(response)["userChangelogReport"] + ) diff --git a/indico/queries/workflow.py b/indico/queries/workflow.py index 73c03dde..a95a9ecd 100644 --- a/indico/queries/workflow.py +++ b/indico/queries/workflow.py @@ -1,16 +1,20 @@ import io -import json import tempfile -from typing import Dict, List, Union +from typing import TYPE_CHECKING from indico.client.request import Delay, GraphQLRequest, RequestChain from indico.errors import IndicoError, IndicoInputError from indico.queries.storage import UploadBatched, UploadDocument -from indico.types import SUBMISSION_RESULT_VERSIONS, Job, Submission, Workflow +from indico.types import SUBMISSION_RESULT_VERSIONS, Submission, Workflow from indico.types.utils import cc_to_snake, snake_to_cc +if TYPE_CHECKING: # pragma: no cover + from typing import Any, ClassVar, Dict, Iterator, List, Optional, Union -class ListWorkflows(GraphQLRequest): + from indico.typing import Payload + + +class ListWorkflows(GraphQLRequest["List[Workflow]"]): """ List all workflows visible to authenticated user @@ -79,9 +83,9 @@ class ListWorkflows(GraphQLRequest): def __init__( self, *, - dataset_ids: List[int] = None, - workflow_ids: List[int] = None, - limit=100, + dataset_ids: "Optional[List[int]]" = None, + workflow_ids: "Optional[List[int]]" = None, + limit: int = 100, ): super().__init__( self.query, @@ -92,14 +96,14 @@ def __init__( }, ) - def process_response(self, response) -> List[Workflow]: + def process_response(self, response: "Payload") -> "List[Workflow]": return [ Workflow(**w) - for w in super().process_response(response)["workflows"]["workflows"] + for w in super().parse_payload(response)["workflows"]["workflows"] ] -class GetWorkflow(ListWorkflows): +class GetWorkflow(GraphQLRequest["Workflow"]): """ Query for Workflow by id @@ -111,13 +115,16 @@ class GetWorkflow(ListWorkflows): """ def __init__(self, workflow_id: int): - super().__init__(workflow_ids=[workflow_id]) + super().__init__( + ListWorkflows.query, + variables={"datasetIds": None, "workflowIds": [workflow_id], "limit": 100}, + ) - def process_response(self, response) -> Workflow: - return super().process_response(response)[0] + def process_response(self, response: "Payload") -> "Workflow": + return Workflow(**super().parse_payload(response)["workflows"]["workflows"][0]) -class _ToggleReview(GraphQLRequest): +class _ToggleReview(GraphQLRequest["Workflow"]): toggle = "enableReview" query_name = "toggleWorkflowReview" query = """ @@ -139,8 +146,8 @@ def __init__(self, workflow_id: int, enable_review: bool): variables={"workflowId": workflow_id, "reviewState": enable_review}, ) - def process_response(self, response) -> Workflow: - return Workflow(**super().process_response(response)[self.query_name]) + def process_response(self, response: "Payload") -> "Workflow": + return Workflow(**super().parse_payload(response)[self.query_name]) class _ToggleAutoReview(_ToggleReview): @@ -148,7 +155,7 @@ class _ToggleAutoReview(_ToggleReview): query_name = "toggleWorkflowAutoReview" -class UpdateWorkflowSettings(RequestChain): +class UpdateWorkflowSettings(RequestChain["Workflow"]): """ Mutation to toggle review and auto-review on a workflow @@ -165,9 +172,9 @@ class UpdateWorkflowSettings(RequestChain): def __init__( self, - workflow: Union[Workflow, int], - enable_review: bool = None, - enable_auto_review: bool = None, + workflow: "Union[Workflow, int]", + enable_review: "Optional[bool]" = None, + enable_auto_review: "Optional[bool]" = None, ): self.workflow_id = workflow.id if isinstance(workflow, Workflow) else workflow if enable_review is None and enable_auto_review is None: @@ -176,14 +183,14 @@ def __init__( self.enable_review = enable_review self.enable_auto_review = enable_auto_review - def requests(self): + def requests(self) -> "Iterator[_ToggleReview]": if self.enable_review is not None: yield _ToggleReview(self.workflow_id, self.enable_review) if self.enable_auto_review is not None: yield _ToggleAutoReview(self.workflow_id, self.enable_auto_review) -class _WorkflowSubmission(GraphQLRequest): +class _WorkflowSubmission(GraphQLRequest["Union[List[Submission], List[int]]"]): query = """ mutation workflowSubmissionMutation({signature}) {{ {mutation_name}({args}) {{ @@ -230,14 +237,14 @@ class _WorkflowSubmission(GraphQLRequest): def __init__( self, detailed_response: bool, - **kwargs, + **kwargs: "Any", ): self.workflow_id = kwargs["workflow_id"] # construct mutation signature and args based on provided kwargs to ensure # backwards-compatible graphql calls # - # inputFiles, bundle, and resultVersion only avaliable on IPA 4.9.0+ + # inputFiles, bundle, and resultVersion only available on IPA 4.9.0+ subq = ( self.files_subquery if kwargs.get("bundle") or kwargs.get("result_version") @@ -249,11 +256,13 @@ def __init__( else self.query ) - args = [ + args_list: "List[str]" = [ _arg for _arg in self.mutation_args.keys() if kwargs.get(cc_to_snake(_arg)) ] - signature = ",".join(f"${_arg}: {self.mutation_args[_arg]}" for _arg in args) - args = ",".join(f"{_arg}: ${_arg}" for _arg in args) + signature: str = ",".join( + f"${_arg}: {self.mutation_args[_arg]}" for _arg in args_list + ) + args: str = ",".join(f"{_arg}: ${_arg}" for _arg in args_list) super().__init__( query=q.format( @@ -262,15 +271,17 @@ def __init__( variables={snake_to_cc(var): val for var, val in kwargs.items()}, ) - def process_response(self, response): - response = super().process_response(response)[self.mutation_name] + def process_response( + self, response: "Payload" + ) -> "Union[List[Submission], List[int]]": + response = super().parse_payload(response)[self.mutation_name] if "submissions" in response: return [Submission(**s) for s in response["submissions"]] if not response["submissionIds"]: raise IndicoError(f"Failed to submit to workflow {self.workflow_id}") else: - return response["submissionIds"] - return [Job(id=job_id) for job_id in response["jobIds"]] + sub_ids: "List[int]" = response["submissionIds"] + return sub_ids class _WorkflowUrlSubmission(_WorkflowSubmission): @@ -279,7 +290,7 @@ class _WorkflowUrlSubmission(_WorkflowSubmission): del mutation_args["files"] -class WorkflowSubmission(RequestChain): +class WorkflowSubmission(RequestChain["Union[List[Submission], List[int]]"]): f""" Submit files to a workflow for processing. One of `files`, `urls`, `stream`, or `text` is required. @@ -306,17 +317,17 @@ class WorkflowSubmission(RequestChain): """ - detailed_response = False + detailed_response: "ClassVar[bool]" = False def __init__( self, workflow_id: int, - files: List[str] = None, - urls: List[str] = None, + files: "Optional[List[str]]" = None, + urls: "Optional[List[str]]" = None, submission: bool = True, bundle: bool = False, - result_version: str = None, - streams: Dict[str, io.BufferedIOBase] = None, + result_version: "Optional[str]" = None, + streams: "Optional[Dict[str, io.BufferedIOBase]]" = None, text: str = "", batch_size: int = 10, ): @@ -330,8 +341,6 @@ def __init__( if streams is not None: self.streams = streams.copy() self.has_streams = True - else: - self.streams = None self.text = text self.batch_size = batch_size if not submission: @@ -353,7 +362,9 @@ def __init__( "Only one of 'files' or 'streams', 'urls', or 'text' may be specified" ) - def requests(self): + def requests( + self, + ) -> "Iterator[Union[UploadBatched, UploadDocument, _WorkflowSubmission]]": if self.files: yield UploadBatched(files=self.files, batch_size=self.batch_size) yield _WorkflowSubmission( @@ -418,15 +429,15 @@ class WorkflowSubmissionDetailed(WorkflowSubmission): """ - detailed_response = True + detailed_response: "ClassVar[bool]" = True def __init__( self, workflow_id: int, - files: List[str] = None, - urls: List[str] = None, + files: "Optional[List[str]]" = None, + urls: "Optional[List[str]]" = None, bundle: bool = False, - result_version: str = None, + result_version: "Optional[str]" = None, ): super().__init__( workflow_id, @@ -438,7 +449,7 @@ def __init__( ) -class _AddDataToWorkflow(GraphQLRequest): +class _AddDataToWorkflow(GraphQLRequest["Workflow"]): query = """ mutation addDataToWorkflow($workflowId: Int!) { addDataToWorkflow(workflowId: $workflowId){ @@ -457,13 +468,13 @@ def __init__(self, workflow_id: int): variables={"workflowId": workflow_id}, ) - def process_response(self, response) -> Workflow: + def process_response(self, response: "Payload") -> "Workflow": return Workflow( - **super().process_response(response)["addDataToWorkflow"]["workflow"] + **super().parse_payload(response)["addDataToWorkflow"]["workflow"] ) -class AddDataToWorkflow(RequestChain): +class AddDataToWorkflow(RequestChain["Workflow"]): """ Mutation to update data in a workflow, presumably after new data is added to the dataset. @@ -482,16 +493,16 @@ def __init__(self, workflow_id: int, wait: bool = False): self.workflow_id = workflow_id self.wait = wait - def requests(self): + def requests(self) -> "Iterator[Union[_AddDataToWorkflow, Delay, GetWorkflow]]": yield _AddDataToWorkflow(self.workflow_id) if self.wait: while self.previous.status != "COMPLETE": - yield GetWorkflow(workflow_id=self.workflow_id) yield Delay() + yield GetWorkflow(workflow_id=self.workflow_id) -class CreateWorkflow(GraphQLRequest): +class CreateWorkflow(GraphQLRequest["Workflow"]): """ Mutation to create workflow given an existing dataset. @@ -542,13 +553,11 @@ def __init__(self, dataset_id: int, name: str): variables={"datasetId": dataset_id, "name": name}, ) - def process_response(self, response) -> Workflow: - return Workflow( - **super().process_response(response)["createWorkflow"]["workflow"] - ) + def process_response(self, response: "Payload") -> "Workflow": + return Workflow(**super().parse_payload(response)["createWorkflow"]["workflow"]) -class DeleteWorkflow(GraphQLRequest): +class DeleteWorkflow(GraphQLRequest[bool]): """ Mutation to delete workflow given workflow id. Note that this operation includes deleting all components and models associated with this workflow. @@ -568,5 +577,6 @@ class DeleteWorkflow(GraphQLRequest): def __init__(self, workflow_id: int): super().__init__(self.query, variables={"workflowId": workflow_id}) - def process_response(self, response) -> bool: - return super().process_response(response)["deleteWorkflow"]["success"] \ No newline at end of file + def process_response(self, response: "Payload") -> bool: + status: bool = super().parse_payload(response)["deleteWorkflow"]["success"] + return status diff --git a/indico/queries/workflow_components.py b/indico/queries/workflow_components.py index 205a9924..bc9fffb0 100644 --- a/indico/queries/workflow_components.py +++ b/indico/queries/workflow_components.py @@ -1,4 +1,4 @@ -from typing import List +from typing import TYPE_CHECKING import jsons @@ -11,8 +11,13 @@ Workflow, ) +if TYPE_CHECKING: # pragma: no cover + from typing import Iterator, List, Optional, Union -class _AddWorkflowComponent(GraphQLRequest): + from indico.typing import AnyDict, Payload + + +class _AddWorkflowComponent(GraphQLRequest["Workflow"]): query = """mutation addWorkflowComponent($afterComponentId:Int, $afterComponentLinkId: Int, $component: JSONString!, $workflowId: Int!){ addWorkflowComponent(afterComponentId: $afterComponentId, component: $component, @@ -65,10 +70,10 @@ class _AddWorkflowComponent(GraphQLRequest): def __init__( self, - after_component_id: int, - after_component_link: int, + after_component_id: "Optional[int]", + after_component_link: "Optional[int]", workflow_id: int, - component: dict, + component: "AnyDict", ): super().__init__( self.query, @@ -80,13 +85,13 @@ def __init__( }, ) - def process_response(self, response) -> Workflow: + def process_response(self, response: "Payload") -> "Workflow": return Workflow( - **super().process_response(response)["addWorkflowComponent"]["workflow"] + **super().parse_payload(response)["addWorkflowComponent"]["workflow"] ) -class AddLinkedLabelComponent(RequestChain): +class AddLinkedLabelComponent(RequestChain["Workflow"]): """ Adds a linked label transformer that groups together labels @@ -104,10 +109,9 @@ def __init__( workflow_id: int, labelset_id: int, model_group_id: int, - groups: List[LinkedLabelGroup], - after_component_link_id: int = None, + groups: "List[LinkedLabelGroup]", + after_component_link_id: "Optional[int]" = None, ): - self.workflow_id = workflow_id self.after_component_id = after_component_id self.after_component_link_id = after_component_link_id @@ -120,7 +124,7 @@ def __init__( }, } - def __groups_to_json(self, group: LinkedLabelGroup): + def __groups_to_json(self, group: "LinkedLabelGroup") -> "AnyDict": return { "name": group.name, "strategy": group.strategy.name.lower(), @@ -128,7 +132,7 @@ def __groups_to_json(self, group: LinkedLabelGroup): "strategy_settings": group.strategy_settings, } - def requests(self): + def requests(self) -> "Iterator[_AddWorkflowComponent]": yield _AddWorkflowComponent( after_component_id=self.after_component_id, after_component_link=self.after_component_link_id, @@ -137,7 +141,7 @@ def requests(self): ) -class AddContentLengthFilterComponent(RequestChain): +class AddContentLengthFilterComponent(RequestChain["Workflow"]): """ Adds a content length filter. @@ -152,9 +156,9 @@ def __init__( self, workflow_id: int, after_component_id: int, - after_component_link_id: int = None, - minimum: int = None, - maximum: int = None, + after_component_link_id: "Optional[int]" = None, + minimum: "Optional[int]" = None, + maximum: "Optional[int]" = None, ): self.workflow_id = workflow_id self.after_component_id = after_component_id @@ -166,7 +170,7 @@ def __init__( "config": {"minimum": minimum, "maximum": maximum}, } - def requests(self): + def requests(self) -> "Iterator[_AddWorkflowComponent]": yield _AddWorkflowComponent( after_component_id=self.after_component_id, after_component_link=self.after_component_link_id, @@ -175,7 +179,7 @@ def requests(self): ) -class AddLinkClassificationComponent(RequestChain): +class AddLinkClassificationComponent(RequestChain["Workflow"]): """ Adds a link classification model component with filtered classes. @@ -192,9 +196,9 @@ def __init__( workflow_id: int, after_component_id: int, model_group_id: int, - filtered_classes: List[List[str]], - labels: str = None, - after_component_link_id: int = None, + filtered_classes: "List[List[str]]", + labels: "Optional[str]" = None, + after_component_link_id: "Optional[int]" = None, ): self.workflow_id = workflow_id self.after_component_id = after_component_id @@ -208,7 +212,7 @@ def __init__( }, } - def requests(self): + def requests(self) -> "Iterator[_AddWorkflowComponent]": yield _AddWorkflowComponent( after_component_id=self.after_component_id, after_component_link=self.after_component_link_id, @@ -217,7 +221,7 @@ def requests(self): ) -class AddModelGroupComponent(GraphQLRequest): +class AddModelGroupComponent(GraphQLRequest["Workflow"]): """ Adds a new model group to a workflow, optionally with a customized questionnaire. Available on 5.0+ only. @@ -308,15 +312,14 @@ def __init__( dataset_id: int, name: str, source_column_id: int, - after_component_id: int = None, - after_link_id: int = None, - labelset_column_id: int = None, - new_labelset_args: NewLabelsetArguments = None, - new_questionnaire_args: NewQuestionnaireArguments = None, - model_training_options: str = None, - model_type: str = None, + after_component_id: "Optional[int]" = None, + after_link_id: "Optional[int]" = None, + labelset_column_id: "Optional[int]" = None, + new_labelset_args: "Optional[NewLabelsetArguments]" = None, + new_questionnaire_args: "Optional[NewQuestionnaireArguments]" = None, + model_training_options: "Optional[Union[str, AnyDict]]" = None, + model_type: "Optional[str]" = None, ): - if labelset_column_id is not None and new_labelset_args is not None: raise IndicoInputError( "Cannot define both labelset_column_id and new_labelset_args, must be one " @@ -327,8 +330,12 @@ def __init__( "Must define one of either labelset_column_id or new_labelset_args." ) + model_training_options_json: "Optional[str]" = None if model_training_options: - model_training_options = jsons.dumps(model_training_options) + if isinstance(model_training_options, dict): + model_training_options_json = jsons.dumps(model_training_options) + else: + model_training_options_json = model_training_options super().__init__( self.query, @@ -340,7 +347,7 @@ def __init__( "labelsetColumnId": labelset_column_id, "afterComponentId": after_component_id, "afterLinkId": after_link_id, - "modelTrainingOptions": model_training_options, + "modelTrainingOptions": model_training_options_json, "modelType": model_type, "newLabelsetArgs": ( self.__labelset_to_json(new_labelset_args) @@ -355,7 +362,7 @@ def __init__( }, ) - def __labelset_to_json(self, labelset: NewLabelsetArguments): + def __labelset_to_json(self, labelset: "NewLabelsetArguments") -> "AnyDict": return { "name": labelset.name, "numLabelersRequired": labelset.num_labelers_required, @@ -364,7 +371,9 @@ def __labelset_to_json(self, labelset: NewLabelsetArguments): "targetNames": labelset.target_names, } - def __questionnaire_to_json(self, questionnaire: NewQuestionnaireArguments): + def __questionnaire_to_json( + self, questionnaire: "NewQuestionnaireArguments" + ) -> "AnyDict": return { "instructions": questionnaire.instructions, "forceTextMode": questionnaire.force_text_mode, @@ -372,13 +381,13 @@ def __questionnaire_to_json(self, questionnaire: NewQuestionnaireArguments): "users": questionnaire.users, } - def process_response(self, response) -> Workflow: + def process_response(self, response: "Payload") -> "Workflow": return Workflow( - **super().process_response(response)["addModelGroupComponent"]["workflow"] + **super().parse_payload(response)["addModelGroupComponent"]["workflow"] ) -class DeleteWorkflowComponent(GraphQLRequest): +class DeleteWorkflowComponent(GraphQLRequest["Workflow"]): """ Deletes a component from a workflow. If the component has an associated model, the model is deleted as well. Available on 5.3+ only. @@ -444,7 +453,7 @@ def __init__(self, workflow_id: int, component_id: int): variables={"workflowId": workflow_id, "componentId": component_id}, ) - def process_response(self, response) -> Workflow: + def process_response(self, response: "Payload") -> "Workflow": return Workflow( - **super().process_response(response)["deleteWorkflowComponent"]["workflow"] + **super().parse_payload(response)["deleteWorkflowComponent"]["workflow"] ) diff --git a/indico/queries/workflow_metrics.py b/indico/queries/workflow_metrics.py index 832d589b..ac5924b1 100644 --- a/indico/queries/workflow_metrics.py +++ b/indico/queries/workflow_metrics.py @@ -1,9 +1,13 @@ from datetime import datetime -from indico.errors import IndicoInputError +from typing import TYPE_CHECKING, List + from indico.client.request import GraphQLRequest +from indico.errors import IndicoInputError from indico.types import BaseType -from indico.types.workflow_metrics import WorkflowMetricsOptions, WorkflowMetrics -from typing import List +from indico.types.workflow_metrics import WorkflowMetrics, WorkflowMetricsOptions + +if TYPE_CHECKING: # pragma: no cover + from indico.typing import Payload class _WorkflowMetric(BaseType): @@ -14,7 +18,7 @@ class _TopWorkflowMetric(BaseType): workflows: List[_WorkflowMetric] -class GetWorkflowMetrics(GraphQLRequest): +class GetWorkflowMetrics(GraphQLRequest["List[WorkflowMetrics]"]): """ Requests detailed workflow metric data, including daily and total submission counts, review queue counts, and straight through processing details. Query can be configured to include only specific metrics by passing in one of WorkflowOptions for SUBMISSIONS, REVIEW, or STRAIGHT_THROUGH_PROCESSING. @@ -26,10 +30,11 @@ class GetWorkflowMetrics(GraphQLRequest): workflow_ids (List[int]): ids of specific workflows to query. """ + __MAP_WORKFLOW_KEYS = { WorkflowMetricsOptions.SUBMISSIONS: """ firstSubmittedDate - submissions { + submissions { aggregate { submitted completed @@ -98,7 +103,7 @@ class GetWorkflowMetrics(GraphQLRequest): reviewDenom autoReviewDenom reviewStpPct - autoReviewStpPct + autoReviewStpPct } daily { date @@ -115,7 +120,7 @@ class GetWorkflowMetrics(GraphQLRequest): """, WorkflowMetricsOptions.TIME_ON_TASK: """ timeOnTask { - aggregate { + aggregate { avgMinsPerDoc avgMinsPerDocReview avgMinsPerDocExceptions @@ -127,8 +132,7 @@ class GetWorkflowMetrics(GraphQLRequest): avgMinsPerDocExceptions } } - """ - + """, } query = """ query ($workflowIds: [Int]!, $startDate: Date, $endDate:Date) { @@ -144,25 +148,43 @@ class GetWorkflowMetrics(GraphQLRequest): } """ - def __init__(self, options: List[WorkflowMetricsOptions], start_date: datetime, end_date: datetime, - workflow_ids: List[int]): + def __init__( + self, + options: "List[WorkflowMetricsOptions]", + start_date: datetime, + end_date: datetime, + workflow_ids: "List[int]", + ): self.query = self.__map_query_values(options) if workflow_ids is None or start_date is None: raise IndicoInputError("Must specify date and workflow id") if end_date is None: end_date = datetime.now() - super().__init__(self.query, variables={"startDate": start_date.strftime('%Y-%m-%d'), - "endDate": end_date.strftime('%Y-%m-%d'), "workflowIds": workflow_ids}) - def process_response(self, response) -> List[WorkflowMetrics]: - list_of_metrics = _TopWorkflowMetric(**super().process_response(response)["workflows"]).workflows - return list(map(lambda x: x.metrics, list_of_metrics)) + super().__init__( + self.query, + variables={ + "startDate": start_date.strftime("%Y-%m-%d"), + "endDate": end_date.strftime("%Y-%m-%d"), + "workflowIds": workflow_ids, + }, + ) - def __map_query_values(self, options: List[WorkflowMetricsOptions]): - daily = ' ' + def __map_query_values(self, options: "List[WorkflowMetricsOptions]") -> str: + daily = " " if len(options) < 1: - daily = ' '.join([self.__MAP_WORKFLOW_KEYS[a] for a in self.__MAP_WORKFLOW_KEYS.keys()]) + daily = " ".join( + [self.__MAP_WORKFLOW_KEYS[a] for a in self.__MAP_WORKFLOW_KEYS.keys()] + ) else: - daily = ' '.join([self.__MAP_WORKFLOW_KEYS[a] for a in options]) - query = self.query.replace("__QUERY_OPTS__", daily) + daily = " ".join([self.__MAP_WORKFLOW_KEYS[a] for a in options]) + + query: str = self.query.replace("__QUERY_OPTS__", daily) return query + + def process_response(self, response: "Payload") -> "List[WorkflowMetrics]": + list_of_metrics = _TopWorkflowMetric( + **super().parse_payload(response)["workflows"] + ).workflows + + return [x.metrics for x in list_of_metrics] diff --git a/indico/types/__init__.py b/indico/types/__init__.py index d434b223..eb11c3a3 100644 --- a/indico/types/__init__.py +++ b/indico/types/__init__.py @@ -1,11 +1,15 @@ -from .datafile import * -from .dataset import * -from .jobs import * -from .model_group import * -from .model import * -from .output_file import * -from .submission_file import * -from .submission import * -from .workflow import * -from .questionnaire import Example -from .integration import * \ No newline at end of file +# the order of these imports is super important for preventing import cycles +# isort: skip_file + +from .base import * # noqa +from .datafile import * # noqa +from .dataset import * # noqa +from .integration import * # noqa +from .jobs import * # noqa +from .model import * # noqa +from .model_group import * # noqa +from .output_file import * # noqa +from .questionnaire import Example # noqa +from .submission_file import * # noqa +from .submission import * # noqa +from .workflow import * # noqa diff --git a/indico/types/base.py b/indico/types/base.py index 95479a6f..dcc5212f 100644 --- a/indico/types/base.py +++ b/indico/types/base.py @@ -1,70 +1,50 @@ -import inspect import json from datetime import datetime -from typing import Any, List - -from indico.types.utils import cc_to_snake - -generic_alias_cls = type(List[Any]) - - -def list_subtype(cls): - if not issubclass(type(cls), generic_alias_cls): - return None - origin = getattr(cls, "__origin__", getattr(cls, "__extra__", None)) - if issubclass(origin, list) and cls.__args__: - return cls.__args__[0] - return None - - -def valid_type(v): - if v is None: - return False - - return ( - (inspect.isclass(v) and issubclass(v, BaseType)) - or v in [str, int, float, bool, JSONType, datetime] - or valid_type(list_subtype(v)) +from typing import TYPE_CHECKING, Any, cast, final + +from pydantic import ( + AliasGenerator, + BaseModel, + BeforeValidator, + ConfigDict, + field_validator, +) +from pydantic.alias_generators import to_camel +from typing_extensions import Annotated + +if TYPE_CHECKING: # pragma: no cover + from pydantic import ValidationInfo + + +class BaseType(BaseModel): + model_config = ConfigDict( + # alias fields so that they can be provided using their GraphQL + # field names (camel case) or by their python names (snake case) + alias_generator=AliasGenerator(validation_alias=to_camel), + populate_by_name=True, + # by default, pydantic warns if a model defines a field prefixed with 'model_'. + # we have a lot of those, so we can omit that namespace from the warnings. + # pydantic will still throw an error if we define a field that clashes with an + # built-in internal one, though. + protected_namespaces=(), ) - -class BaseType: - def _get_attrs(self): - classes = inspect.getmro(self.__class__) - props = dict() - for c in classes: - if not getattr(c, "__annotations__", None): - continue - props.update({k: v for k, v in c.__annotations__.items() if valid_type(v)}) - return props - - def __init__(self, **kwargs): - attrs = self._get_attrs() - for k, v in kwargs.items(): - k = cc_to_snake(k) - if k in attrs: - attr_type = attrs[k] - if ( - v is not None - and inspect.isclass(attr_type) - and issubclass(attr_type, BaseType) - ): - v = attrs[k](**v) - - if attr_type == JSONType and v is not None: - v = json.loads(v) - - if attr_type == datetime: - try: - v = datetime.fromtimestamp(float(v)) - except ValueError: - v = datetime.fromisoformat(v) - - subtype = list_subtype(attr_type) - if subtype and issubclass(subtype, BaseType): - v = [subtype(**x) for x in v] - setattr(self, k, v) - - -class JSONType: - pass + @field_validator("*", mode="before") + @classmethod + @final + def _validate_legacy_datetime(cls, v: Any, info: "ValidationInfo") -> Any: + # pydantic forces a UTC timezone when validating datetimes provided in epoch + # time. this isn't backwards compat, so we override the built-in validation for + # all fields defined with a datetime annotation with our legacy parsing + if cls.model_fields[cast(str, info.field_name)].annotation == datetime: + try: + v = datetime.fromtimestamp(float(v)) + except ValueError: + v = datetime.fromisoformat(v) + + return v + + +# this really should be dict[str, Any], but that breaks static typing since we'd supply +# the field as a string but the model's type would be annotated as a dict +JSONType = Annotated[Any, BeforeValidator(json.loads)] diff --git a/indico/types/custom_blueprint.py b/indico/types/custom_blueprint.py index 42db3a0c..36e02f44 100644 --- a/indico/types/custom_blueprint.py +++ b/indico/types/custom_blueprint.py @@ -1,4 +1,5 @@ import typing as t + from indico.types import BaseType @@ -11,4 +12,5 @@ class TaskBlueprint(BaseType): footer: str tags: t.List[str] enabled: bool - config: t.Dict + # should this be JSONType? + config: t.Dict # type: ignore diff --git a/indico/types/dataset.py b/indico/types/dataset.py index 04ff426a..721b921e 100644 --- a/indico/types/dataset.py +++ b/indico/types/dataset.py @@ -1,9 +1,9 @@ from enum import Enum from typing import List +from indico.errors import IndicoInputError from indico.types.base import BaseType from indico.types.datafile import Datafile -from indico.errors import IndicoInputError class DataColumn(BaseType): @@ -57,12 +57,16 @@ class Dataset(BaseType): def labelset_by_name(self, name: str) -> LabelSet: if name not in [lab.name for lab in self.labelsets]: - raise IndicoInputError(f"No labelset found for {name}. Current labelset names include {[lab.name for lab in self.labelsets]}.") + raise IndicoInputError( + f"No labelset found for {name}. Current labelset names include {[lab.name for lab in self.labelsets]}." + ) return next(lab for lab in self.labelsets if lab.name == name) def datacolumn_by_name(self, name: str) -> DataColumn: if name not in [datacol.name for datacol in self.datacolumns]: - raise IndicoInputError(f"No datacolumn found for {name}. Current datacolumn names include {[datacol.name for datacol in self.datacolumns]}.") + raise IndicoInputError( + f"No datacolumn found for {name}. Current datacolumn names include {[datacol.name for datacol in self.datacolumns]}." + ) return next(datacol for datacol in self.datacolumns if datacol.name == name) @@ -70,15 +74,18 @@ class TableReadOrder(Enum): ROW = 0 COLUMN = 1 + class OcrEngine(Enum): """ Enum representing available OCR engines. """ + OMNIPAGE = 0 READAPI = 1 READAPI_V2 = 2 READAPI_TABLES_V1 = 3 + class OmnipageOcrOptionsInput(BaseType): """ Omnipage specific OCR options for dataset creation. @@ -95,6 +102,7 @@ class OmnipageOcrOptionsInput(BaseType): table_read_order(TableReadOrder): Read table by row or column. """ + auto_rotate: bool single_column: bool upscale_images: bool @@ -105,6 +113,7 @@ class OmnipageOcrOptionsInput(BaseType): native_pdf: bool table_read_order: TableReadOrder + class ReadApiOcrOptionsInput(BaseType): """ Read API OCR options. @@ -115,20 +124,23 @@ class ReadApiOcrOptionsInput(BaseType): upscale_images(bool): Scale up low resolution images. languages(List[str]): List of languages to use. """ + auto_rotate: bool single_column: bool upscale_images: bool languages: List[str] + class OcrInputLanguage(BaseType): name: str code: str -class OcrOptionsInput(): + +class OcrOptionsInput: """ Input options for OCR engine. """ + ocr_engine: OcrEngine omnipage_options: OmnipageOcrOptionsInput readapi_options: ReadApiOcrOptionsInput - \ No newline at end of file diff --git a/indico/types/document_report.py b/indico/types/document_report.py index 4fa19584..3db2a588 100644 --- a/indico/types/document_report.py +++ b/indico/types/document_report.py @@ -1,7 +1,6 @@ from typing import List -from indico.types import BaseType -from indico.types import SubmissionFile +from indico.types import BaseType, SubmissionFile class DocumentReport(BaseType): @@ -9,6 +8,7 @@ class DocumentReport(BaseType): A report about a submission """ + dataset_id: int workflow_id: int submission_id: int @@ -22,4 +22,3 @@ class DocumentReport(BaseType): retrieved: bool input_files: List[SubmissionFile] files_deleted: bool - diff --git a/indico/types/export.py b/indico/types/export.py index 43d2c67c..f22521f3 100644 --- a/indico/types/export.py +++ b/indico/types/export.py @@ -1,5 +1,5 @@ -from typing import List from enum import Enum +from typing import List from indico.types.base import BaseType diff --git a/indico/types/integration.py b/indico/types/integration.py index d67d13f5..90e7db60 100644 --- a/indico/types/integration.py +++ b/indico/types/integration.py @@ -1,27 +1,32 @@ from datetime import datetime from enum import Enum + from indico.types import BaseType class IntegrationType(Enum): EXCHANGE = 1 + class ExchangeIntegrationCredentials(BaseType): """ Credentials needed to connect a Microsoft Exchange server to an Indico workflow using """ + client_id: str client_secret: str tenant_id: str + class ExchangeIntegrationConfiguration(BaseType): """ Configuration options available for an integration with Microsoft Exchange """ + user_id: str folder_id: str - folder_name: str = None - filters: str = None + folder_name: str + filters: str class Integration(BaseType): @@ -29,7 +34,7 @@ class Integration(BaseType): An integration pulls document from a third-party data source and submits them to a workflow Args: - id(int): ID of the integration + id(int): ID of the integration workflow_id(int): ID of the workflow to submit to. enabled(bool): Whether Indico is currently sending documents from this datasource to the workflow. created_at(datetime): When this integration was created. @@ -42,5 +47,6 @@ class Integration(BaseType): created_at: datetime integration_type: IntegrationType + class ExchangeIntegration(Integration): config: ExchangeIntegrationConfiguration diff --git a/indico/types/jobs.py b/indico/types/jobs.py index c4982612..93bba7e2 100644 --- a/indico/types/jobs.py +++ b/indico/types/jobs.py @@ -1,5 +1,10 @@ # -*- coding: utf-8 -*- + +from typing import Optional + +from pydantic import Field + from indico.types.base import BaseType, JSONType @@ -17,13 +22,7 @@ class Job(BaseType): ready (bool): """ - id: int - status: str - result: JSONType - ready: bool - - def __init__(self, **kwargs): - if "jobId" in kwargs: - kwargs["id"] = kwargs["jobId"] - del kwargs["jobId"] - super().__init__(**kwargs) + id: int = Field(validation_alias="jobId") + status: Optional[str] = None + result: Optional[JSONType] = None + ready: Optional[bool] = None diff --git a/indico/types/model.py b/indico/types/model.py index 3cc672ff..305c1d9d 100644 --- a/indico/types/model.py +++ b/indico/types/model.py @@ -1,6 +1,3 @@ -import json -from typing import Any, Dict, Optional - from indico.types.base import BaseType, JSONType diff --git a/indico/types/model_group.py b/indico/types/model_group.py index 0de0898e..2352624c 100644 --- a/indico/types/model_group.py +++ b/indico/types/model_group.py @@ -33,6 +33,7 @@ class ModelGroup(BaseType): class ModelTaskType(Enum): """A list of valid task types for a model group.""" + CLASSIFICATION = 1 FORM_EXTRACTION = 2 OBJECT_DETECTION = 3 @@ -54,16 +55,16 @@ class ModelType(Enum): class NewQuestionnaireArguments(BaseType): """instructions: String -Questionnaire instructions + Questionnaire instructions -forceTextMode: Boolean = false -Always use Text Labeling UI + forceTextMode: Boolean = false + Always use Text Labeling UI -showPredictions: Boolean = true -Show predictions at the global level + showPredictions: Boolean = true + Show predictions at the global level -users: [Int] -User IDs to add to the questionnaire""" + users: [Int] + User IDs to add to the questionnaire""" instructions: str force_text_mode: bool = False @@ -71,9 +72,15 @@ class NewQuestionnaireArguments(BaseType): users: List[int] -class NewLabelsetArguments(): - def __init__(self, name: str, task_type: ModelTaskType, - target_names: List[str], datacolumn_id: int, num_labelers_required: int = 1): +class NewLabelsetArguments: + def __init__( + self, + name: str, + task_type: ModelTaskType, + target_names: List[str], + datacolumn_id: int, + num_labelers_required: int = 1, + ): self.name = name self.num_labelers_required = num_labelers_required self.task_type = task_type diff --git a/indico/types/model_metrics.py b/indico/types/model_metrics.py index 55bbd1eb..0825de4e 100644 --- a/indico/types/model_metrics.py +++ b/indico/types/model_metrics.py @@ -1,4 +1,5 @@ from typing import List + from indico.types.base import BaseType diff --git a/indico/types/output_file.py b/indico/types/output_file.py index 36f1fc8f..2b2e1aee 100644 --- a/indico/types/output_file.py +++ b/indico/types/output_file.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime from indico.types import BaseType diff --git a/indico/types/questionnaire.py b/indico/types/questionnaire.py index 9fd67e02..e536468b 100644 --- a/indico/types/questionnaire.py +++ b/indico/types/questionnaire.py @@ -1,4 +1,5 @@ from typing import List + from indico.types.base import BaseType diff --git a/indico/types/submission.py b/indico/types/submission.py index f45ef22c..d79304e5 100644 --- a/indico/types/submission.py +++ b/indico/types/submission.py @@ -1,7 +1,7 @@ -import datetime -from typing import Optional +from datetime import datetime +from typing import List -from indico.types import BaseType, List, JSONType +from indico.types import BaseType, JSONType from . import OutputFile, SubmissionFile @@ -65,6 +65,7 @@ class SubmissionReview(BaseType): review_type: str notes: str + class SubmissionReviewFull(BaseType): f""" Information about a submission's Reviews. Includes changes diff --git a/indico/types/submission_file.py b/indico/types/submission_file.py index e9109307..c2c47ad4 100644 --- a/indico/types/submission_file.py +++ b/indico/types/submission_file.py @@ -1,3 +1,5 @@ +from typing import Any, Dict + from indico.types import BaseType @@ -25,3 +27,4 @@ class SubmissionFile(BaseType): filetype: str file_size: int num_pages: int + meta: Dict[str, Any] diff --git a/indico/types/user_metrics.py b/indico/types/user_metrics.py index 1acf1ff2..76ed1df4 100644 --- a/indico/types/user_metrics.py +++ b/indico/types/user_metrics.py @@ -1,27 +1,29 @@ -import datetime +from datetime import datetime +from typing import List from indico.types.base import BaseType -from typing import List class AppRoles(BaseType): """Info about roles. Name and how many users have this role. - + Attributes: role (str): A role name count (int): Count of how many users have said role. """ + role: str count: int class UserDatasets(BaseType): """Dataset Id and roles assigned to user in that dataset. - + Attributes: dataset_id (int): The id of a particular dataset role (str): A dataset assigned to said dataset. """ + dataset_id: int role: str @@ -34,13 +36,14 @@ class UserSummaryCounts(BaseType): enabled (int): How many user accounts are enabled. disabled (int): How many user accounts are disabled. """ + enabled: int disabled: int class UserSnapshot(BaseType): """Individual information about a user and their dataset access - + Attributes: id (int): The user's id. name (str): The user's name. @@ -50,6 +53,7 @@ class UserSnapshot(BaseType): roles (List[str]): List of roles assigned to this user. datasets (List[UserDatasets]): List of datasets this user can access. """ + id: int name: str email: str @@ -61,29 +65,31 @@ class UserSnapshot(BaseType): class UserSummary(BaseType): """Summary data on users and app roles - + Attributes: users (UserSummarycounts): Counts of enabled/disabled users. app_roles (List[AppRoles]): List of all available user roles. """ + users: UserSummaryCounts app_roles: List[AppRoles] class DatasetRole(BaseType): """Dataset role information - + Attributes: dataset_id (int): Id of a particular dataset. role (str): Role name which has access to this dataset. """ + dataset_id: int role: str class UserChangelog(BaseType): """Log entry of a change made to a user's permission - + Attributes: id (str): Id of the log entry date (datetime): Time of long entry. @@ -100,6 +106,7 @@ class UserChangelog(BaseType): changes_made (List[str]): A list of changed made. """ + id: str date: datetime user_id: int @@ -117,8 +124,9 @@ class UserChangelog(BaseType): class UserChangelogReport(BaseType): """Job id of a request for a changelog report file for download - + Attributes: job_id (str): The job id. Use for fetching JobStatus. """ + job_id: str diff --git a/indico/types/utils.py b/indico/types/utils.py index b48e3cc4..3e5b0c01 100644 --- a/indico/types/utils.py +++ b/indico/types/utils.py @@ -1,32 +1,40 @@ import re -from typing import Union import time +from typing import TYPE_CHECKING + +from pydantic.alias_generators import to_snake + from indico.errors import IndicoTimeoutError -_cc_to_snake_re = re.compile(r"(? str: + return to_snake(string) -def _camel(match): +def _camel(match: "Match[str]") -> str: return match.group(1) + match.group(2).upper() -def snake_to_cc(string: str): +# pydantic's `to_camel` behaves slightly differently that what we want, in that it +# doesn't parse strings that start with an underscore, hence why we do not use it here +def snake_to_cc(string: str) -> str: return re.sub(_snake_to_cc_re, _camel, string, 0) - class Timer: - def __init__(self, timeout: Union[int, float]): - self.timeout = timeout - self.start = time.time() - self.elapsed = 0 + def __init__(self, timeout: "Union[int, float]"): + self.timeout: "Union[int, float]" = timeout + self.start: float = time.time() + self.elapsed: float = 0 - def check(self): + def check(self) -> "Optional[NoReturn]": self.elapsed = time.time() - self.start if self.timeout < self.elapsed: - raise IndicoTimeoutError(self.elapsed) \ No newline at end of file + raise IndicoTimeoutError(self.elapsed) + + return None diff --git a/indico/types/workflow.py b/indico/types/workflow.py index 7ceda61c..d71f2c4f 100644 --- a/indico/types/workflow.py +++ b/indico/types/workflow.py @@ -1,7 +1,11 @@ -from enum import Enum -from typing import List, Union from datetime import datetime -from indico.types import BaseType, ModelGroup, ModelTaskType +from enum import Enum +from typing import TYPE_CHECKING, List + +from indico.types import BaseType, ModelGroup + +if TYPE_CHECKING: # pragma: no cover + from indico.typing import AnyDict class WorkflowComponent(BaseType): @@ -17,8 +21,6 @@ class WorkflowComponent(BaseType): task_type: str model_type: str model_group: ModelGroup - task_type: str - model_type: str minimum: int maximum: int @@ -71,7 +73,7 @@ def model_group_by_name(self, name: str) -> WorkflowComponent: class LinkedLabelStrategy(Enum): - BY_ROW = 0, + BY_ROW = (0,) BY_KEY = 1 @@ -81,7 +83,7 @@ def __init__( name: str, strategy: LinkedLabelStrategy, class_ids: List[int], - strategy_settings: dict, + strategy_settings: "AnyDict", ): self.name = name self.strategy = strategy @@ -90,8 +92,8 @@ def __init__( class ComponentFamily(Enum): - MODEL = 0, - FILTER = 1, - TRANSFORMER = 2, - REVIEW = 3, + MODEL = (0,) + FILTER = (1,) + TRANSFORMER = (2,) + REVIEW = (3,) OUTPUT = 4 diff --git a/indico/types/workflow_metrics.py b/indico/types/workflow_metrics.py index fc83ee73..e9824515 100644 --- a/indico/types/workflow_metrics.py +++ b/indico/types/workflow_metrics.py @@ -1,7 +1,8 @@ from enum import Enum -from indico.types.base import BaseType from typing import List +from indico.types.base import BaseType + class WorkflowMetricsOptions(Enum): """ @@ -13,6 +14,7 @@ class WorkflowMetricsOptions(Enum): STRAIGHT_THROUGH_PROCESSING: Generates StraightThroughProcessing metrics in the response. TIME_ON_TASK: Average time spent by reviewers on documents. """ + SUBMISSIONS = 1 REVIEW = 2 STRAIGHT_THROUGH_PROCESSING = 3 @@ -31,6 +33,7 @@ class TimeOnTaskMetric(BaseType): avg_min_per_doc_exceptions(float): Average amount of minutes reviewers spend on ducments for this workflow in the exceptions queue. """ + avg_mins_per_doc: float avg_mins_per_doc_review: float avg_min_per_doc_exceptions: float @@ -39,7 +42,7 @@ class TimeOnTaskMetric(BaseType): class DailyTimeOnTaskMetric(TimeOnTaskMetric): """ Daily time on task metrics. - + Attributes: avg_mins_per_doc(float): Average amount of minutes reviewers spend on documents in this workflow across review and exceptions queue. @@ -49,6 +52,7 @@ class DailyTimeOnTaskMetric(TimeOnTaskMetric): the exceptions queue. date(str): Date. """ + date: str @@ -60,6 +64,7 @@ class TimeOnTaskMetrics(BaseType): aggregate(TimeOnTaskMetric): Aggregate level time on task metrics. daily(List[DailyTimeOnTaskMetric): Daily level time on task metrics. """ + aggregate: TimeOnTaskMetric daily: List[DailyTimeOnTaskMetric] @@ -77,6 +82,7 @@ class SubmissionMetric(BaseType): rejected_in_review(int): Number of items rejected in the exceptions queue. date(str): Date. """ + submitted: int completed: int completed_in_review: int @@ -98,6 +104,7 @@ class DailySubmissionMetric(SubmissionMetric): rejected_in_review(int): Number of items rejected in the exceptions queue. date(str): Date. """ + date: str @@ -109,6 +116,7 @@ class SubmissionMetrics(BaseType): aggregate(SubmissionMetric): Aggregate level submission metrics. daily(List[SubmissionMetric]): Daily submission metrics. """ + aggregate: SubmissionMetric daily: List[DailySubmissionMetric] @@ -122,6 +130,7 @@ class DailyQueueMetric(BaseType): hours_on_queue(float): Cumulative hours items wait on queue for review. avg_age_in_queue(float): Average cumulative age of items waiting in review queues. """ + date: str subs_on_queue: int hours_on_queue: float @@ -134,6 +143,7 @@ class QueueMetrics(BaseType): Attributes: daily_cumulative(List[DailyQueueMetric]): list of cumulative queue metrics per day """ + daily_cumulative: List[DailyQueueMetric] @@ -143,6 +153,7 @@ class PredictionMetric(BaseType): Attributes: num_preds(int): the total number of model-generated predictions. """ + num_preds: int @@ -153,6 +164,7 @@ class DailyPredictionMetric(PredictionMetric): date(str): the date of the prediction. num_preds(int): the total number of model-generated predictions. """ + date: str @@ -164,6 +176,7 @@ class PredictionMetrics(BaseType): aggregate(PredictionMetric): Total number of predictions generated for this workflow. daily(List[PredictionMetrics]): Number of predictions generated per-date. """ + aggregate: PredictionMetric daily: List[PredictionMetric] @@ -180,6 +193,7 @@ class StpMetric(BaseType): review_stp_pct(float): Percent of human accepted model predictions. Present if auto review is disabled. auto_review_stp_pct(float): Percent of auto review labels accepted. Present if auto review is enabled. """ + review_numerator: int auto_review_numerator: int review_denom: int @@ -201,6 +215,7 @@ class DailyStpMetric(StpMetric): auto_review_stp_pct(float): Percent of auto review labels accepted. Present if auto review is enabled. date(str): The date these metrics are applicable. """ + date: str @@ -213,6 +228,7 @@ class ClassStpMetrics(BaseType): aggregate(StpMetric): Aggregate level metrics about this class. daily(List[DailyStpMetric]): Per-date STP metrics for this class. """ + class_name: str aggregate: StpMetric daily: List[DailyStpMetric] @@ -229,6 +245,7 @@ class ModelStpMetrics(BaseType): daily(List[DailyStpMetric]): Daily STP metrics for the model. class_metrics(List[ClassStpMetrics]): Metrics per model class. """ + model_group_id: int name: str aggregate: StpMetric @@ -243,6 +260,7 @@ class WorkflowStpMetrics(BaseType): Attributes: daily(List[DailyStpMetric]): List of daily metrics. """ + daily: List[DailyStpMetric] @@ -254,6 +272,7 @@ class StraightThroughProcessing(BaseType): workflow(WorkflowStpMetrics): Daily aggregate workflow level STP metrics. model(List[ModelStpMetrics]): Model STP metrics (including class STP). """ + workflow: WorkflowStpMetrics model: List[ModelStpMetrics] @@ -272,6 +291,7 @@ class WorkflowMetrics(BaseType): first_submitted_date(str): The earliest date of submission to this workflow. """ + workflow_id: int time_on_task: TimeOnTaskMetrics submissions: SubmissionMetrics diff --git a/indico/typing.py b/indico/typing.py new file mode 100644 index 00000000..da1dbca4 --- /dev/null +++ b/indico/typing.py @@ -0,0 +1,4 @@ +from typing import Any, Dict + +AnyDict = Dict[str, Any] +Payload = AnyDict diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..13fcd96c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,9 @@ +[tool.isort] +known_first_party = ["indico"] +profile = "black" + +[tool.mypy] +strict = true +packages = ["indico"] +mypy_path = "./stubs:${MYPYPATH}" +plugins = ["pydantic.mypy"] diff --git a/readme_docs/Dockerfile b/readme_docs/Dockerfile index a8d4861b..153c0f1d 100644 --- a/readme_docs/Dockerfile +++ b/readme_docs/Dockerfile @@ -3,5 +3,5 @@ FROM python:3.9.15 COPY . /indico-client WORKDIR /indico-client RUN python3 setup.py install -RUN apt-get update && apt-get install python3-sphinx jq vim -y -CMD ["sleep", "infinity"] \ No newline at end of file +RUN apt-get update && apt-get install python3-sphinx jq vim -y +CMD ["sleep", "infinity"] diff --git a/readme_docs/build_docs.sh b/readme_docs/build_docs.sh index 84bcf64e..5ae086c9 100644 --- a/readme_docs/build_docs.sh +++ b/readme_docs/build_docs.sh @@ -16,6 +16,6 @@ for filename in $(find . -maxdepth 1 -type f); do if [[ $filename == *"-types.md" ]]; then mv $filename ./types elif [[ $filename != *"/types.md" ]] && [[ $filename != *"/classes.md" ]] && [[ $filename != *"/index.md" ]] ; then - mv $filename ./classes + mv $filename ./classes fi -done \ No newline at end of file +done diff --git a/scripts/test.sh b/scripts/test.sh index 390f4866..1980860d 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -2,4 +2,4 @@ #put token in user's home dir printenv INDICO_API_TOKEN > ~/indico_api_token.txt #install additional test reqs and run smoke tests -tox -p \ No newline at end of file +tox -p diff --git a/setup.cfg b/setup.cfg index e304d5ae..96490348 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,4 +3,4 @@ VCS = git style = pep440 versionfile_source = indico/_version.py versionfile_build = indico/_version.py -tag_prefix = \ No newline at end of file +tag_prefix = diff --git a/setup.py b/setup.py index 10a46e6c..bf5eac87 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,12 @@ """ Setup for indico apis """ -import versioneer + from pathlib import Path -from setuptools import setup, find_packages + +from setuptools import find_packages, setup + +import versioneer setup( name="indico-client", @@ -16,7 +19,6 @@ url="https://github.com/IndicoDataSolutions/indico-client-python", author="indico", author_email="engineering@indico.io", - tests_require=["pytest>=5.2.1", "requests-mock>=1.8.0", "pytest-asyncio"], install_requires=[ "msgpack>=0.5.6", "msgpack-numpy==0.4.4.3", @@ -28,6 +30,7 @@ 'importlib-metadata ~= 1.0 ; python_version < "3.8"', "deprecation>=2.1.0", "jsons", - "aiohttp[speedups]" + "aiohttp[speedups]", + "pydantic>=2", ], ) diff --git a/stubs/jsons.pyi b/stubs/jsons.pyi new file mode 100644 index 00000000..9ab15a66 --- /dev/null +++ b/stubs/jsons.pyi @@ -0,0 +1,10 @@ +# the source lib is missing its py.typed marker + +from typing import Callable + +KEY_TRANSFORMER_CAMELCASE: Callable[[str], str] + +def dump( + obj: object, key_transformer: Callable[[str], str], strip_nulls: bool +) -> object: ... +def dumps(obj: object) -> str: ... diff --git a/stubs/msgpack.pyi b/stubs/msgpack.pyi new file mode 100644 index 00000000..4ee86cb3 --- /dev/null +++ b/stubs/msgpack.pyi @@ -0,0 +1,3 @@ +from typing import Any + +def unpackb(packed: bytes, **kwargs: Any) -> Any: ... diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d35a4dc8..59ff0601 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,11 +1,13 @@ import logging -import pytest import os +import pytest + from indico.client import IndicoClient from indico.queries.datasets import GetAvailableOcrEngines -from .data.datasets import * + from .data.async_datasets import * +from .data.datasets import * logging.getLogger("indico").setLevel(logging.DEBUG) diff --git a/tests/integration/data/AirlineComplaints.csv b/tests/integration/data/AirlineComplaints.csv index 1a08dd96..82588230 100644 --- a/tests/integration/data/AirlineComplaints.csv +++ b/tests/integration/data/AirlineComplaints.csv @@ -873,32 +873,32 @@ How dangerous is traveling with the weather like this?,You’re worried about th Is the weather going to affect the flight,You’re worried about the weather and its implications on your flight Will the weather implicate my flight,You’re worried about the weather and its implications on your flight Is it safe to fly in bad weather?,You’re worried about the weather and its implications on your flight -Do you have a lost & found department? I lost my bag.,Your bag was lost -Who do I contact about lost Luggage?,Your bag was lost -"I lost my bag, it would help me get it",Your bag was lost -Where can I find my bag?,Your bag was lost -Can you help me find my bag?,Your bag was lost -there will be some way to get my bag,Your bag was lost -"I just arrived to my destination, but my bag is lost, could you help me?",Your bag was lost -Where can I report a lost bag?,Your bag was lost -"Please help, my purse was stolen!",Your bag was lost -"Please, I lost my purse, could you tell me where I can get help?",Your bag was lost -"My bag is lost, What am I supposed to do?",Your bag was lost -My bag is missing and i have very important documents in it.,Your bag was lost -I lost my package could help me get it,Your bag was lost -hi could you please tell me again the information that i did not understand,Your bag was lost -my bag was lost I could help her,Your bag was lost -How may I find my lost bag?,Your bag was lost -i have lost my bag so is there any possibility to repack it,Your bag was lost -What happened with my bag?,Your bag was lost -Can you help me get a new flight?,Your bag was lost -where the lost objects department is located,Your bag was lost -How do I file the complaint?,Your bag was lost -Please help me my bag was lost.,Your bag was lost -"I can not find my bag, help",Your bag was lost -I need you to please respond for my bag,Your bag was lost -My luggage is gone. Do I have insurance?,Your bag was lost -Can I get a refund for it?,Your bag was lost +Do you have a lost & found department? I lost my bag.,Your bag was lost +Who do I contact about lost Luggage?,Your bag was lost +"I lost my bag, it would help me get it",Your bag was lost +Where can I find my bag?,Your bag was lost +Can you help me find my bag?,Your bag was lost +there will be some way to get my bag,Your bag was lost +"I just arrived to my destination, but my bag is lost, could you help me?",Your bag was lost +Where can I report a lost bag?,Your bag was lost +"Please help, my purse was stolen!",Your bag was lost +"Please, I lost my purse, could you tell me where I can get help?",Your bag was lost +"My bag is lost, What am I supposed to do?",Your bag was lost +My bag is missing and i have very important documents in it.,Your bag was lost +I lost my package could help me get it,Your bag was lost +hi could you please tell me again the information that i did not understand,Your bag was lost +my bag was lost I could help her,Your bag was lost +How may I find my lost bag?,Your bag was lost +i have lost my bag so is there any possibility to repack it,Your bag was lost +What happened with my bag?,Your bag was lost +Can you help me get a new flight?,Your bag was lost +where the lost objects department is located,Your bag was lost +How do I file the complaint?,Your bag was lost +Please help me my bag was lost.,Your bag was lost +"I can not find my bag, help",Your bag was lost +I need you to please respond for my bag,Your bag was lost +My luggage is gone. Do I have insurance?,Your bag was lost +Can I get a refund for it?,Your bag was lost My flight was canceled. Can I get a new one?,Your flight was cancelled and you need to get on another one I need to travel urgently in the next,Your flight was cancelled and you need to get on another one I would like to book another flight because the other one was cancelled.,Your flight was cancelled and you need to get on another one @@ -1802,35 +1802,35 @@ if it is safe to travel with the weather as well ??,You’re worried about the w Tell me if the weather going to affect the flight,You’re worried about the weather and its implications on your flight Tell me if the weather implicate my flight,You’re worried about the weather and its implications on your flight Is the turbulence from storm cells really dangerous?,You’re worried about the weather and its implications on your flight -I lost my bag on a plane. Can you help me?,Your bag was lost -Lost two suitcases with things,Your bag was lost -I lost my bag I could get another because I can not find it,Your bag was lost -Where is lost and found?,Your bag was lost -I lost my bag. Can you help me?,Your bag was lost -where could you go for information,Your bag was lost -"My bag is missing, can you check where it is?",Your bag was lost -Do you have a department for lost items?,Your bag was lost -"Could you help, my bag is lost",Your bag was lost -Where do lost objects go?,Your bag was lost -How do I not get my bag?,Your bag was lost -What should I do to find my lost bag?,Your bag was lost -"My bag was lost , how do you could help me?",Your bag was lost -i think i lost my package,Your bag was lost -hi could you please tell me again the information that i did not understand,Your bag was lost -I think I lost my bag,Your bag was lost -How may I find my lost bag?,Your bag was lost -it is black and have a white line in one side,Your bag was lost -How long will it take to get my bag back?,Your bag was lost -I need you to find me my bag,Your bag was lost -Can you help me get a new flight?,Your bag was lost -Can you help me I lost my bag,Your bag was lost -who can inform me of my lost bag on the flight,Your bag was lost -who is in charge must be responsible,Your bag was lost -I really need help my bag was lost.,Your bag was lost -Could you help me locate my bag?,Your bag was lost -They are supposed to have to be careful where the luggage is supposed to be safe,Your bag was lost -Who do I look for information about lost luggage?,Your bag was lost -Are you responsible for my lost yewelry?,Your bag was lost +I lost my bag on a plane. Can you help me?,Your bag was lost +Lost two suitcases with things,Your bag was lost +I lost my bag I could get another because I can not find it,Your bag was lost +Where is lost and found?,Your bag was lost +I lost my bag. Can you help me?,Your bag was lost +where could you go for information,Your bag was lost +"My bag is missing, can you check where it is?",Your bag was lost +Do you have a department for lost items?,Your bag was lost +"Could you help, my bag is lost",Your bag was lost +Where do lost objects go?,Your bag was lost +How do I not get my bag?,Your bag was lost +What should I do to find my lost bag?,Your bag was lost +"My bag was lost , how do you could help me?",Your bag was lost +i think i lost my package,Your bag was lost +hi could you please tell me again the information that i did not understand,Your bag was lost +I think I lost my bag,Your bag was lost +How may I find my lost bag?,Your bag was lost +it is black and have a white line in one side,Your bag was lost +How long will it take to get my bag back?,Your bag was lost +I need you to find me my bag,Your bag was lost +Can you help me get a new flight?,Your bag was lost +Can you help me I lost my bag,Your bag was lost +who can inform me of my lost bag on the flight,Your bag was lost +who is in charge must be responsible,Your bag was lost +I really need help my bag was lost.,Your bag was lost +Could you help me locate my bag?,Your bag was lost +They are supposed to have to be careful where the luggage is supposed to be safe,Your bag was lost +Who do I look for information about lost luggage?,Your bag was lost +Are you responsible for my lost yewelry?,Your bag was lost Did you organized something in place canceled flight?,Your flight was cancelled and you need to get on another one "I go with children, I will have a room in a hotel",Your flight was cancelled and you need to get on another one Could you help me to get a new flight?,Your flight was cancelled and you need to get on another one @@ -2733,35 +2733,35 @@ I do not think it's convenient to travel like this,You’re worried about the we Want to know if the weather going to affect the flight,You’re worried about the weather and its implications on your flight Want to know if the weather implicate my flight,You’re worried about the weather and its implications on your flight Are airports dangerous for landings and take-offs?,You’re worried about the weather and its implications on your flight -I lost my bag on a plane. What I must do to try to find it?,Your bag was lost -"They had a name tag, please check me for the ticket number",Your bag was lost -I would bring a bag since I can not get mine,Your bag was lost -How can I file a lost baggage claim?,Your bag was lost -My luggage did not arrive. can you help me?,Your bag was lost -will be held responsible for the loss,Your bag was lost -"I am at the luggage claim, still waiting for my bag, I think it is lost.",Your bag was lost -I need to find my lost bag.,Your bag was lost -"I lost somewhere in my bag, help me please!",Your bag was lost -"If they can't find it, what would happen, is there a refund?",Your bag was lost -How long should I wait to find out?,Your bag was lost -"Please, I need help. My bag is lost?",Your bag was lost -"you have total responsability over my luggage,and it was lost.",Your bag was lost -I do not remember where I put my package I could not find it could help me,Your bag was lost -hi could you please tell me again the information that i did not understand,Your bag was lost -I do not remember where my bag could fit,Your bag was lost -How may I find my lost bag?,Your bag was lost -it is contain my money and other documents,Your bag was lost -Why did you lost my bag?,Your bag was lost -my bag is not to be missed,Your bag was lost -Can you help me get a new flight?,Your bag was lost -I lost my bag. Please help me,Your bag was lost -"I lost my bag, who can help me with this?",Your bag was lost -Can you help me to find my bag It's lost.,Your bag was lost -"I need you to help me locate my bag, please",Your bag was lost -please check my luggage my boldsa has to appear,Your bag was lost -I have goods of importance in my suitcase. What should I do?,Your bag was lost -Where can I fill a form for refund?,Your bag was lost -can you tell me where my bag is ?,Your bag was lost +I lost my bag on a plane. What I must do to try to find it?,Your bag was lost +"They had a name tag, please check me for the ticket number",Your bag was lost +I would bring a bag since I can not get mine,Your bag was lost +How can I file a lost baggage claim?,Your bag was lost +My luggage did not arrive. can you help me?,Your bag was lost +will be held responsible for the loss,Your bag was lost +"I am at the luggage claim, still waiting for my bag, I think it is lost.",Your bag was lost +I need to find my lost bag.,Your bag was lost +"I lost somewhere in my bag, help me please!",Your bag was lost +"If they can't find it, what would happen, is there a refund?",Your bag was lost +How long should I wait to find out?,Your bag was lost +"Please, I need help. My bag is lost?",Your bag was lost +"you have total responsability over my luggage,and it was lost.",Your bag was lost +I do not remember where I put my package I could not find it could help me,Your bag was lost +hi could you please tell me again the information that i did not understand,Your bag was lost +I do not remember where my bag could fit,Your bag was lost +How may I find my lost bag?,Your bag was lost +it is contain my money and other documents,Your bag was lost +Why did you lost my bag?,Your bag was lost +my bag is not to be missed,Your bag was lost +Can you help me get a new flight?,Your bag was lost +I lost my bag. Please help me,Your bag was lost +"I lost my bag, who can help me with this?",Your bag was lost +Can you help me to find my bag It's lost.,Your bag was lost +"I need you to help me locate my bag, please",Your bag was lost +please check my luggage my boldsa has to appear,Your bag was lost +I have goods of importance in my suitcase. What should I do?,Your bag was lost +Where can I fill a form for refund?,Your bag was lost +can you tell me where my bag is ?,Your bag was lost Can you help me? My flight is canceled but I need to travel to designated place as soon as posible.,Your flight was cancelled and you need to get on another one "please be quick, you need to be on time",Your flight was cancelled and you need to get on another one I am in need of getting a new flight because my previous one was cancelled,Your flight was cancelled and you need to get on another one diff --git a/tests/integration/data/TooSmall.csv b/tests/integration/data/TooSmall.csv index e2e1e1ec..fdae9468 100644 --- a/tests/integration/data/TooSmall.csv +++ b/tests/integration/data/TooSmall.csv @@ -3,4 +3,4 @@ ID,Target_1,Text 1,You are threatening to never to use this airline again,I will never use this airline again. 2,You are threatening to never to use this airline again,"BECAUSE THEY TREAT ME SO EVIL, I WILL NOT USE THIS COMPANY TO STOP FLYING" 3,You are threatening to never to use this airline again,If I have to pay any additional fees I will not fly more for this company. -4,You are threatening to never to use this airline again,"If you do not give me what I'm asking for, I'll book another flight with another airline" \ No newline at end of file +4,You are threatening to never to use this airline again,"If you do not give me what I'm asking for, I'll book another flight with another airline" diff --git a/tests/integration/data/async_datasets.py b/tests/integration/data/async_datasets.py index 877b72c1..47da2cbe 100644 --- a/tests/integration/data/async_datasets.py +++ b/tests/integration/data/async_datasets.py @@ -1,10 +1,10 @@ -import pytest_asyncio import time from pathlib import Path + +import pytest_asyncio + from indico.client import AsyncIndicoClient -from indico.queries import ( - CreateDataset, -) +from indico.queries import CreateDataset @pytest_asyncio.fixture(scope="module") diff --git a/tests/integration/data/datasets.py b/tests/integration/data/datasets.py index ad604fc9..05ca1a31 100644 --- a/tests/integration/data/datasets.py +++ b/tests/integration/data/datasets.py @@ -1,19 +1,21 @@ import os -import pytest import time from pathlib import Path + +import pytest + from indico.client import IndicoClient from indico.queries import ( + AddExchangeIntegration, + AddModelGroupComponent, CreateDataset, CreateWorkflow, - AddModelGroupComponent, GetModelGroup, - AddExchangeIntegration, - StartIntegration, GetWorkflow, + StartIntegration, ) from indico.queries.workflow_components import _AddWorkflowComponent -from indico.types import ModelGroup, Dataset, Workflow, Integration +from indico.types import Dataset, Integration, ModelGroup, Workflow PUBLIC_URL = "https://github.com/IndicoDataSolutions/indico-client-python/raw/master/tests/integration/data/" @@ -270,19 +272,14 @@ def exchange_integration_to_delete(org_annotate_workflow: Workflow) -> Integrati creds = { "clientId": os.getenv("EXCH_CLIENT_ID"), "clientSecret": os.getenv("EXCH_CLIENT_SECRET"), - "tenantId": os.getenv("EXCH_TENANT_ID") + "tenantId": os.getenv("EXCH_TENANT_ID"), } - config = { - "userId": os.getenv("EXCH_USER_ID"), - "folderId": "mailFolders('inbox')" - } + config = {"userId": os.getenv("EXCH_USER_ID"), "folderId": "mailFolders('inbox')"} integ: Integration = client.call( AddExchangeIntegration( - workflow_id=org_annotate_workflow.id, - config=config, - credentials=creds + workflow_id=org_annotate_workflow.id, config=config, credentials=creds ) ) @@ -295,24 +292,15 @@ def started_exchange_integration(org_annotate_workflow: Workflow) -> Integration creds = { "clientId": os.getenv("EXCH_CLIENT_ID"), "clientSecret": os.getenv("EXCH_CLIENT_SECRET"), - "tenantId": os.getenv("EXCH_TENANT_ID") + "tenantId": os.getenv("EXCH_TENANT_ID"), } - config = { - "userId": os.getenv("EXCH_USER_ID"), - "folderId": "mailFolders('inbox')" - } + config = {"userId": os.getenv("EXCH_USER_ID"), "folderId": "mailFolders('inbox')"} integ: Integration = client.call( AddExchangeIntegration( - workflow_id=org_annotate_workflow.id, - config=config, - credentials=creds - ) - ) - client.call( - StartIntegration( - integration_id=integ.id + workflow_id=org_annotate_workflow.id, config=config, credentials=creds ) ) + client.call(StartIntegration(integration_id=integ.id)) yield integ diff --git a/tests/integration/data/pdf_links.csv b/tests/integration/data/pdf_links.csv index 56ebd22b..0bb82d55 100644 --- a/tests/integration/data/pdf_links.csv +++ b/tests/integration/data/pdf_links.csv @@ -1,4 +1,4 @@ pdf_link https://storage.googleapis.com/indico-testing/test-pdfs/financial_disclosures/Lee-Ebony-Letise.pdf https://storage.googleapis.com/indico-testing/test-pdfs/financial_disclosures/Matthew-Miller.pdf -https://storage.googleapis.com/indico-testing/test-pdfs/financial_disclosures/Mia-Palmieri-Heck.pdf \ No newline at end of file +https://storage.googleapis.com/indico-testing/test-pdfs/financial_disclosures/Mia-Palmieri-Heck.pdf diff --git a/tests/integration/queries/test_custom_blueprint.py b/tests/integration/queries/test_custom_blueprint.py index 19ef5324..cb5e917d 100644 --- a/tests/integration/queries/test_custom_blueprint.py +++ b/tests/integration/queries/test_custom_blueprint.py @@ -1,4 +1,5 @@ import pytest + from indico.client import IndicoClient from indico.errors import IndicoInputError from indico.queries.custom_blueprint import RegisterCustomBlueprint @@ -8,34 +9,38 @@ def test_register_blueprint(indico): client = IndicoClient() bpreq = RegisterCustomBlueprint( - component_family=ComponentFamily.OUTPUT, - name="Meowtput", - description="adds 'cat' to the result file", + component_family=ComponentFamily.OUTPUT, + name="Meowtput", + description="adds 'cat' to the result file", config={ "inputs": [{"name": "result_file", "ioClass": "PartitionStream"}], "outputs": [{"name": "result_file", "ioClass": "PartitionStream"}], - "submissionLauncher": {"service": "customizer_default", "name": "cat_output"} + "submissionLauncher": { + "service": "customizer_default", + "name": "cat_output", + }, }, - tags= ["custom"], + tags=["custom"], all_access=True, - footer="version 2" + footer="version 2", ) bp = client.call(bpreq) print(bp.__dict__) assert bp.id + def test_register_blueprint_bad_config(indico): with pytest.raises(IndicoInputError): RegisterCustomBlueprint( - component_family=ComponentFamily.OUTPUT, - name="Meowtput", - description="adds 'cat' to the result file", + component_family=ComponentFamily.OUTPUT, + name="Meowtput", + description="adds 'cat' to the result file", config={ "inputs": [{"name": "result_file", "ioClass": "PartitionStream"}], "outputs": [{"name": "result_file", "ioClass": "PartitionStream"}], - "submissionLauncher": {} + "submissionLauncher": {}, }, - tags= ["custom"], + tags=["custom"], all_access=True, - footer="version 2" + footer="version 2", ) diff --git a/tests/integration/queries/test_document.py b/tests/integration/queries/test_document.py index 013115a3..a378d575 100644 --- a/tests/integration/queries/test_document.py +++ b/tests/integration/queries/test_document.py @@ -1,12 +1,14 @@ -import unittest -import pytest import os +import unittest from pathlib import Path + +import pytest + from indico.client import IndicoClient from indico.queries import ( - RetrieveStorageObject, - JobStatus, DocumentExtraction, + JobStatus, + RetrieveStorageObject, UploadBatched, UploadDocument, ) @@ -144,7 +146,10 @@ def test_document_extraction_batched(indico): assert job.ready is True assert isinstance(job.result["url"], str) -@unittest.skip("Expected to fail pending https://indicodata.atlassian.net/browse/SUP-437") + +@unittest.skip( + "Expected to fail pending https://indicodata.atlassian.net/browse/SUP-437" +) def test_document_extraction_images(indico): client = IndicoClient() dataset_filepath = str(Path(__file__).parents[1]) + "/data/mock.pdf" diff --git a/tests/integration/queries/test_document_report.py b/tests/integration/queries/test_document_report.py index b681fd89..aaed150a 100644 --- a/tests/integration/queries/test_document_report.py +++ b/tests/integration/queries/test_document_report.py @@ -2,9 +2,9 @@ from typing import List from indico import IndicoClient +from indico.filters import DocumentReportFilter from indico.queries.document_report import GetDocumentReport from indico.types.document_report import DocumentReport -from indico.filters import DocumentReportFilter def test_fetch_docs(indico): diff --git a/tests/integration/queries/test_export.py b/tests/integration/queries/test_export.py index 440ee66b..73094027 100644 --- a/tests/integration/queries/test_export.py +++ b/tests/integration/queries/test_export.py @@ -1,13 +1,16 @@ -import pytest import json +import re from unittest import mock + +import pytest + from indico.client import IndicoClient -from indico.types.dataset import Dataset -from indico.types.export import Export from indico.errors import IndicoRequestError from indico.queries.export import CreateExport, DownloadExport +from indico.types.dataset import Dataset +from indico.types.export import Export + from ..data.datasets import airlines_dataset -import re def test_create_and_download_export(airlines_dataset: Dataset): diff --git a/tests/integration/queries/test_integration.py b/tests/integration/queries/test_integration.py index c6ee9f91..fa8c0a2b 100644 --- a/tests/integration/queries/test_integration.py +++ b/tests/integration/queries/test_integration.py @@ -3,23 +3,23 @@ from indico.client import IndicoClient from indico.queries import ( AddExchangeIntegration, - StartIntegration, DeleteIntegration, PauseIntegration, + StartIntegration, ) - -from indico.types import Workflow, Integration, ModelGroup +from indico.types import Integration, ModelGroup, Workflow from tests.integration.data.datasets import ( - airlines_workflow, airlines_dataset, - org_annotate_workflow, + airlines_workflow, + exchange_integration_to_delete, org_annotate_dataset, org_annotate_exchange_integration, org_annotate_model_group, - exchange_integration_to_delete, - started_exchange_integration + org_annotate_workflow, + started_exchange_integration, ) + def test_add_integration(airlines_workflow: Workflow): client = IndicoClient() creds = { @@ -48,31 +48,19 @@ def test_start_integration( integ = org_annotate_exchange_integration assert not integ.enabled client = IndicoClient() - resp = client.call( - StartIntegration( - integration_id=integ.id - ) - ) + resp = client.call(StartIntegration(integration_id=integ.id)) assert resp["startWorkflowIntegration"]["success"] - + def test_delete_integration(exchange_integration_to_delete: Integration): integ = exchange_integration_to_delete client = IndicoClient() - resp = client.call( - DeleteIntegration( - integration_id=integ.id - ) - ) + resp = client.call(DeleteIntegration(integration_id=integ.id)) assert resp["deleteWorkflowIntegration"]["success"] def test_pause_integration(started_exchange_integration: Integration): integ = started_exchange_integration client = IndicoClient() - resp = client.call( - PauseIntegration( - integration_id=integ.id - ) - ) + resp = client.call(PauseIntegration(integration_id=integ.id)) assert resp["pauseWorkflowIntegration"]["success"] diff --git a/tests/integration/queries/test_job.py b/tests/integration/queries/test_job.py index d9e66858..279c7db3 100644 --- a/tests/integration/queries/test_job.py +++ b/tests/integration/queries/test_job.py @@ -3,9 +3,9 @@ import pytest from indico.client import IndicoClient -from indico.queries import JobStatus, DocumentExtraction -from indico.types.jobs import Job from indico.errors import IndicoTimeoutError +from indico.queries import DocumentExtraction, JobStatus +from indico.types.jobs import Job def test_job_wait_on_success(indico): diff --git a/tests/integration/queries/test_model_group.py b/tests/integration/queries/test_model_group.py index 39df02b8..ce4e076e 100644 --- a/tests/integration/queries/test_model_group.py +++ b/tests/integration/queries/test_model_group.py @@ -26,20 +26,6 @@ from indico.types.model import Model, ModelOptions, TrainingProgress from indico.types.model_group import ModelGroup -from ..data.datasets import ( - airlines_dataset, - airlines_model_group, - airlines_workflow, - cats_dogs_image_dataset, - cats_dogs_image_workflow, - cats_dogs_modelgroup, - org_annotate_dataset, - org_annotate_model_group, - org_annotate_workflow, - too_small_dataset, - too_small_workflow, -) - def test_get_missing_model_group(indico): client = IndicoClient() diff --git a/tests/integration/queries/test_submission.py b/tests/integration/queries/test_submission.py index 20da011e..9341e667 100644 --- a/tests/integration/queries/test_submission.py +++ b/tests/integration/queries/test_submission.py @@ -1,9 +1,10 @@ -import pytest from datetime import datetime -from indico.client import IndicoClient, IndicoConfig +import pytest + +from indico.client import IndicoClient from indico.filters import DateRangeFilter, SubmissionFilter, SubmissionReviewFilter -from indico.queries import ListSubmissions, GetReviews +from indico.queries import GetReviews, ListSubmissions from indico.types import SubmissionReviewFull @@ -17,7 +18,9 @@ def test_list_submissions(indico): def test_list_submissions_filter_filetype(indico): client = IndicoClient() - subs = client.call(ListSubmissions(filters=SubmissionFilter(file_type=["PDF"]), limit=10)) + subs = client.call( + ListSubmissions(filters=SubmissionFilter(file_type=["PDF"]), limit=10) + ) assert len(subs) > 0 for sub in subs: sub_filetype = sub.input_filename.split(".")[-1].upper() @@ -26,15 +29,16 @@ def test_list_submissions_filter_filetype(indico): @pytest.mark.parametrize( "_input_filename, _should_contain", - [ - ("pdf", True), - ("randomstring", False) - ], + [("pdf", True), ("randomstring", False)], ) def test_list_submissions_filter_filename(indico, _input_filename, _should_contain): client = IndicoClient() - subs = client.call(ListSubmissions(filters=SubmissionFilter(input_filename=_input_filename), limit=10)) + subs = client.call( + ListSubmissions( + filters=SubmissionFilter(input_filename=_input_filename), limit=10 + ) + ) for sub in subs: assert (_input_filename in sub.input_filename) == _should_contain @@ -44,13 +48,18 @@ def test_list_submissions_filter_reviews(indico): review_filter = SubmissionReviewFilter(rejected=False) - subs = client.call(ListSubmissions(filters=SubmissionFilter(reviews=review_filter), limit=10)) + subs = client.call( + ListSubmissions(filters=SubmissionFilter(reviews=review_filter), limit=10) + ) assert len(subs) >= 0 + def test_list_submissions_filter_reviews_in_progress(indico): client = IndicoClient() - subs = client.call(ListSubmissions(filters=SubmissionFilter(review_in_progress=False), limit=10)) + subs = client.call( + ListSubmissions(filters=SubmissionFilter(review_in_progress=False), limit=10) + ) assert len(subs) > 0 @@ -59,11 +68,15 @@ def test_list_submissions_filter_created_at(indico): date_filter = DateRangeFilter( filter_from=datetime(year=2020, month=2, day=2).strftime("%Y-%m-%d"), - filter_to=datetime.now().strftime("%Y-%m-%d") + filter_to=datetime.now().strftime("%Y-%m-%d"), + ) + subs = client.call( + ListSubmissions(filters=SubmissionFilter(created_at=date_filter), limit=10) ) - subs = client.call(ListSubmissions(filters=SubmissionFilter(created_at=date_filter), limit=10)) assert len(subs) > 0 - subs = client.call(ListSubmissions(filters=SubmissionFilter(updated_at=date_filter), limit=10)) + subs = client.call( + ListSubmissions(filters=SubmissionFilter(updated_at=date_filter), limit=10) + ) assert len(subs) > 0 @@ -75,6 +88,6 @@ def test_get_reviews(indico): for sub in subs: reviews = client.call(GetReviews(sub.id)) for review in reviews: - assert (isinstance(review, SubmissionReviewFull)) + assert isinstance(review, SubmissionReviewFull) if not review.rejected: assert isinstance(review.changes, (list, dict)) diff --git a/tests/integration/queries/test_user_metrics.py b/tests/integration/queries/test_user_metrics.py index a232d94f..5c725291 100644 --- a/tests/integration/queries/test_user_metrics.py +++ b/tests/integration/queries/test_user_metrics.py @@ -1,15 +1,17 @@ +from datetime import datetime + import pytest + from indico.client import IndicoClient -from indico.filters import or_, UserMetricsFilter +from indico.filters import UserMetricsFilter, or_ from indico.queries import JobStatus, RetrieveStorageObject -from indico.types.user_metrics import UserSummary from indico.queries.usermetrics import ( - GetUserSummary, - GetUserSnapshots, - GetUserChangelog, GenerateChangelogReport, + GetUserChangelog, + GetUserSnapshots, + GetUserSummary, ) -from datetime import datetime +from indico.types.user_metrics import UserSummary def test_fetch_summary(indico): diff --git a/tests/integration/queries/test_workflow.py b/tests/integration/queries/test_workflow.py index 5982d106..48573577 100644 --- a/tests/integration/queries/test_workflow.py +++ b/tests/integration/queries/test_workflow.py @@ -380,7 +380,6 @@ def test_workflow_submission_auto_review_v1( org_annotate_workflow, org_annotate_model_group, ): - client = IndicoClient() wf = client.call( @@ -426,7 +425,6 @@ def test_workflow_submission_auto_review_v3_result( org_annotate_workflow, org_annotate_model_group, ): - client = IndicoClient() wf = client.call( diff --git a/tests/integration/queries/test_workflow_component.py b/tests/integration/queries/test_workflow_component.py index 27a1878a..09398b70 100644 --- a/tests/integration/queries/test_workflow_component.py +++ b/tests/integration/queries/test_workflow_component.py @@ -1,5 +1,5 @@ from indico.client import IndicoClient -from indico.queries import DeleteWorkflowComponent, GetWorkflow, GetDataset +from indico.queries import DeleteWorkflowComponent, GetDataset, GetWorkflow from indico.queries.workflow_components import AddLinkClassificationComponent from indico.types import ModelGroup, ModelTaskType, NewLabelsetArguments @@ -55,6 +55,7 @@ def test_delete_workflow_component( assert mg_comp_id not in {c.id for c in wf.components} assert len(wf.component_links) == num_links - 1 + def test_add_many_filtered_classes(indico, org_annotate_dataset): client = IndicoClient() workflowreq = CreateWorkflow( @@ -63,7 +64,9 @@ def test_add_many_filtered_classes(indico, org_annotate_dataset): wf = client.call(workflowreq) mg_name = f"TestAnnotationModelGroup-{int(time.time())}" labelset_name = "test-filtered-classes" - after_component_id = next(c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION") + after_component_id = next( + c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION" + ) source_column_id = org_annotate_dataset.datacolumn_by_name( "News Headlines w/Company Names" ).id @@ -81,24 +84,24 @@ def test_add_many_filtered_classes(indico, org_annotate_dataset): ), ) wf = client.call(modelgroupreq) - mg = next(c for c in wf.components if c.component_type == 'MODEL_GROUP') + mg = next(c for c in wf.components if c.component_type == "MODEL_GROUP") after_component_id = mg.id - - classes_to_filter = [ - ["type 1"], - ["type 3"] - ] + + classes_to_filter = [["type 1"], ["type 3"]] filtered = AddLinkClassificationComponent( - workflow_id=wf.id, - after_component_id=after_component_id, - model_group_id=mg.model_group.id, - filtered_classes=classes_to_filter, - labels="actual", - ) + workflow_id=wf.id, + after_component_id=after_component_id, + model_group_id=mg.model_group.id, + filtered_classes=classes_to_filter, + labels="actual", + ) wf = client.call(filtered) - new_component = next(c.id for c in wf.components if c.component_type == "LINK_CLASSIFICATION_MODEL") + new_component = next( + c.id for c in wf.components if c.component_type == "LINK_CLASSIFICATION_MODEL" + ) assert new_component is not None + def test_add_single_filtered_class(indico, org_annotate_dataset): client = IndicoClient() workflowreq = CreateWorkflow( @@ -107,7 +110,9 @@ def test_add_single_filtered_class(indico, org_annotate_dataset): wf = client.call(workflowreq) mg_name = f"TestAnnotationModelGroup-{int(time.time())}" labelset_name = "test-filtered-classes" - after_component_id = next(c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION") + after_component_id = next( + c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION" + ) source_column_id = org_annotate_dataset.datacolumn_by_name( "News Headlines w/Company Names" ).id @@ -125,21 +130,21 @@ def test_add_single_filtered_class(indico, org_annotate_dataset): ), ) wf = client.call(modelgroupreq) - mg = next(c for c in wf.components if c.component_type == 'MODEL_GROUP') + mg = next(c for c in wf.components if c.component_type == "MODEL_GROUP") after_component_id = mg.id - - classes_to_filter = [ - ["type 1"] - ] + + classes_to_filter = [["type 1"]] filtered = AddLinkClassificationComponent( - workflow_id=wf.id, - after_component_id=after_component_id, - model_group_id=mg.model_group.id, - filtered_classes=classes_to_filter, - labels="actual", - ) + workflow_id=wf.id, + after_component_id=after_component_id, + model_group_id=mg.model_group.id, + filtered_classes=classes_to_filter, + labels="actual", + ) wf = client.call(filtered) - new_component = next(c.id for c in wf.components if c.component_type == 'LINK_CLASSIFICATION_MODEL') + new_component = next( + c.id for c in wf.components if c.component_type == "LINK_CLASSIFICATION_MODEL" + ) assert new_component is not None @@ -151,7 +156,9 @@ def test_add_bad_syntax_filtered_classes(indico, org_annotate_dataset): wf = client.call(workflowreq) mg_name = f"TestAnnotationModelGroup-{int(time.time())}" labelset_name = "test-filtered-classes" - after_component_id = next(c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION") + after_component_id = next( + c.id for c in wf.components if c.component_type == "INPUT_OCR_EXTRACTION" + ) source_column_id = org_annotate_dataset.datacolumn_by_name( "News Headlines w/Company Names" ).id @@ -169,20 +176,16 @@ def test_add_bad_syntax_filtered_classes(indico, org_annotate_dataset): ), ) wf = client.call(modelgroupreq) - mg = next(c for c in wf.components if c.component_type == 'MODEL_GROUP') + mg = next(c for c in wf.components if c.component_type == "MODEL_GROUP") after_component_id = mg.id - - classes_to_filter = [ - ["type 1"], - ["type 2"] - ] + + classes_to_filter = [["type 1"], ["type 2"]] filtered = AddLinkClassificationComponent( - workflow_id=wf.id, - after_component_id=after_component_id, - model_group_id=6108, - filtered_classes=[classes_to_filter], - labels="actual", - ) + workflow_id=wf.id, + after_component_id=after_component_id, + model_group_id=6108, + filtered_classes=[classes_to_filter], + labels="actual", + ) with pytest.raises(Exception): wf = client.call(filtered) - \ No newline at end of file diff --git a/tests/integration/queries/test_workflow_metrics.py b/tests/integration/queries/test_workflow_metrics.py index 0a372caa..e647e8b5 100644 --- a/tests/integration/queries/test_workflow_metrics.py +++ b/tests/integration/queries/test_workflow_metrics.py @@ -1,32 +1,36 @@ -from pathlib import Path import time +from datetime import datetime +from pathlib import Path from typing import List import pytest + from indico.client import IndicoClient from indico.queries import ( JobStatus, RetrieveStorageObject, + SubmitReview, UpdateWorkflowSettings, - WorkflowSubmission, WaitForSubmissions, - SubmitReview, + WorkflowSubmission, ) - -from indico.types.workflow_metrics import WorkflowMetrics, WorkflowMetricsOptions from indico.queries.workflow_metrics import GetWorkflowMetrics -from datetime import datetime +from indico.types.workflow_metrics import WorkflowMetrics, WorkflowMetricsOptions + from ..data.datasets import * # noqa from ..data.datasets import PUBLIC_URL -import time @pytest.fixture -def workflow(indico, org_annotate_dataset, org_annotate_workflow, org_annotate_model_group): +def workflow( + indico, org_annotate_dataset, org_annotate_workflow, org_annotate_model_group +): client = IndicoClient() wf = client.call( - UpdateWorkflowSettings(org_annotate_workflow.id, enable_review=True, enable_auto_review=True) + UpdateWorkflowSettings( + org_annotate_workflow.id, enable_review=True, enable_auto_review=True + ) ) assert wf.review_enabled and wf.auto_review_enabled diff --git a/tests/integration/test_base_aioclient.py b/tests/integration/test_base_aioclient.py index 0e8c9f70..cd3fedc0 100644 --- a/tests/integration/test_base_aioclient.py +++ b/tests/integration/test_base_aioclient.py @@ -1,10 +1,9 @@ import pytest from indico.client import AsyncIndicoClient, IndicoConfig -from indico.client.request import HTTPRequest, HTTPMethod, GraphQLRequest +from indico.client.request import GraphQLRequest, HTTPMethod, HTTPRequest from indico.errors import IndicoAuthenticationFailed - pytestmark = pytest.mark.asyncio(scope="module") diff --git a/tests/integration/test_base_client.py b/tests/integration/test_base_client.py index 6885697f..6843dc30 100644 --- a/tests/integration/test_base_client.py +++ b/tests/integration/test_base_client.py @@ -1,7 +1,7 @@ import pytest from indico.client import IndicoClient, IndicoConfig -from indico.client.request import HTTPRequest, HTTPMethod, GraphQLRequest +from indico.client.request import GraphQLRequest, HTTPMethod, HTTPRequest from indico.errors import IndicoAuthenticationFailed from tests.integration.data.datasets import org_annotate_dataset diff --git a/tests/unit/client/test_aioclient.py b/tests/unit/client/test_aioclient.py index 613cc5d7..cece915a 100644 --- a/tests/unit/client/test_aioclient.py +++ b/tests/unit/client/test_aioclient.py @@ -1,7 +1,6 @@ import pytest - -from indico.client import AsyncIndicoClient, HTTPRequest, HTTPMethod, GraphQLRequest +from indico.client import AsyncIndicoClient, GraphQLRequest, HTTPMethod, HTTPRequest from indico.config import IndicoConfig from indico.errors import IndicoError @@ -52,18 +51,17 @@ async def test_client_basic_http_request(indico_request, auth, indico_test_confi await client.cleanup() async with AsyncIndicoClient(config=indico_test_config) as client: - response = await client.call( HTTPRequest(method=HTTPMethod.GET, path="/users/details") ) assert response == {"test": True} + async def test_client_creation_error_handling(indico_test_config): client = AsyncIndicoClient() with pytest.raises(IndicoError): - await client.call( - HTTPRequest(method=HTTPMethod.GET, path="/users/details") - ) + await client.call(HTTPRequest(method=HTTPMethod.GET, path="/users/details")) + async def test_client_graphql_text_request(indico_request, auth, indico_test_config): client = await AsyncIndicoClient(config=indico_test_config).create() diff --git a/tests/unit/client/test_client.py b/tests/unit/client/test_client.py index fddac7ef..98b8f3ab 100644 --- a/tests/unit/client/test_client.py +++ b/tests/unit/client/test_client.py @@ -1,8 +1,8 @@ import unittest.mock -import pytest +import pytest -from indico.client import IndicoClient, HTTPRequest, HTTPMethod, GraphQLRequest +from indico.client import GraphQLRequest, HTTPMethod, HTTPRequest, IndicoClient from indico.config import IndicoConfig @@ -10,6 +10,7 @@ def indico_test_config(): return IndicoConfig(protocol="mock", host="mock") + @pytest.fixture(scope="function") def indico_request(requests_mock, indico_test_config): def new_request_mock(method, path, *args, **kwargs): @@ -70,8 +71,13 @@ def test_client_verify_true_request(indico_request, auth, indico_test_config): def test_client_verify_false_request(indico_request, auth, indico_test_config): - client = IndicoClient(IndicoConfig(verify_ssl=False, host=indico_test_config.host, - protocol=indico_test_config.protocol)) + client = IndicoClient( + IndicoConfig( + verify_ssl=False, + host=indico_test_config.host, + protocol=indico_test_config.protocol, + ) + ) indico_request( "post", "/graph/api/graphql", @@ -89,8 +95,13 @@ def test_client_verify_false_request(indico_request, auth, indico_test_config): def test_client_requests_params(indico_request, auth, indico_test_config): - client = IndicoClient(IndicoConfig(requests_params={"verify": False}, host=indico_test_config.host, - protocol=indico_test_config.protocol)) + client = IndicoClient( + IndicoConfig( + requests_params={"verify": False}, + host=indico_test_config.host, + protocol=indico_test_config.protocol, + ) + ) indico_request( "post", "/graph/api/graphql", @@ -105,6 +116,7 @@ def test_client_requests_params(indico_request, auth, indico_test_config): ) assert response == {"datasets": []} + def test_client_get_ipa_version(indico_request, auth, indico_test_config): client = IndicoClient(config=indico_test_config) indico_request( diff --git a/tests/unit/client/test_config.py b/tests/unit/client/test_config.py index 605bac46..4c88f830 100644 --- a/tests/unit/client/test_config.py +++ b/tests/unit/client/test_config.py @@ -1,5 +1,6 @@ import os from unittest import mock + from indico.config import IndicoConfig diff --git a/tests/unit/client/test_request.py b/tests/unit/client/test_request.py index da6c3f61..bf10f837 100644 --- a/tests/unit/client/test_request.py +++ b/tests/unit/client/test_request.py @@ -1,7 +1,9 @@ +from datetime import datetime + import pytest -from indico.client.request import HTTPMethod, HTTPRequest, GraphQLRequest -from datetime import datetime +from indico.client.request import GraphQLRequest, HTTPMethod, HTTPRequest + def test_http_request_properties(): data = {"test", "something"} @@ -31,6 +33,7 @@ def test_graphql_request_properties(): assert req.method == HTTPMethod.POST assert req.kwargs["json"] == {"query": query, "variables": variables} + def test_graphql_request_properties_datetime(): query = "query($ids: List(Int)) { dataset(ids: $ids) { name } }" variables = {"date": [datetime.now()]} diff --git a/tests/unit/http/mocked_data/application_json_ b/tests/unit/http/mocked_data/application_json_ index e310edf6..5263b62e 100644 --- a/tests/unit/http/mocked_data/application_json_ +++ b/tests/unit/http/mocked_data/application_json_ @@ -1 +1 @@ -{"code": 401, "content": {}, "error_class": "ServiceException", "error_type": "LoginRequired", "message": "Login is required to access this route."} \ No newline at end of file +{"code": 401, "content": {}, "error_class": "ServiceException", "error_type": "LoginRequired", "message": "Login is required to access this route."} diff --git a/tests/unit/http/mocked_data/application_json_utf-8 b/tests/unit/http/mocked_data/application_json_utf-8 index e310edf6..5263b62e 100644 --- a/tests/unit/http/mocked_data/application_json_utf-8 +++ b/tests/unit/http/mocked_data/application_json_utf-8 @@ -1 +1 @@ -{"code": 401, "content": {}, "error_class": "ServiceException", "error_type": "LoginRequired", "message": "Login is required to access this route."} \ No newline at end of file +{"code": 401, "content": {}, "error_class": "ServiceException", "error_type": "LoginRequired", "message": "Login is required to access this route."} diff --git a/tests/unit/http/mocked_data/application_octet-stream_ b/tests/unit/http/mocked_data/application_octet-stream_ index dbf091df..21f9a10d 100644 --- a/tests/unit/http/mocked_data/application_octet-stream_ +++ b/tests/unit/http/mocked_data/application_octet-stream_ @@ -20,7 +20,7 @@ endobj << /Type /Pages /Count 2 -/Kids [ 4 0 R 6 0 R ] +/Kids [ 4 0 R 6 0 R ] >> endobj @@ -30,7 +30,7 @@ endobj /Parent 3 0 R /Resources << /Font << -/F1 9 0 R +/F1 9 0 R >> /ProcSet 8 0 R >> @@ -103,7 +103,7 @@ endobj /Parent 3 0 R /Resources << /Font << -/F1 9 0 R +/F1 9 0 R >> /ProcSet 8 0 R >> diff --git a/tests/unit/http/mocked_data/application_pdf_ b/tests/unit/http/mocked_data/application_pdf_ index dbf091df..21f9a10d 100644 --- a/tests/unit/http/mocked_data/application_pdf_ +++ b/tests/unit/http/mocked_data/application_pdf_ @@ -20,7 +20,7 @@ endobj << /Type /Pages /Count 2 -/Kids [ 4 0 R 6 0 R ] +/Kids [ 4 0 R 6 0 R ] >> endobj @@ -30,7 +30,7 @@ endobj /Parent 3 0 R /Resources << /Font << -/F1 9 0 R +/F1 9 0 R >> /ProcSet 8 0 R >> @@ -103,7 +103,7 @@ endobj /Parent 3 0 R /Resources << /Font << -/F1 9 0 R +/F1 9 0 R >> /ProcSet 8 0 R >> diff --git a/tests/unit/http/mocked_data/text_html_ISO-8859-1 b/tests/unit/http/mocked_data/text_html_ISO-8859-1 index 0831b180..e14e957d 100644 --- a/tests/unit/http/mocked_data/text_html_ISO-8859-1 +++ b/tests/unit/http/mocked_data/text_html_ISO-8859-1 @@ -5,4 +5,4 @@ if (!iesg){document.f&&document.f.q.focus();document.gbqf&&document.gbqf.q.focus })();



 

Advanced searchLanguage tools

© 2019 - Privacy - Terms

\ No newline at end of file +(function(){google.spjs=false;google.snet=true;google.em=[];google.emw=false;})();(function(){var pmc='{\x22CaHQXQ\x22:{},\x22Qnk92g\x22:{},\x22RWGcrA\x22:{},\x22U5B21g\x22:{},\x22YFCs/g\x22:{},\x22ZI/YVQ\x22:{},\x22d\x22:{},\x22mVopag\x22:{},\x22sb_he\x22:{\x22agen\x22:true,\x22cgen\x22:true,\x22client\x22:\x22heirloom-hp\x22,\x22dh\x22:true,\x22dhqt\x22:true,\x22ds\x22:\x22\x22,\x22ffql\x22:\x22en\x22,\x22fl\x22:true,\x22host\x22:\x22google.com\x22,\x22isbh\x22:28,\x22jsonp\x22:true,\x22msgs\x22:{\x22cibl\x22:\x22Clear Search\x22,\x22dym\x22:\x22Did you mean:\x22,\x22lcky\x22:\x22I\\u0026#39;m Feeling Lucky\x22,\x22lml\x22:\x22Learn more\x22,\x22oskt\x22:\x22Input tools\x22,\x22psrc\x22:\x22This search was removed from your \\u003Ca href\x3d\\\x22/history\\\x22\\u003EWeb History\\u003C/a\\u003E\x22,\x22psrl\x22:\x22Remove\x22,\x22sbit\x22:\x22Search by image\x22,\x22srch\x22:\x22Google Search\x22},\x22ovr\x22:{},\x22pq\x22:\x22\x22,\x22refpd\x22:true,\x22rfs\x22:[],\x22sbpl\x22:24,\x22sbpr\x22:24,\x22scd\x22:10,\x22sce\x22:5,\x22stok\x22:\x22okHA0IecwTcCeg7yXOR1yDbhnhY\x22,\x22uhde\x22:false}}';google.pmc=JSON.parse(pmc);})(); diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py index c7ed782e..412b37d6 100644 --- a/tests/unit/test_filters.py +++ b/tests/unit/test_filters.py @@ -1,8 +1,9 @@ -import pytest from datetime import datetime +import pytest + from indico.errors import IndicoInputError -from indico.filters import Filter, SubmissionFilter, and_, or_, DocumentReportFilter +from indico.filters import DocumentReportFilter, Filter, SubmissionFilter, and_, or_ def test_filter(): @@ -31,11 +32,11 @@ def test_doc_report_filter(): todays_date = datetime.now().strftime("%Y-%m-%d") filter_opts = DocumentReportFilter(created_at_start_date=datetime(2021, 7, 1)) - assert filter_opts['createdAt']['to'] == todays_date + assert filter_opts["createdAt"]["to"] == todays_date with pytest.raises(IndicoInputError): DocumentReportFilter(created_at_end_date=datetime.now()) filter_opts = DocumentReportFilter(updated_at_start_date=datetime(2021, 8, 1)) - assert filter_opts['updatedAt']['to'] == todays_date + assert filter_opts["updatedAt"]["to"] == todays_date with pytest.raises(IndicoInputError): DocumentReportFilter(updated_at_end_date=datetime.now()) diff --git a/tests/unit/types/test_base_type.py b/tests/unit/types/test_base_type.py index b6d13c6a..3d7bd06b 100644 --- a/tests/unit/types/test_base_type.py +++ b/tests/unit/types/test_base_type.py @@ -1,6 +1,8 @@ -from typing import List -from indico.types.base import BaseType, JSONType from datetime import datetime +from typing import Dict, List + +from indico.types.base import BaseType, JSONType + def test_setting_attributes_from_dict(): class A(BaseType): @@ -35,6 +37,7 @@ class B(BaseType): assert x.id == 1 assert x.a[0].id == 2 + def test_nested_empty_list_base_type(): class A(BaseType): id: int @@ -59,9 +62,7 @@ class B(BaseType): assert x.a[0] == "this is meta" - def test_nested_empty_list_simple_type(): - class B(BaseType): id: int a: List[str] @@ -93,6 +94,15 @@ def test_timestamp_to_datetime_field(): class A(BaseType): created_at: datetime - x = A(**{"createdAt": '1590169591.582852'}) + x = A(**{"createdAt": "1590169591.582852"}) assert x.created_at == datetime.fromtimestamp(1590169591.582852) + + +def test_generic_dict_field(): + class A(BaseType): + meta: Dict[str, str] + + x = A(meta={"foo": "bar"}) + + assert x.meta == {"foo": "bar"} diff --git a/tox.ini b/tox.ini index d8383ef1..569b0d52 100644 --- a/tox.ini +++ b/tox.ini @@ -6,9 +6,16 @@ envlist = py38,py39,py310,py311,py312 parallel_show_output = true # install pytest in the virtualenv where commands will be executed deps = - pytest - pytest-asyncio + pytest < 8 + pytest-asyncio > 0.21 requests-mock >= 1.8.0 + mypy == 1.8 + typing_extensions + pandas-stubs + types-requests + types-pytz + types-openpyxl commands = # NOTE: you can run any command line tool here - not just testsg + mypy --config-file=pyproject.toml pytest -sv tests/unit