Skip to content

Commit 00c7082

Browse files
authored
Added non-string labels in tasks. (#243)
1 parent 74467d0 commit 00c7082

File tree

11 files changed

+107
-16
lines changed

11 files changed

+107
-16
lines changed

taskiq/kicker.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from taskiq.abc.middleware import TaskiqMiddleware
2020
from taskiq.compat import model_dump
2121
from taskiq.exceptions import SendTaskError
22+
from taskiq.labels import prepare_label
2223
from taskiq.message import TaskiqMessage
2324
from taskiq.scheduler.created_schedule import CreatedSchedule
2425
from taskiq.scheduler.scheduled_task import CronSpec, ScheduledTask
@@ -245,12 +246,14 @@ def _prepare_message(
245246
formatted_args = []
246247
formatted_kwargs = {}
247248
labels = {}
249+
labels_types = {}
248250
for arg in args:
249251
formatted_args.append(self._prepare_arg(arg))
250252
for kwarg_name, kwarg_val in kwargs.items():
251253
formatted_kwargs[kwarg_name] = self._prepare_arg(kwarg_val)
254+
252255
for label, label_val in self.labels.items():
253-
labels[label] = str(label_val)
256+
labels[label], labels_types[label] = prepare_label(label_val)
254257

255258
task_id = self.custom_task_id
256259
if task_id is None:
@@ -260,6 +263,7 @@ def _prepare_message(
260263
task_id=task_id,
261264
task_name=self.task_name,
262265
labels=labels,
266+
labels_types=labels_types,
263267
args=formatted_args,
264268
kwargs=formatted_kwargs,
265269
)

taskiq/labels.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import base64
2+
import enum
3+
from typing import Any, Callable, Dict, Optional, Tuple
4+
5+
6+
class LabelType(enum.IntEnum):
7+
"""Possible label types."""
8+
9+
ANY = enum.auto()
10+
INT = enum.auto()
11+
STR = enum.auto()
12+
FLOAT = enum.auto()
13+
BOOL = enum.auto()
14+
BYTES = enum.auto()
15+
16+
17+
_LABEL_PARSERS: Dict[LabelType, Callable[[str], Any]] = {
18+
LabelType.INT: int,
19+
LabelType.STR: str,
20+
LabelType.FLOAT: float,
21+
LabelType.BOOL: lambda x: x.lower() == "true",
22+
LabelType.BYTES: base64.b64decode,
23+
LabelType.ANY: lambda x: x,
24+
}
25+
26+
27+
def prepare_label(label_value: Any) -> Tuple[str, int]:
28+
"""
29+
Prepare label value for serialization.
30+
31+
:param label_value: label value to prepare.
32+
:return: tuple of prepared label value and its type.
33+
"""
34+
var_type = type(label_value)
35+
if var_type in (int, str, float, bool):
36+
return str(label_value), LabelType[var_type.__name__.upper()].value
37+
if var_type == bytes:
38+
return base64.b64encode(label_value).decode(), LabelType.BYTES.value
39+
return str(label_value), LabelType.ANY.value
40+
41+
42+
def parse_label(label_value: Any, label_type: Optional[int] = None) -> Any:
43+
"""
44+
Parse label value from serialized format.
45+
46+
:param label_value: label value to parse.
47+
:param label_type: label type.
48+
:return: parsed label value.
49+
"""
50+
if label_type is None:
51+
return label_value
52+
label_type = LabelType(label_type)
53+
if label_type in _LABEL_PARSERS:
54+
return _LABEL_PARSERS[label_type](label_value)
55+
raise ValueError(f"Unsupported label type: {label_type}")

taskiq/message.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
from typing import Any, Dict, List
1+
from typing import Any, Dict, List, Optional
22

33
from pydantic import BaseModel
44

5+
from taskiq.labels import parse_label
6+
57

68
class TaskiqMessage(BaseModel):
79
"""
@@ -15,9 +17,23 @@ class TaskiqMessage(BaseModel):
1517
task_id: str
1618
task_name: str
1719
labels: Dict[str, Any]
20+
labels_types: Optional[Dict[str, int]] = None
1821
args: List[Any]
1922
kwargs: Dict[str, Any]
2023

24+
def parse_labels(self) -> None:
25+
"""
26+
Parse labels.
27+
28+
:return: None
29+
"""
30+
if self.labels_types is None:
31+
return
32+
33+
for label, label_type in self.labels_types.items():
34+
if label in self.labels:
35+
self.labels[label] = parse_label(self.labels[label], label_type)
36+
2137

2238
class BrokerMessage(BaseModel):
2339
"""Format of messages for brokers."""

taskiq/middlewares/retry_middleware.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from copy import deepcopy
21
from logging import getLogger
32
from typing import Any
43

54
from taskiq.abc.middleware import TaskiqMiddleware
65
from taskiq.exceptions import NoResultError
6+
from taskiq.kicker import AsyncKicker
77
from taskiq.message import TaskiqMessage
88
from taskiq.result import TaskiqResult
99

@@ -47,25 +47,32 @@ async def on_error(
4747
return
4848

4949
retry_on_error = message.labels.get("retry_on_error")
50+
if isinstance(retry_on_error, str):
51+
retry_on_error = retry_on_error.lower() == "true"
52+
5053
if retry_on_error is None:
51-
retry_on_error = "true" if self.default_retry_label else "false"
54+
retry_on_error = self.default_retry_label
5255
# Check if retrying is enabled for the task.
53-
if retry_on_error.lower() != "true":
56+
if not retry_on_error:
5457
return
55-
new_msg = deepcopy(message)
58+
59+
kicker: AsyncKicker[Any, Any] = AsyncKicker(
60+
task_name=message.task_name,
61+
broker=self.broker,
62+
labels=message.labels,
63+
).with_task_id(message.task_id)
5664

5765
# Getting number of previous retries.
58-
retries = int(new_msg.labels.get("_retries", 0)) + 1
59-
new_msg.labels["_retries"] = str(retries)
60-
max_retries = int(new_msg.labels.get("max_retries", self.default_retry_count))
66+
retries = int(message.labels.get("_retries", 0)) + 1
67+
kicker.with_labels(_retries=retries)
68+
max_retries = int(message.labels.get("max_retries", self.default_retry_count))
6169

6270
if retries < max_retries:
6371
logger.info(
6472
"Task '%s' invocation failed. Retrying.",
6573
message.task_name,
6674
)
67-
broker_message = self.broker.formatter.dumps(message=new_msg)
68-
await self.broker.kick(broker_message)
75+
await kicker.kiq(*message.args, **message.kwargs)
6976

7077
if self.no_result_on_retry:
7178
result.error = NoResultError()

taskiq/receiver/receiver.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ async def callback( # noqa: C901, PLR0912
100100
message_data = message.data if isinstance(message, AckableMessage) else message
101101
try:
102102
taskiq_msg = self.broker.formatter.loads(message=message_data)
103+
taskiq_msg.parse_labels()
103104
except Exception as exc:
104105
logger.warning(
105106
"Cannot parse message: %s. Skipping execution.\n %s",

taskiq/result/v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class TaskiqResult(GenericModel, Generic[_ReturnType]):
3333
log: Optional[str] = None
3434
return_value: _ReturnType
3535
execution_time: float
36-
labels: Dict[str, str] = Field(default_factory=dict)
36+
labels: Dict[str, Any] = Field(default_factory=dict)
3737

3838
error: Optional[BaseException] = None
3939

taskiq/result/v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class TaskiqResult(BaseModel, Generic[_ReturnType]):
2020
log: Optional[str] = None
2121
return_value: _ReturnType
2222
execution_time: float
23-
labels: Dict[str, str] = Field(default_factory=dict)
23+
labels: Dict[str, Any] = Field(default_factory=dict)
2424

2525
error: Optional[BaseException] = None
2626

taskiq/serializers/json_serializer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
from json import dumps, loads
2-
from typing import Any
2+
from typing import Any, Callable, Optional
33

44
from taskiq.abc.serializer import TaskiqSerializer
55

66

77
class JSONSerializer(TaskiqSerializer):
88
"""Default taskiq serizalizer."""
99

10+
def __init__(self, default: Optional[Callable[..., None]] = None) -> None:
11+
self.default = default
12+
1013
def dumpb(self, value: Any) -> bytes:
1114
"""
1215
Dumps taskiq message to some broker message format.
1316
1417
:param message: message to send.
1518
:return: Dumped message.
1619
"""
17-
return dumps(value).encode()
20+
return dumps(
21+
value,
22+
default=self.default,
23+
).encode()
1824

1925
def loadb(self, value: bytes) -> Any:
2026
"""

tests/formatters/test_json_formatter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ async def test_json_dumps() -> None:
2222
message=(
2323
b'{"task_id":"task-id","task_name":"task.name",'
2424
b'"labels":{"label1":1,"label2":"text"},'
25+
b'"labels_types":null,'
2526
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
2627
),
2728
labels={"label1": 1, "label2": "text"},

tests/formatters/test_proxy_formatter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ async def test_proxy_dumps() -> None:
2121
message=(
2222
b'{"task_id": "task-id", "task_name": "task.name", '
2323
b'"labels": {"label1": 1, "label2": "text"}, '
24+
b'"labels_types": null, '
2425
b'"args": [1, "a"], "kwargs": {"p1": "v1"}}'
2526
),
2627
labels={"label1": 1, "label2": "text"},

0 commit comments

Comments
 (0)