Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement customizable serializer #214

Merged
merged 1 commit into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 62 additions & 52 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ freezegun = "^1.2.2"
pytest-mock = "^3.11.1"
tzlocal = "^5.0.1"
types-tzlocal = "^5.0.1.1"
types-pytz = "^2023.3.1.1"

[tool.poetry.extras]
zmq = ["pyzmq"]
Expand Down
20 changes: 18 additions & 2 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from typing_extensions import ParamSpec, Self, TypeAlias

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.acks import AckableMessage
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.events import TaskiqEvents
from taskiq.formatters.json_formatter import JSONFormatter
from taskiq.formatters.proxy_formatter import ProxyFormatter
from taskiq.message import BrokerMessage
from taskiq.result_backends.dummy import DummyResultBackend
from taskiq.serializers.json_serializer import JSONSerializer
from taskiq.state import TaskiqState
from taskiq.utils import maybe_awaitable, remove_suffix
from taskiq.warnings import TaskiqDeprecationWarning
Expand Down Expand Up @@ -97,7 +99,8 @@ def __init__(
self.middlewares: "List[TaskiqMiddleware]" = []
self.result_backend = result_backend
self.decorator_class = AsyncTaskiqDecoratedTask
self.formatter: "TaskiqFormatter" = JSONFormatter()
self.serializer: TaskiqSerializer = JSONSerializer()
self.formatter: "TaskiqFormatter" = ProxyFormatter(self)
self.id_generator = task_id_generator
self.local_task_registry: Dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {}
# Every event has a list of handlers.
Expand Down Expand Up @@ -479,6 +482,19 @@ def with_event_handlers(
self.event_handlers[event].extend(handlers)
return self

def with_serializer(
self,
serializer: TaskiqSerializer,
) -> "Self": # pragma: no cover
"""
Set a new serializer and return an updated broker.

:param serializer: new serializer.
:return: self
"""
self.serializer = serializer
return self

def _register_task(
self,
task_name: str,
Expand Down
24 changes: 24 additions & 0 deletions taskiq/abc/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import Any


class TaskiqSerializer(ABC):
"""Custom serializer for brokers."""

@abstractmethod
def dumpb(self, value: Any) -> bytes:
"""
Dump value to bytes for sending through the wire.

:param value: value to encode.
:return: encoded value.
"""

@abstractmethod
def loadb(self, value: bytes) -> Any:
"""
Parse byte-encoded value received from the wire.

:param message: value to parse.
:return: decoded value.
"""
18 changes: 18 additions & 0 deletions taskiq/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
def parse_obj_as(annot: T, obj: Any) -> T:
return pydantic.TypeAdapter(annot).validate_python(obj)

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
) -> Model:
return model_class.model_validate(message)

def model_dump(instance: Model) -> Dict[str, Any]:
return instance.model_dump()

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
Expand All @@ -37,6 +46,15 @@ def model_copy(
else:
parse_obj_as = pydantic.parse_obj_as # type: ignore

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
) -> Model:
return model_class.parse_obj(message)

def model_dump(instance: Model) -> Dict[str, Any]:
return instance.dict()

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
Expand Down
2 changes: 1 addition & 1 deletion taskiq/formatters/json_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class JSONFormatter(TaskiqFormatter):
"""Default taskiq formatter."""
"""JSON taskiq formatter."""

def dumps(self, message: TaskiqMessage) -> BrokerMessage:
"""
Expand Down
38 changes: 38 additions & 0 deletions taskiq/formatters/proxy_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TYPE_CHECKING

from taskiq.abc.formatter import TaskiqFormatter
from taskiq.compat import model_dump, model_validate
from taskiq.message import BrokerMessage, TaskiqMessage

if TYPE_CHECKING:
from taskiq.abc.broker import AsyncBroker


class ProxyFormatter(TaskiqFormatter):
"""Default taskiq formatter."""

def __init__(self, broker: "AsyncBroker") -> None:
self.broker = broker

def dumps(self, message: TaskiqMessage) -> BrokerMessage:
"""
Dumps taskiq message to some broker message format.

:param message: message to send.
:return: Dumped message.
"""
return BrokerMessage(
task_id=message.task_id,
task_name=message.task_name,
message=self.broker.serializer.dumpb(model_dump(message)),
labels=message.labels,
)

def loads(self, message: bytes) -> TaskiqMessage:
"""
Loads json from message.

:param message: broker's message.
:return: parsed taskiq message.
"""
return model_validate(TaskiqMessage, self.broker.serializer.loadb(message))
4 changes: 2 additions & 2 deletions taskiq/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TaskiqMessage(BaseModel):

task_id: str
task_name: str
labels: Dict[str, str]
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]

Expand All @@ -25,4 +25,4 @@ class BrokerMessage(BaseModel):
task_id: str
task_name: str
message: bytes
labels: Dict[str, str]
labels: Dict[str, Any]
1 change: 1 addition & 0 deletions taskiq/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Taskiq serializers."""
26 changes: 26 additions & 0 deletions taskiq/serializers/json_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from json import dumps, loads
from typing import Any

from taskiq.abc.serializer import TaskiqSerializer


class JSONSerializer(TaskiqSerializer):
"""Default taskiq serizalizer."""

def dumpb(self, value: Any) -> bytes:
"""
Dumps taskiq message to some broker message format.

:param message: message to send.
:return: Dumped message.
"""
return dumps(value).encode()

def loadb(self, value: bytes) -> Any:
"""
Parse byte-encoded value received from the wire.

:param message: value to parse.
:return: decoded value.
"""
return loads(value.decode())
45 changes: 45 additions & 0 deletions tests/formatters/test_json_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from taskiq.formatters.json_formatter import JSONFormatter
from taskiq.message import BrokerMessage, TaskiqMessage


@pytest.mark.anyio
async def test_json_dumps() -> None:
fmt = JSONFormatter()
msg = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
expected = BrokerMessage(
task_id="task-id",
task_name="task.name",
message=(
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
),
labels={"label1": 1, "label2": "text"},
)
assert fmt.dumps(msg) == expected


@pytest.mark.anyio
async def test_json_loads() -> None:
fmt = JSONFormatter()
msg = (
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
)
expected = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
assert fmt.loads(msg) == expected
47 changes: 47 additions & 0 deletions tests/formatters/test_proxy_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.message import BrokerMessage, TaskiqMessage


@pytest.mark.anyio
async def test_proxy_dumps() -> None:
# uses json serializer by default
broker = InMemoryBroker()
msg = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
expected = BrokerMessage(
task_id="task-id",
task_name="task.name",
message=(
b'{"task_id": "task-id", "task_name": "task.name", '
b'"labels": {"label1": 1, "label2": "text"}, '
b'"args": [1, "a"], "kwargs": {"p1": "v1"}}'
),
labels={"label1": 1, "label2": "text"},
)
assert broker.formatter.dumps(msg) == expected


@pytest.mark.anyio
async def test_proxy_loads() -> None:
# uses json serializer by default
broker = InMemoryBroker()
msg = (
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
)
expected = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
assert broker.formatter.loads(msg) == expected
23 changes: 23 additions & 0 deletions tests/serializers/test_json_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from taskiq.serializers.json_serializer import JSONSerializer


@pytest.mark.anyio
async def test_json_dumpb() -> None:
serizalizer = JSONSerializer()
assert serizalizer.dumpb(None) == b"null" # noqa: PLR2004
assert serizalizer.dumpb(1) == b"1" # noqa: PLR2004
assert serizalizer.dumpb("a") == b'"a"' # noqa: PLR2004
assert serizalizer.dumpb(["a"]) == b'["a"]' # noqa: PLR2004
assert serizalizer.dumpb({"a": "b"}) == b'{"a": "b"}' # noqa: PLR2004


@pytest.mark.anyio
async def test_json_loadb() -> None:
serizalizer = JSONSerializer()
assert serizalizer.loadb(b"null") is None
assert serizalizer.loadb(b"1") == 1
assert serizalizer.loadb(b'"a"') == "a"
assert serizalizer.loadb(b'["a"]') == ["a"]
assert serizalizer.loadb(b'{"a": "b"}') == {"a": "b"}