-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #208 from IndicoDataSolutions/mawelborn/polling
Auto Review and Downstream Polling
- Loading branch information
Showing
7 changed files
with
486 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
""" | ||
A feature-complete auto review polling script that incorporates asyncio, | ||
automatic retry, result file dataclasses, and etl output dataclasses. | ||
Max workers, spawn rate, etl output loading, and retry behavior are all configurable. | ||
See the `AutoReviewPoller` class definition. | ||
""" | ||
|
||
import asyncio | ||
import sys | ||
|
||
from indico import IndicoConfig | ||
|
||
from indico_toolkit.polling import AutoReviewPoller, AutoReviewed | ||
from indico_toolkit.etloutput import EtlOutput | ||
from indico_toolkit.results import Result, Document | ||
|
||
|
||
async def auto_review( | ||
result: Result, etl_outputs: dict[Document, EtlOutput] | ||
) -> AutoReviewed: | ||
""" | ||
Apply auto review rules to predictions and determine straight through processing. | ||
Any IO performed (network requests, file reads/writes, etc) must be awaited to | ||
avoid blocking the asyncio loop that schedules this coroutine. | ||
""" | ||
predictions = result.pre_review | ||
|
||
# Apply auto review rules. | ||
|
||
return AutoReviewed( | ||
changes=predictions.to_changes(result), | ||
reject=False, # Defaults to `False` and may be omitted. | ||
stp=False, # Defaults to `False` and may be omitted. | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import logging | ||
|
||
logging.basicConfig( | ||
format=( | ||
r"[%(asctime)s] " | ||
r"%(name)s:%(funcName)s():%(lineno)s: " | ||
r"%(levelname)s %(message)s" | ||
), | ||
level=logging.INFO, | ||
force=True, | ||
) | ||
workflow_id = int(sys.argv[1]) | ||
asyncio.run( | ||
AutoReviewPoller( | ||
IndicoConfig( | ||
host="try.indico.io", | ||
api_token_path="indico_api_token.txt", | ||
), | ||
workflow_id, | ||
auto_review, | ||
).poll_forever() | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .autoreview import AutoReviewed, AutoReviewPoller | ||
from .downstream import DownstreamPoller | ||
|
||
__all__ = ( | ||
"AutoReviewed", | ||
"AutoReviewPoller", | ||
"DownstreamPoller", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
import asyncio | ||
import logging | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING | ||
|
||
from indico import AsyncIndicoClient, IndicoConfig # type: ignore[import-untyped] | ||
from indico.queries import ( # type: ignore[import-untyped] | ||
GetSubmission, | ||
JobStatus, | ||
RetrieveStorageObject, | ||
SubmitReview, | ||
) | ||
|
||
from .. import etloutput, results | ||
from ..etloutput import EtlOutput | ||
from ..results import Document, Result | ||
from ..retry import retry | ||
from .queries import SubmissionIdsPendingAutoReview | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Awaitable, Callable | ||
from typing import Any, NoReturn, TypeAlias | ||
|
||
SubmissionId: TypeAlias = int | ||
Worker: TypeAlias = asyncio.Task[None] | ||
WorkerQueue: TypeAlias = asyncio.Queue[tuple[SubmissionId, Worker]] | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class AutoReviewed: | ||
changes: "dict[str, Any] | list[dict[str, Any]]" | ||
reject: bool = False | ||
stp: bool = False | ||
|
||
|
||
class AutoReviewPoller: | ||
""" | ||
Polls for submissions requiring auto review, processes them, | ||
and submits the review results concurrently. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
config: IndicoConfig, | ||
workflow_id: int, | ||
auto_review: "Callable[[Result, dict[Document, EtlOutput]], Awaitable[AutoReviewed]]", # noqa: E501 | ||
*, | ||
worker_count: int = 8, | ||
spawn_rate: float = 1, | ||
poll_delay: float = 30, | ||
load_etl_output: bool = True, | ||
load_text: bool = True, | ||
load_tokens: bool = True, | ||
load_tables: bool = False, | ||
retry_count: int = 4, | ||
retry_wait: float = 1, | ||
retry_backoff: float = 4, | ||
retry_jitter: float = 0.5, | ||
): | ||
self._config = config | ||
self._workflow_id = workflow_id | ||
self._auto_review = auto_review | ||
self._worker_count = worker_count | ||
self._spawn_rate = spawn_rate | ||
self._poll_delay = poll_delay | ||
self._load_etl_output = load_etl_output | ||
self._load_text = load_text | ||
self._load_tokens = load_tokens | ||
self._load_tables = load_tables | ||
|
||
self._retry = retry( | ||
Exception, | ||
count=retry_count, | ||
wait=retry_wait, | ||
backoff=retry_backoff, | ||
jitter=retry_jitter, | ||
) | ||
self._worker_slots = asyncio.Semaphore(worker_count) | ||
self._worker_queue: "WorkerQueue" = asyncio.Queue(1) | ||
self._processing_submission_ids: "set[SubmissionId]" = set() | ||
|
||
async def poll_forever(self) -> "NoReturn": # type: ignore[misc] | ||
logger.info( | ||
"Starting auto review poller for: " | ||
f"host={self._config.host} " | ||
f"workflow_id={self._workflow_id} " | ||
f"worker_count={self._worker_count}" | ||
) | ||
|
||
async with AsyncIndicoClient(self._config) as client: | ||
self._client_call = self._retry(client.call) | ||
await asyncio.gather( | ||
self._spawn_workers(), | ||
*(self._reap_workers() for _ in range(self._worker_count)), | ||
) | ||
|
||
async def _retrieve_storage_object(self, url: str) -> object: | ||
return await self._client_call(RetrieveStorageObject(url)) | ||
|
||
async def _spawn_workers(self) -> None: | ||
""" | ||
Poll for submissions pending auto review and spawn workers to process them. | ||
`self._worker_slots` limits the number of workers that can run concurrently. | ||
Submission IDs in progress are tracked with `self._processing_submission_ids`. | ||
""" | ||
logger.info( | ||
f"Polling submissions pending auto review every {self._poll_delay} seconds" | ||
) | ||
|
||
while True: | ||
try: | ||
submission_ids: "set[SubmissionId]" = await self._client_call( | ||
SubmissionIdsPendingAutoReview(self._workflow_id) | ||
) | ||
except Exception: | ||
logger.exception("Error occurred while polling submissions") | ||
await asyncio.sleep(self._poll_delay) | ||
continue | ||
|
||
submission_ids -= self._processing_submission_ids | ||
|
||
if not submission_ids: | ||
await asyncio.sleep(self._poll_delay) | ||
continue | ||
|
||
for submission_id in submission_ids: | ||
await self._worker_slots.acquire() | ||
logger.info(f"Spawning worker for {submission_id=}") | ||
self._processing_submission_ids.add(submission_id) | ||
worker = asyncio.create_task(self._worker(submission_id)) | ||
await self._worker_queue.put((submission_id, worker)) | ||
await asyncio.sleep(1 / self._spawn_rate) | ||
|
||
async def _worker(self, submission_id: "SubmissionId") -> None: | ||
""" | ||
Process a single submission by retrieving submission metadata, the result file, | ||
etl output, calling `self._auto_review`, and submitting changes. | ||
""" | ||
logger.info(f"Retrieving metadata for {submission_id=}") | ||
submission = await self._client_call(GetSubmission(submission_id)) | ||
|
||
logger.info(f"Retrieving results for {submission_id=}") | ||
result = await results.load_async( | ||
submission.result_file, | ||
reader=self._retrieve_storage_object, | ||
) | ||
|
||
if self._load_etl_output: | ||
logger.info(f"Retrieving etl output for {submission_id=}") | ||
etl_outputs = { | ||
document: await etloutput.load_async( | ||
document.etl_output_url, | ||
reader=self._retrieve_storage_object, | ||
text=self._load_text, | ||
tokens=self._load_tokens, | ||
tables=self._load_tables, | ||
) | ||
for document in result.documents | ||
} | ||
else: | ||
logger.info(f"Skipping etl output for {submission_id=}") | ||
etl_outputs = {} | ||
|
||
logger.info(f"Applying auto review for {submission_id=}") | ||
auto_reviewed = await self._auto_review(result, etl_outputs) | ||
|
||
logger.info(f"Submitting auto review for {submission_id=}") | ||
job = await self._client_call( | ||
SubmitReview( | ||
submission_id, | ||
changes=auto_reviewed.changes, | ||
rejected=auto_reviewed.reject, | ||
force_complete=auto_reviewed.stp, | ||
) | ||
) | ||
job = await self._client_call(JobStatus(job.id)) | ||
|
||
if job.status == "SUCCESS": | ||
logger.info(f"Completed auto review of {submission_id=}") | ||
else: | ||
logger.error( | ||
f"Submit failed for {submission_id=}: " | ||
f"{job.status=!r} {job.result=!r}" | ||
) | ||
|
||
async def _reap_workers(self) -> None: | ||
""" | ||
Reap completed workers, releasing their slots for new tasks. Log errors for | ||
submissions that failed to process. Remove their submission IDs from | ||
`self._processing_submission_ids` to be retried. | ||
""" | ||
while True: | ||
submission_id, worker = await self._worker_queue.get() | ||
|
||
try: | ||
await worker | ||
except Exception: | ||
logger.exception(f"Error occurred while processing {submission_id=}") | ||
|
||
self._processing_submission_ids.remove(submission_id) | ||
self._worker_slots.release() |
Oops, something went wrong.