Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: set/get progress #130

Merged
merged 18 commits into from
Jun 13, 2024
Merged
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ jobs:
strategy:
matrix:
py_version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
pydantic_ver: ["<2", ">=2,<3"]
pydantic_ver: ["<2", ">=2.5,<3"]
Sobes76rus marked this conversation as resolved.
Show resolved Hide resolved
os: [ubuntu-latest, windows-latest]
runs-on: "${{ matrix.os }}"
steps:
Expand Down
28 changes: 27 additions & 1 deletion taskiq/abc/result_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar

from taskiq.result import TaskiqResult

if TYPE_CHECKING: # pragma: no cover
from taskiq.depends.progress_tracker import TaskProgress


_ReturnType = TypeVar("_ReturnType")


Expand Down Expand Up @@ -50,3 +54,25 @@ async def get_result(
:param with_logs: if True it will download task's logs.
:return: task's return value.
"""

async def set_progress(
Sobes76rus marked this conversation as resolved.
Show resolved Hide resolved
self,
task_id: str,
progress: "TaskProgress[Any]",
) -> None:
"""
Saves progress.

:param task_id: task's id.
:param progress: progress of execution.
"""

async def get_progress(
self,
task_id: str,
) -> "Optional[TaskProgress[Any]]":
"""
Gets progress.

:param task_id: task's id.
"""
35 changes: 34 additions & 1 deletion taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import asyncio
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncGenerator, Set, TypeVar
from typing import Any, AsyncGenerator, Optional, Set, TypeVar

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.events import TaskiqEvents
from taskiq.exceptions import TaskiqError
from taskiq.message import BrokerMessage
Expand All @@ -27,6 +28,7 @@ class InmemoryResultBackend(AsyncResultBackend[_ReturnType]):
def __init__(self, max_stored_results: int = 100) -> None:
self.max_stored_results = max_stored_results
self.results: OrderedDict[str, TaskiqResult[_ReturnType]] = OrderedDict()
self.progress: OrderedDict[str, TaskProgress[Any]] = OrderedDict()

async def set_result(self, task_id: str, result: TaskiqResult[_ReturnType]) -> None:
"""
Expand Down Expand Up @@ -79,6 +81,37 @@ async def get_result(
"""
return self.results[task_id]

async def set_progress(
self,
task_id: str,
progress: TaskProgress[Any],
) -> None:
"""
Set progress of task exection.
Sobes76rus marked this conversation as resolved.
Show resolved Hide resolved

:param task_id: task id
:param progress: task execution progress
"""
if (
self.max_stored_results != -1
and len(self.progress) >= self.max_stored_results
):
self.progress.popitem(last=False)

self.progress[task_id] = progress

async def get_progress(
self,
task_id: str,
) -> Optional[TaskProgress[Any]]:
"""
Get progress of task execution.

:param task_id: task id
:return: progress or None
"""
return self.progress.get(task_id)


class InMemoryBroker(AsyncBroker):
"""
Expand Down
72 changes: 72 additions & 0 deletions taskiq/depends/progress_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import enum
from typing import Generic, Optional, Union

from taskiq_dependencies import Depends
from typing_extensions import TypeVar

from taskiq.compat import IS_PYDANTIC2
from taskiq.context import Context

if IS_PYDANTIC2:
from pydantic import BaseModel as GenericModel
else:
from pydantic.generics import GenericModel # type: ignore[no-redef]


_ProgressType = TypeVar("_ProgressType")


class TaskState(str, enum.Enum):
"""State of task execution."""

STARTED = "STARTED"
FAILURE = "FAILURE"
SUCCESS = "SUCCESS"
RETRY = "RETRY"


class TaskProgress(GenericModel, Generic[_ProgressType]):
"""Progress of task execution."""

state: Union[TaskState, str]
meta: Optional[_ProgressType]


class ProgressTracker(Generic[_ProgressType]):
"""Task's dependency to set progress."""

def __init__(
self,
context: Context = Depends(),
) -> None:
self.context = context

async def set_progress(
self,
state: Union[TaskState, str],
meta: Optional[_ProgressType] = None,
) -> None:
"""Set progress.

:param state: TaskState or str
:param meta: progress data
"""
if meta is None:
progress = await self.get_progress()
meta = progress.meta if progress else None

progress = TaskProgress(
state=state,
meta=meta,
)

await self.context.broker.result_backend.set_progress(
self.context.message.task_id,
progress,
)

async def get_progress(self) -> Optional[TaskProgress[_ProgressType]]:
"""Get progress."""
return await self.context.broker.result_backend.get_progress(
self.context.message.task_id,
)
26 changes: 25 additions & 1 deletion taskiq/task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from abc import ABC, abstractmethod
from time import time
from typing import TYPE_CHECKING, Any, Coroutine, Generic, TypeVar, Union
from typing import TYPE_CHECKING, Any, Coroutine, Generic, Optional, Union

from typing_extensions import TypeVar

from taskiq.exceptions import (
ResultGetError,
Expand All @@ -11,6 +13,7 @@

if TYPE_CHECKING: # pragma: no cover
from taskiq.abc.result_backend import AsyncResultBackend
from taskiq.depends.progress_tracker import TaskProgress
from taskiq.result import TaskiqResult

_ReturnType = TypeVar("_ReturnType")
Expand Down Expand Up @@ -65,6 +68,19 @@ def wait_result(
:return: TaskiqResult.
"""

@abstractmethod
def get_progress(
self,
) -> Union[
"Optional[TaskProgress[Any]]",
Coroutine[Any, Any, "Optional[TaskProgress[Any]]"],
]:
"""
Get task progress.

:return: task's progress.
"""


class AsyncTaskiqTask(_Task[_ReturnType]):
"""AsyncTask for AsyncResultBackend."""
Expand Down Expand Up @@ -137,3 +153,11 @@ async def wait_result(
if 0 < timeout < time() - start_time:
raise TaskiqResultTimeoutError
return await self.get_result(with_logs=with_logs)

async def get_progress(self) -> "Optional[TaskProgress[Any]]":
"""
Get task progress.

:return: task's progress.
"""
return await self.result_backend.get_progress(self.task_id)
121 changes: 121 additions & 0 deletions tests/depends/test_progress_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, Optional

import pytest
from pydantic import ValidationError

from taskiq import (
AsyncTaskiqDecoratedTask,
InMemoryBroker,
TaskiqDepends,
TaskiqMessage,
)
from taskiq.abc import AsyncBroker
from taskiq.depends.progress_tracker import ProgressTracker, TaskState
from taskiq.receiver import Receiver


def get_receiver(
broker: Optional[AsyncBroker] = None,
no_parse: bool = False,
max_async_tasks: Optional[int] = None,
) -> Receiver:
"""
Returns receiver with custom broker and args.

:param broker: broker, defaults to None
:param no_parse: parameter to taskiq_args, defaults to False
:param cli_args: Taskiq worker CLI arguments.
:return: new receiver.
"""
if broker is None:
broker = InMemoryBroker()
return Receiver(
broker,
executor=ThreadPoolExecutor(max_workers=10),
validate_params=not no_parse,
max_async_tasks=max_async_tasks,
)


def get_message(
task: AsyncTaskiqDecoratedTask[Any, Any],
task_id: Optional[str] = None,
*args: Any,
labels: Optional[Dict[str, str]] = None,
**kwargs: Dict[str, Any],
) -> TaskiqMessage:
if labels is None:
labels = {}
return TaskiqMessage(
task_id=task_id or task.broker.id_generator(),
task_name=task.task_name,
labels=labels,
args=list(args),
kwargs=kwargs,
)


@pytest.mark.anyio
@pytest.mark.parametrize(
"state,meta",
[
(TaskState.STARTED, "hello world!"),
("retry", "retry error!"),
("custom state", {"Complex": "Value"}),
],
)
async def test_progress_tracker_ctx_raw(state: Any, meta: Any) -> None:
broker = InMemoryBroker()

@broker.task
async def test_func(tes_val: ProgressTracker[Any] = TaskiqDepends()) -> None:
await tes_val.set_progress(state, meta)

kicker = await test_func.kiq()
result = await kicker.wait_result()

assert not result.is_err
progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is not None
assert progress.meta == meta
assert progress.state == state


@pytest.mark.anyio
async def test_progress_tracker_ctx_none() -> None:
broker = InMemoryBroker()

@broker.task
async def test_func() -> None:
pass

kicker = await test_func.kiq()
result = await kicker.wait_result()

assert not result.is_err
progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is None


@pytest.mark.anyio
@pytest.mark.parametrize(
"state,meta",
[
(("state", "error"), 1),
],
)
async def test_progress_tracker_validation_error(state: Any, meta: Any) -> None:
broker = InMemoryBroker()

@broker.task
async def test_func(progress: ProgressTracker[int] = TaskiqDepends()) -> None:
await progress.set_progress(state, meta) # type: ignore

kicker = await test_func.kiq()
result = await kicker.wait_result()
with pytest.raises(ValidationError):
result.raise_for_error()

progress = await broker.result_backend.get_progress(kicker.task_id)
assert progress is None
Loading