diff --git a/examples/poll_auto_review.py b/examples/poll_auto_review.py new file mode 100644 index 0000000..8a22599 --- /dev/null +++ b/examples/poll_auto_review.py @@ -0,0 +1,60 @@ +""" +A feature-complete auto review polling script that incorporates asyncio, +automatic retry, result file dataclasses, and etl output dataclasses. + +Max workers, spawn rate, etl output loading, and retry behavior are all configurable. +See the `AutoReviewPoller` class definition. +""" + +import asyncio +import sys + +from indico import IndicoConfig + +from indico_toolkit.polling import AutoReviewPoller, AutoReviewed +from indico_toolkit.etloutput import EtlOutput +from indico_toolkit.results import Result, Document + + +async def auto_review( + result: Result, etl_outputs: dict[Document, EtlOutput] +) -> AutoReviewed: + """ + Apply auto review rules to predictions and determine straight through processing. + Any IO performed (network requests, file reads/writes, etc) must be awaited to + avoid blocking the asyncio loop that schedules this coroutine. + """ + predictions = result.pre_review + + # Apply auto review rules. + + return AutoReviewed( + changes=predictions.to_changes(result), + reject=False, # Defaults to `False` and may be omitted. + stp=False, # Defaults to `False` and may be omitted. + ) + + +if __name__ == "__main__": + import logging + + logging.basicConfig( + format=( + r"[%(asctime)s] " + r"%(name)s:%(funcName)s():%(lineno)s: " + r"%(levelname)s %(message)s" + ), + level=logging.INFO, + force=True, + ) + workflow_id = int(sys.argv[1]) + asyncio.run( + AutoReviewPoller( + IndicoConfig( + host="try.indico.io", + api_token_path="indico_api_token.txt", + ), + workflow_id, + auto_review, + ).poll_forever() + ) diff --git a/indico_toolkit/polling/__init__.py b/indico_toolkit/polling/__init__.py new file mode 100644 index 0000000..10201a4 --- /dev/null +++ b/indico_toolkit/polling/__init__.py @@ -0,0 +1,8 @@ +from .autoreview import AutoReviewed, AutoReviewPoller +from .downstream import DownstreamPoller + +__all__ = ( + "AutoReviewed", + "AutoReviewPoller", + "DownstreamPoller", +) diff --git a/indico_toolkit/polling/autoreview.py b/indico_toolkit/polling/autoreview.py new file mode 100644 index 0000000..c93fa9a --- /dev/null +++ b/indico_toolkit/polling/autoreview.py @@ -0,0 +1,203 @@ +import asyncio +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from indico import AsyncIndicoClient, IndicoConfig # type: ignore[import-untyped] +from indico.queries import ( # type: ignore[import-untyped] + GetSubmission, + JobStatus, + RetrieveStorageObject, + SubmitReview, +) + +from .. import etloutput, results +from ..etloutput import EtlOutput +from ..results import Document, Result +from ..retry import retry +from .queries import SubmissionIdsPendingAutoReview + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from typing import Any, NoReturn, TypeAlias + + SubmissionId: TypeAlias = int + Worker: TypeAlias = asyncio.Task[None] + WorkerQueue: TypeAlias = asyncio.Queue[tuple[SubmissionId, Worker]] + +logger = logging.getLogger(__name__) + + +@dataclass +class AutoReviewed: + changes: "dict[str, Any] | list[dict[str, Any]]" + reject: bool = False + stp: bool = False + + +class AutoReviewPoller: + """ + Polls for submissions requiring auto review, processes them, + and submits the review results concurrently. + """ + + def __init__( + self, + config: IndicoConfig, + workflow_id: int, + auto_review: "Callable[[Result, dict[Document, EtlOutput]], Awaitable[AutoReviewed]]", # noqa: E501 + *, + worker_count: int = 8, + spawn_rate: float = 1, + poll_delay: float = 30, + load_etl_output: bool = True, + load_text: bool = True, + load_tokens: bool = True, + load_tables: bool = False, + retry_count: int = 4, + retry_wait: float = 1, + retry_backoff: float = 4, + retry_jitter: float = 0.5, + ): + self._config = config + self._workflow_id = workflow_id + self._auto_review = auto_review + self._worker_count = worker_count + self._spawn_rate = spawn_rate + self._poll_delay = poll_delay + self._load_etl_output = load_etl_output + self._load_text = load_text + self._load_tokens = load_tokens + self._load_tables = load_tables + + self._retry = retry( + Exception, + count=retry_count, + wait=retry_wait, + backoff=retry_backoff, + jitter=retry_jitter, + ) + self._worker_slots = asyncio.Semaphore(worker_count) + self._worker_queue: "WorkerQueue" = asyncio.Queue(1) + self._processing_submission_ids: "set[SubmissionId]" = set() + + async def poll_forever(self) -> "NoReturn": # type: ignore[misc] + logger.info( + "Starting auto review poller for: " + f"host={self._config.host} " + f"workflow_id={self._workflow_id} " + f"worker_count={self._worker_count}" + ) + + async with AsyncIndicoClient(self._config) as client: + self._client_call = self._retry(client.call) + await asyncio.gather( + self._spawn_workers(), + *(self._reap_workers() for _ in range(self._worker_count)), + ) + + async def _retrieve_storage_object(self, url: str) -> object: + return await self._client_call(RetrieveStorageObject(url)) + + async def _spawn_workers(self) -> None: + """ + Poll for submissions pending auto review and spawn workers to process them. + `self._worker_slots` limits the number of workers that can run concurrently. + Submission IDs in progress are tracked with `self._processing_submission_ids`. + """ + logger.info( + f"Polling submissions pending auto review every {self._poll_delay} seconds" + ) + + while True: + try: + submission_ids: "set[SubmissionId]" = await self._client_call( + SubmissionIdsPendingAutoReview(self._workflow_id) + ) + except Exception: + logger.exception("Error occurred while polling submissions") + await asyncio.sleep(self._poll_delay) + continue + + submission_ids -= self._processing_submission_ids + + if not submission_ids: + await asyncio.sleep(self._poll_delay) + continue + + for submission_id in submission_ids: + await self._worker_slots.acquire() + logger.info(f"Spawning worker for {submission_id=}") + self._processing_submission_ids.add(submission_id) + worker = asyncio.create_task(self._worker(submission_id)) + await self._worker_queue.put((submission_id, worker)) + await asyncio.sleep(1 / self._spawn_rate) + + async def _worker(self, submission_id: "SubmissionId") -> None: + """ + Process a single submission by retrieving submission metadata, the result file, + etl output, calling `self._auto_review`, and submitting changes. + """ + logger.info(f"Retrieving metadata for {submission_id=}") + submission = await self._client_call(GetSubmission(submission_id)) + + logger.info(f"Retrieving results for {submission_id=}") + result = await results.load_async( + submission.result_file, + reader=self._retrieve_storage_object, + ) + + if self._load_etl_output: + logger.info(f"Retrieving etl output for {submission_id=}") + etl_outputs = { + document: await etloutput.load_async( + document.etl_output_url, + reader=self._retrieve_storage_object, + text=self._load_text, + tokens=self._load_tokens, + tables=self._load_tables, + ) + for document in result.documents + } + else: + logger.info(f"Skipping etl output for {submission_id=}") + etl_outputs = {} + + logger.info(f"Applying auto review for {submission_id=}") + auto_reviewed = await self._auto_review(result, etl_outputs) + + logger.info(f"Submitting auto review for {submission_id=}") + job = await self._client_call( + SubmitReview( + submission_id, + changes=auto_reviewed.changes, + rejected=auto_reviewed.reject, + force_complete=auto_reviewed.stp, + ) + ) + job = await self._client_call(JobStatus(job.id)) + + if job.status == "SUCCESS": + logger.info(f"Completed auto review of {submission_id=}") + else: + logger.error( + f"Submit failed for {submission_id=}: " + f"{job.status=!r} {job.result=!r}" + ) + + async def _reap_workers(self) -> None: + """ + Reap completed workers, releasing their slots for new tasks. Log errors for + submissions that failed to process. Remove their submission IDs from + `self._processing_submission_ids` to be retried. + """ + while True: + submission_id, worker = await self._worker_queue.get() + + try: + await worker + except Exception: + logger.exception(f"Error occurred while processing {submission_id=}") + + self._processing_submission_ids.remove(submission_id) + self._worker_slots.release() diff --git a/indico_toolkit/polling/downstream.py b/indico_toolkit/polling/downstream.py new file mode 100644 index 0000000..b22b7f1 --- /dev/null +++ b/indico_toolkit/polling/downstream.py @@ -0,0 +1,145 @@ +import asyncio +import logging +from typing import TYPE_CHECKING + +from indico import AsyncIndicoClient, IndicoConfig # type: ignore[import-untyped] +from indico.queries import ( # type: ignore[import-untyped] + GetSubmission, + UpdateSubmission, +) +from indico.types import Submission # type: ignore[import-untyped] + +from ..retry import retry +from .queries import SubmissionIdsPendingDownstream + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from typing import NoReturn, TypeAlias + + SubmissionId: TypeAlias = int + Worker: TypeAlias = asyncio.Task[None] + WorkerQueue: TypeAlias = asyncio.Queue[tuple[SubmissionId, Worker]] + +logger = logging.getLogger(__name__) + + +class DownstreamPoller: + """ + Polls for completed and failed submissions pending downstream egestion, processes + them concurrently, and marks them as retrieved. + """ + + def __init__( + self, + config: IndicoConfig, + workflow_id: int, + downstream: "Callable[[Submission], Awaitable[None]]", + *, + worker_count: int = 8, + spawn_rate: float = 1, + poll_delay: float = 30, + retry_count: int = 4, + retry_wait: float = 1, + retry_backoff: float = 4, + retry_jitter: float = 0.5, + ): + self._config = config + self._workflow_id = workflow_id + self._downstream = downstream + self._worker_count = worker_count + self._spawn_rate = spawn_rate + self._poll_delay = poll_delay + + self._retry = retry( + Exception, + count=retry_count, + wait=retry_wait, + backoff=retry_backoff, + jitter=retry_jitter, + ) + self._worker_slots = asyncio.Semaphore(worker_count) + self._worker_queue: "WorkerQueue" = asyncio.Queue(1) + self._processing_submission_ids: "set[SubmissionId]" = set() + + async def poll_forever(self) -> "NoReturn": # type: ignore[misc] + logger.info( + "Starting downstream poller for: " + f"host={self._config.host} " + f"workflow_id={self._workflow_id} " + f"worker_count={self._worker_count}" + ) + + async with AsyncIndicoClient(self._config) as client: + self._client_call = self._retry(client.call) + await asyncio.gather( + self._spawn_workers(), + *(self._reap_workers() for _ in range(self._worker_count)), + ) + + async def _spawn_workers(self) -> None: + """ + Poll for completed and failed submissions and spawn workers to send them + downstream. `self._worker_slots` limits the number of workers that can run + concurrently. Submission IDs in progress are tracked with + `self._processing_submission_ids`. + """ + logger.info( + f"Polling submissions pending downstream every {self._poll_delay} seconds" + ) + + while True: + try: + submission_ids: "set[SubmissionId]" = await self._client_call( + SubmissionIdsPendingDownstream(self._workflow_id) + ) + except Exception: + logger.exception("Error occurred while polling submissions") + await asyncio.sleep(self._poll_delay) + continue + + submission_ids -= self._processing_submission_ids + + if not submission_ids: + await asyncio.sleep(self._poll_delay) + continue + + for submission_id in submission_ids: + await self._worker_slots.acquire() + logger.info(f"Spawning worker for {submission_id=}") + self._processing_submission_ids.add(submission_id) + worker = asyncio.create_task(self._worker(submission_id)) + await self._worker_queue.put((submission_id, worker)) + await asyncio.sleep(1 / self._spawn_rate) + + async def _worker(self, submission_id: "SubmissionId") -> None: + """ + Process a single submission by retrieving submission metadata and calling + `self._downstream`. Once completed, mark the submission retrieved. + """ + logger.info(f"Retrieving metadata for {submission_id=}") + submission = await self._client_call(GetSubmission(submission_id)) + + logger.info(f"Sending {submission_id=} downstream") + await self._downstream(submission) + + logger.info(f"Marking {submission_id=} retrieved") + await self._client_call(UpdateSubmission(submission_id, retrieved=True)) + + logger.info(f"Completed dowstream of {submission_id=}") + + async def _reap_workers(self) -> None: + """ + Reap completed workers, releasing their slots for new tasks. Log errors for + submissions that failed to process. Remove their submission IDs from + `self._processing_submission_ids` to be retried. + """ + while True: + submission_id, worker = await self._worker_queue.get() + + try: + await worker + except Exception: + logger.exception(f"Error occurred while processing {submission_id=}") + + self._processing_submission_ids.remove(submission_id) + self._worker_slots.release() diff --git a/indico_toolkit/polling/queries.py b/indico_toolkit/polling/queries.py new file mode 100644 index 0000000..bd629f5 --- /dev/null +++ b/indico_toolkit/polling/queries.py @@ -0,0 +1,68 @@ +from typing import TYPE_CHECKING + +from indico.queries import GraphQLRequest # type: ignore[import-untyped] + +if TYPE_CHECKING: + from typing import Any + + +class SubmissionIdsPendingAutoReview(GraphQLRequest): # type: ignore[misc] + QUERY = """ + query SubmissionIdsPendingAutoReview($workflowIds: [Int]) { + submissions( + desc: false + filters: { status: PENDING_AUTO_REVIEW } + limit: 1000 + orderBy: ID + workflowIds: $workflowIds + ) { + submissions { + id + } + } + } + """ + + def __init__(self, workflow_id: int): + super().__init__(self.QUERY, {"workflowIds": [workflow_id]}) + + def process_response(self, response: "Any") -> set[int]: + response = super().process_response(response) + return { + submission["id"] for submission in response["submissions"]["submissions"] + } + + +class SubmissionIdsPendingDownstream(GraphQLRequest): # type: ignore[misc] + QUERY = """ + query SubmissionIdsPendingDownstream($workflowIds: [Int]) { + submissions( + desc: false + filters: { + AND: { + retrieved: false + OR: [ + { status: COMPLETE } + { status: FAILED } + ] + } + } + limit: 1000 + orderBy: ID + workflowIds: $workflowIds + ) { + submissions { + id + } + } + } + """ + + def __init__(self, workflow_id: int): + super().__init__(self.QUERY, {"workflowIds": [workflow_id]}) + + def process_response(self, response: "Any") -> set[int]: + response = super().process_response(response) + return { + submission["id"] for submission in response["submissions"]["submissions"] + } diff --git a/pyproject.toml b/pyproject.toml index 2919a7a..cbb7c57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ home-page = "https://github.com/IndicoDataSolutions/Indico-Solutions-Toolkit" classifiers = [ "License :: OSI Approved :: MIT License",] description-file = "README.md" requires = [ - "indico-client>=5.1.4", + "indico-client>=6.1.0", "plotly==5.2.1", "tqdm==4.50.0", "faker==13.3.3", diff --git a/requirements.txt b/requirements.txt index 0e8b39a..f2cb2f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -indico-client>=5.1.4 +indico-client>=6.1.0 python-dateutil==2.8.1 pytz==2021.1 pytest==8.3.4