diff --git a/taskiq/kicker.py b/taskiq/kicker.py index 7ec8451f..5a3b9756 100644 --- a/taskiq/kicker.py +++ b/taskiq/kicker.py @@ -19,6 +19,7 @@ from taskiq.abc.middleware import TaskiqMiddleware from taskiq.compat import model_dump from taskiq.exceptions import SendTaskError +from taskiq.labels import prepare_label from taskiq.message import TaskiqMessage from taskiq.scheduler.created_schedule import CreatedSchedule from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask @@ -245,12 +246,14 @@ def _prepare_message( formatted_args = [] formatted_kwargs = {} labels = {} + labels_types = {} for arg in args: formatted_args.append(self._prepare_arg(arg)) for kwarg_name, kwarg_val in kwargs.items(): formatted_kwargs[kwarg_name] = self._prepare_arg(kwarg_val) + for label, label_val in self.labels.items(): - labels[label] = str(label_val) + labels[label], labels_types[label] = prepare_label(label_val) task_id = self.custom_task_id if task_id is None: @@ -260,6 +263,7 @@ def _prepare_message( task_id=task_id, task_name=self.task_name, labels=labels, + labels_types=labels_types, args=formatted_args, kwargs=formatted_kwargs, ) diff --git a/taskiq/labels.py b/taskiq/labels.py new file mode 100644 index 00000000..6ad3b952 --- /dev/null +++ b/taskiq/labels.py @@ -0,0 +1,55 @@ +import base64 +import enum +from typing import Any, Callable, Dict, Optional, Tuple + + +class LabelType(enum.IntEnum): + """Possible label types.""" + + ANY = enum.auto() + INT = enum.auto() + STR = enum.auto() + FLOAT = enum.auto() + BOOL = enum.auto() + BYTES = enum.auto() + + +_LABEL_PARSERS: Dict[LabelType, Callable[[str], Any]] = { + LabelType.INT: int, + LabelType.STR: str, + LabelType.FLOAT: float, + LabelType.BOOL: lambda x: x.lower() == "true", + LabelType.BYTES: base64.b64decode, + LabelType.ANY: lambda x: x, +} + + +def prepare_label(label_value: Any) -> Tuple[str, int]: + """ + Prepare label value for serialization. + + :param label_value: label value to prepare. + :return: tuple of prepared label value and its type. + """ + var_type = type(label_value) + if var_type in (int, str, float, bool): + return str(label_value), LabelType[var_type.__name__.upper()].value + if var_type == bytes: + return base64.b64encode(label_value).decode(), LabelType.BYTES.value + return str(label_value), LabelType.ANY.value + + +def parse_label(label_value: Any, label_type: Optional[int] = None) -> Any: + """ + Parse label value from serialized format. + + :param label_value: label value to parse. + :param label_type: label type. + :return: parsed label value. + """ + if label_type is None: + return label_value + label_type = LabelType(label_type) + if label_type in _LABEL_PARSERS: + return _LABEL_PARSERS[label_type](label_value) + raise ValueError(f"Unsupported label type: {label_type}") diff --git a/taskiq/message.py b/taskiq/message.py index 129073ff..675f7cf3 100644 --- a/taskiq/message.py +++ b/taskiq/message.py @@ -1,7 +1,9 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from pydantic import BaseModel +from taskiq.labels import parse_label + class TaskiqMessage(BaseModel): """ @@ -15,9 +17,23 @@ class TaskiqMessage(BaseModel): task_id: str task_name: str labels: Dict[str, Any] + labels_types: Optional[Dict[str, int]] = None args: List[Any] kwargs: Dict[str, Any] + def parse_labels(self) -> None: + """ + Parse labels. + + :return: None + """ + if self.labels_types is None: + return + + for label, label_type in self.labels_types.items(): + if label in self.labels: + self.labels[label] = parse_label(self.labels[label], label_type) + class BrokerMessage(BaseModel): """Format of messages for brokers.""" diff --git a/taskiq/middlewares/retry_middleware.py b/taskiq/middlewares/retry_middleware.py index 30ced347..0eaee1bc 100644 --- a/taskiq/middlewares/retry_middleware.py +++ b/taskiq/middlewares/retry_middleware.py @@ -1,9 +1,9 @@ -from copy import deepcopy from logging import getLogger from typing import Any from taskiq.abc.middleware import TaskiqMiddleware from taskiq.exceptions import NoResultError +from taskiq.kicker import AsyncKicker from taskiq.message import TaskiqMessage from taskiq.result import TaskiqResult @@ -47,25 +47,32 @@ async def on_error( return retry_on_error = message.labels.get("retry_on_error") + if isinstance(retry_on_error, str): + retry_on_error = retry_on_error.lower() == "true" + if retry_on_error is None: - retry_on_error = "true" if self.default_retry_label else "false" + retry_on_error = self.default_retry_label # Check if retrying is enabled for the task. - if retry_on_error.lower() != "true": + if not retry_on_error: return - new_msg = deepcopy(message) + + kicker: AsyncKicker[Any, Any] = AsyncKicker( + task_name=message.task_name, + broker=self.broker, + labels=message.labels, + ).with_task_id(message.task_id) # Getting number of previous retries. - retries = int(new_msg.labels.get("_retries", 0)) + 1 - new_msg.labels["_retries"] = str(retries) - max_retries = int(new_msg.labels.get("max_retries", self.default_retry_count)) + retries = int(message.labels.get("_retries", 0)) + 1 + kicker.with_labels(_retries=retries) + max_retries = int(message.labels.get("max_retries", self.default_retry_count)) if retries < max_retries: logger.info( "Task '%s' invocation failed. Retrying.", message.task_name, ) - broker_message = self.broker.formatter.dumps(message=new_msg) - await self.broker.kick(broker_message) + await kicker.kiq(*message.args, **message.kwargs) if self.no_result_on_retry: result.error = NoResultError() diff --git a/taskiq/receiver/receiver.py b/taskiq/receiver/receiver.py index 1400d451..73a56ed3 100644 --- a/taskiq/receiver/receiver.py +++ b/taskiq/receiver/receiver.py @@ -100,6 +100,7 @@ async def callback( # noqa: C901, PLR0912 message_data = message.data if isinstance(message, AckableMessage) else message try: taskiq_msg = self.broker.formatter.loads(message=message_data) + taskiq_msg.parse_labels() except Exception as exc: logger.warning( "Cannot parse message: %s. Skipping execution.\n %s", diff --git a/taskiq/result/v1.py b/taskiq/result/v1.py index f2225916..95297053 100644 --- a/taskiq/result/v1.py +++ b/taskiq/result/v1.py @@ -33,7 +33,7 @@ class TaskiqResult(GenericModel, Generic[_ReturnType]): log: Optional[str] = None return_value: _ReturnType execution_time: float - labels: Dict[str, str] = Field(default_factory=dict) + labels: Dict[str, Any] = Field(default_factory=dict) error: Optional[BaseException] = None diff --git a/taskiq/result/v2.py b/taskiq/result/v2.py index cf49ff25..6294a2e3 100644 --- a/taskiq/result/v2.py +++ b/taskiq/result/v2.py @@ -20,7 +20,7 @@ class TaskiqResult(BaseModel, Generic[_ReturnType]): log: Optional[str] = None return_value: _ReturnType execution_time: float - labels: Dict[str, str] = Field(default_factory=dict) + labels: Dict[str, Any] = Field(default_factory=dict) error: Optional[BaseException] = None diff --git a/taskiq/serializers/json_serializer.py b/taskiq/serializers/json_serializer.py index e7d8d38a..8bf41f87 100644 --- a/taskiq/serializers/json_serializer.py +++ b/taskiq/serializers/json_serializer.py @@ -1,5 +1,5 @@ from json import dumps, loads -from typing import Any +from typing import Any, Callable, Optional from taskiq.abc.serializer import TaskiqSerializer @@ -7,6 +7,9 @@ class JSONSerializer(TaskiqSerializer): """Default taskiq serizalizer.""" + def __init__(self, default: Optional[Callable[..., None]] = None) -> None: + self.default = default + def dumpb(self, value: Any) -> bytes: """ Dumps taskiq message to some broker message format. @@ -14,7 +17,10 @@ def dumpb(self, value: Any) -> bytes: :param message: message to send. :return: Dumped message. """ - return dumps(value).encode() + return dumps( + value, + default=self.default, + ).encode() def loadb(self, value: bytes) -> Any: """ diff --git a/tests/formatters/test_json_formatter.py b/tests/formatters/test_json_formatter.py index 96a85f4e..17a37185 100644 --- a/tests/formatters/test_json_formatter.py +++ b/tests/formatters/test_json_formatter.py @@ -22,6 +22,7 @@ async def test_json_dumps() -> None: message=( b'{"task_id":"task-id","task_name":"task.name",' b'"labels":{"label1":1,"label2":"text"},' + b'"labels_types":null,' b'"args":[1,"a"],"kwargs":{"p1":"v1"}}' ), labels={"label1": 1, "label2": "text"}, diff --git a/tests/formatters/test_proxy_formatter.py b/tests/formatters/test_proxy_formatter.py index 50d179af..8d583f16 100644 --- a/tests/formatters/test_proxy_formatter.py +++ b/tests/formatters/test_proxy_formatter.py @@ -21,6 +21,7 @@ async def test_proxy_dumps() -> None: message=( b'{"task_id": "task-id", "task_name": "task.name", ' b'"labels": {"label1": 1, "label2": "text"}, ' + b'"labels_types": null, ' b'"args": [1, "a"], "kwargs": {"p1": "v1"}}' ), labels={"label1": 1, "label2": "text"}, diff --git a/tox.ini b/tox.ini index 209e011c..141a0fd0 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,6 @@ env_list = skip_install = true allowlist_externals = poetry commands_pre = - poetry install + poetry install --all-extras commands = poetry run pytest -vv -n auto