diff --git a/docs/examples/stats/stats.py b/docs/examples/stats/stats.py new file mode 100644 index 0000000..63ffec6 --- /dev/null +++ b/docs/examples/stats/stats.py @@ -0,0 +1,155 @@ +from typing import Any +import asyncio +import os +import random + +from redis.asyncio import Redis +from taskiq import AsyncBroker, AsyncTaskiqTask, Context, TaskiqDepends, TaskiqResult +from taskiq.api import run_receiver_task +from taskiq.exceptions import ResultGetError +from taskiq.metrics import Stat, Stats +from taskiq_redis import ListQueueBroker, PubSubBroker, RedisAsyncResultBackend +from taskiq_redis.redis_broker import BaseRedisBroker +from taskiq.middlewares.stat_middleware import StatMiddleware + +random.seed() + +redis_result_url = "redis://localhost:6379/0" +redis_stat_url = "redis://localhost:6379/1" + +task_async_result: RedisAsyncResultBackend[Any] = RedisAsyncResultBackend( + redis_url=redis_result_url, +) + +broker = ListQueueBroker(url=redis_result_url).with_result_backend( + task_async_result, +) + +stat_async_result: RedisAsyncResultBackend[Any] = RedisAsyncResultBackend( + redis_url=redis_stat_url, +) + +stat_broker = PubSubBroker(url=redis_stat_url).with_result_backend( + stat_async_result, +) + +stat_middleware = StatMiddleware(stat_broker=stat_broker) +broker.add_middlewares(stat_middleware) + + +async def get_task_result( + broker: AsyncBroker, + task_id: str, +) -> TaskiqResult[Any] | None: + """Get task result from redis by task_id.""" + task = AsyncTaskiqTask(task_id=task_id, result_backend=broker.result_backend) + try: + if task_result := await task.get_result(): + return task_result + except ResultGetError: + pass + return None + + +async def get_keys( + broker: AsyncBroker, + prefix: str, + max_count: int = 50, +) -> list[str]: + """Get redis keys via scan by prefix.""" + keys = [] + if isinstance(broker, BaseRedisBroker) and isinstance( + broker.result_backend, + RedisAsyncResultBackend, + ): + async with Redis(connection_pool=broker.result_backend.redis_pool) as redis: + async for key in redis.scan_iter(f"{prefix}:*"): + keys.append(key.decode() if isinstance(key, bytes) else key) + if len(keys) == max_count: + break + return keys + + +@stat_broker.task() +async def get_stats(context: Context = TaskiqDepends()) -> Stat: + """ + Task to get stats from StatMiddleware instance of each worker process. + + As soon as we use pub-sub broker inside StatMiddleware, and it starts inside each + worker process, the result value will be overwritten inside result_backend. + So we need to change task_id of the task to be able to gather many results + from different workers by adding process pid to the initial task_id. + + To get all results one must scan result_backend with pattern: + task_id:* + and aggregate all results together. + """ + context.message.task_id = f"{context.message.task_id}:{os.getpid()}" + return stat_middleware.get_stats() + + +@broker.task() +async def get_all_stats(timeout: float = 0.2) -> Stats: + """Gathers results of get_stats task from all running workers.""" + results = {} + if task := await get_stats.kiq(): + task_id = task.task_id + await asyncio.sleep(timeout) + if keys := await get_keys( + broker=stat_broker, + prefix=task_id, + ): + for key in keys: + try: + _, worker_pid = key.split(":") + except ValueError: + continue + if result := await get_task_result(broker=stat_broker, task_id=key): + results[int(worker_pid)] = result.return_value + return Stats(workers=results) + + +@broker.task() +async def demo_task(timeout: float = 0.1) -> bool: + print(f"demo_task(timeout={timeout})") + await asyncio.sleep(timeout) + return random.choice([True, False]) + + +async def main() -> None: + # Emulate we run taskiq worker processes with several workers. + broker.is_worker_process = True + # Await broker startup with stat_middleware startup that starts pub-sub worker + await broker.startup() + # Start random number of workers + worker_tasks = [] + for _ in range(random.randint(2, 5)): + worker_task = asyncio.create_task(run_receiver_task(broker)) + worker_tasks.append(worker_task) + + # Start random number of tasks with random execution time + for _ in range(random.randint(2, 10)): + await demo_task.kiq(timeout=random.random()) + + # Wait a little + await asyncio.sleep(0.5) + + # Start task to gather stats from all workers + get_stats_task = await get_all_stats.kiq() + stats_result = await get_stats_task.wait_result() + if stats_result: + print("Stats of all workers:\n\t", stats_result.return_value) + + # Stop workers. + for worker_task in worker_tasks: + worker_task.cancel() + try: + await worker_task + except asyncio.CancelledError: + print("Worker successfully exited.") + + await broker.shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 355e69d..7fbe47d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,8 +51,8 @@ cbor2 = { version = "^5", optional = true } [tool.poetry.dev-dependencies] pytest = "^7.1.2" ruff = "^0" -black = { version = "^22.6.0", allow-prereleases = true } -mypy = "^1" +black = { version = "^23.7.0", allow-prereleases = true } +mypy = "^1.12.0" pre-commit = "^2.20.0" coverage = "^6.4.2" pytest-cov = "^3.0.0" diff --git a/taskiq/metrics.py b/taskiq/metrics.py new file mode 100644 index 0000000..fb939e9 --- /dev/null +++ b/taskiq/metrics.py @@ -0,0 +1,275 @@ +from typing import Any, Self + +from pydantic import BaseModel, Field, computed_field, model_validator + + +class Label(BaseModel): + """Base Label class.""" + + name: str + count: int = 0 + + +class CounterLabel(Label): + """ + Counter label class. + + Has just one count metric. + Used withe count or quantity data. + Attributes: + count: int + """ + + def inc(self) -> None: + """Increments counter by 1.""" + self.count += 1 + + +class AggregateLabel(Label): + """ + Aggregate label class. + + Aggregation counter. + Used for aggregation of data with different values. + Duration or so. + Attributes: + count: int - the quantity of aggregated values + average: float - the average value since start + avg_dev: float - average deviation from average value + max_dev: float - maximum absolute deviation from average value + """ + + average: float = 0.0 + avg_dev: float = 0.0 + max_dev: float = 0.0 + + def aggregate(self, value: float) -> None: + """ + Aggregates value and calculates average, avg_dev, max_dev. + + Aggregation increases count attribute. + :param value: float - value to aggregate + """ + new_count = self.count + 1 + delta = value - self.average + self.average = (self.average * self.count + value) / new_count + self.avg_dev = (self.avg_dev * (self.count - 1) + value) / new_count + self.count = new_count + if abs(delta) > abs(self.max_dev): + self.max_dev = delta + + +class Labels[T: Label](BaseModel): + """ + Base stat class. + + Base collection class for counters. + Implements label(label_name) function. + + Attributes: + name: name of the stat counter + description: description of the counter + labels: list of counter labels + labels_dict: convenience dict to speed up label selection + label: function of selecting or creating label with specified name + reset: function of resetting counter labels and values, used in serialization + """ + + label_class: type[T] = Field( + default=CounterLabel, + exclude=True, + ) + name: str + description: str | None = Field(default=None) + labels: list[T] = Field(default=[]) + labels_dict: dict[str, T] = Field( + default={}, + exclude=True, + ) + + @model_validator(mode="after") + def _post_create(self) -> Self: + if self.labels and not self.labels_dict: + self.labels_dict = {label.name: label for label in self.labels} + return self + + def label(self, label_name: str) -> T: + """Returns Label subclass instance from labels_dict.""" + if label := self.labels_dict.get(label_name): + return label + label = self.label_class(name=label_name) + self.labels_dict[label_name] = label + self.labels.append(label) + return label + + def reset(self) -> None: + """Resets counter.""" + self.labels = [] + self.labels_dict = {} + + +class Counter(Labels[CounterLabel]): + """ + Counter stats class. + + Simple counter class. + Just increments count value. + Counts sum count value of all labels. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(label_class=CounterLabel, **kwargs) + + @computed_field() + def count(self) -> int: + """Returns Counter value.""" + return sum([cnt.count for cnt in self.labels]) + + def __add__(self, other: "Counter") -> "Counter": + if not isinstance(other, Counter): + raise ValueError("other is not instance of type Counter.") + counter = Counter(name=self.name, description=self.description) + for label_name in set( + list(self.labels_dict.keys()) + list(other.labels_dict.keys()), + ): + self_label = self.labels_dict.get(label_name) or self.label_class( + name=label_name, + ) + other_label = other.labels_dict.get(label_name) or other.label_class( + name=label_name, + ) + counter.label(label_name).count = self_label.count + other_label.count + return counter + + +class Aggregator(Labels[AggregateLabel]): + """ + Aggregator stats class. + + Aggregator counter class is more complex stat metric, for using with duration data. + """ + + def __init__(self, **kwargs: Any) -> None: + super().__init__(label_class=AggregateLabel, **kwargs) + + def __add__(self, other: "Aggregator") -> "Aggregator": + if not isinstance(other, Aggregator): + raise ValueError("other is not instance of type Aggregator.") + aggregator = Aggregator(name=self.name, description=self.description) + for label_name in set( + list(self.labels_dict.keys()) + list(other.labels_dict.keys()), + ): + self_label = self.labels_dict.get(label_name) or self.label_class( + name=label_name, + ) + other_label = other.labels_dict.get(label_name) or other.label_class( + name=label_name, + ) + agg_label = aggregator.label(label_name) + sum_count = self_label.count + other_label.count + agg_label.count = sum_count + agg_label.average = ( + self_label.average * self_label.count + + other_label.average * other_label.count + ) / sum_count + agg_label.avg_dev = ( + self_label.avg_dev * self_label.count + + other_label.avg_dev * other_label.count + ) / sum_count + agg_label.max_dev = ( + other_label.max_dev + if abs(other_label.max_dev) > abs(self_label.max_dev) + else self_label.max_dev + ) + return aggregator + + @computed_field() + def count(self) -> int: + """Returns sum count value of all labels.""" + return sum([cnt.count for cnt in self.labels]) + + @computed_field() + def average(self) -> float: + """Returns average value of all labels.""" + sum_count = sum([cnt.count for cnt in self.labels]) + if sum_count: + return sum([cnt.average * cnt.count for cnt in self.labels]) / sum_count + return 0.0 + + @computed_field() + def avg_dev(self) -> float: + """Returns average deviation value of all labels.""" + sum_count = sum([cnt.count for cnt in self.labels]) + if sum_count: + return sum([cnt.avg_dev * cnt.count for cnt in self.labels]) / sum_count + return 0.0 + + @computed_field() + def max_dev(self) -> float: + """Returns maximum deviation value of all labels.""" + devs = [cnt.max_dev for cnt in self.labels] + if devs: + abs_devs = [abs(dev) for dev in devs] + return devs[abs_devs.index(max(abs_devs))] + return 0.0 + + +class Stat(BaseModel): + """ + Stats model for the worker. + + Attributes: + task_errors: Counter for tasks with errors + received_tasks: Counter for all received tasks + execution_time: Aggregator for all executed tasks + """ + + task_errors: Counter = Field( + default=Counter( + name="task errors", + description="Number of tasks with errors", + ), + ) + received_tasks: Counter = Field( + default=Counter( + name="received tasks", + description="Number of received tasks", + ), + ) + execution_time: Aggregator = Field( + default=Aggregator( + name="execution time", + description="Time of function execution", + ), + ) + + def __add__(self, other: "Stat") -> "Stat": + if not isinstance(other, Stat): + raise ValueError("other is not instance of type Stat.") + stat = self.model_copy(deep=True) + stat.task_errors = self.task_errors + other.task_errors + stat.received_tasks = self.received_tasks + other.received_tasks + stat.execution_time = self.execution_time + other.execution_time + return stat + + def reset(self) -> None: + """Resets all counters.""" + self.task_errors.reset() + self.received_tasks.reset() + self.execution_time.reset() + + +class Stats(Stat): + """Summarized statistics for workers.""" + + workers: dict[int, Stat] | None = {} + + @model_validator(mode="after") + def _post_create(self) -> Self: + if self.workers: + self.reset() + for worker in self.workers.values(): + self.task_errors += worker.task_errors + self.received_tasks += worker.received_tasks + self.execution_time += worker.execution_time + return self diff --git a/taskiq/middlewares/stat_middleware.py b/taskiq/middlewares/stat_middleware.py new file mode 100644 index 0000000..3bc0364 --- /dev/null +++ b/taskiq/middlewares/stat_middleware.py @@ -0,0 +1,93 @@ +import asyncio +from asyncio import Task +from logging import getLogger +from typing import Any + +from taskiq.abc import AsyncBroker +from taskiq.abc.middleware import TaskiqMiddleware +from taskiq.api import run_receiver_task +from taskiq.message import TaskiqMessage +from taskiq.metrics import Stat +from taskiq.result import TaskiqResult + +logger = getLogger(__name__) + + +class StatMiddleware(TaskiqMiddleware): + """ + Middleware that gathers stat info. + + StatMiddleware runs on every worker it was added to. + So to gather statistics from each worker we need additional PUB-SUB broker. + + This middleware starts internal worker process for stat_broker. + + If stat_broker was passed to __init__, and middleware starts by broker process, + it also starts another worker for stat_broker to be able to reply on get_stats_task + of each worker process. + :param stat_broker: Any AsyncBroker (better be pub-sub type) + """ + + def __init__(self, stat_broker: AsyncBroker | None = None) -> None: + super().__init__() + self.stat_broker: AsyncBroker | None = stat_broker + self.stat_worker: Task[Any] | None = None + self.stat_metrics: Stat = Stat() + + async def startup(self) -> None: + """ + Startup event trigger. + + If stat_broker was passed to __init__ and we start on the worker side, + then we start stat_broker worker task to be able to run stat_broker tasks. + """ + logger.info("StatMiddleware startup") + if self.broker.is_worker_process and self.stat_broker: + self.stat_broker.is_worker_process = True + await self.stat_broker.startup() + self.stat_worker = asyncio.create_task(run_receiver_task(self.stat_broker)) + + async def shutdown(self) -> None: + """Shutdown event trigger.""" + logger.info("StatMiddleware shutdown") + if self.stat_worker: + self.stat_worker.cancel() + if self.stat_broker: + await self.stat_broker.shutdown() + + def get_stats(self) -> Stat: + """Returns dump of counters.""" + return self.stat_metrics + + def pre_execute( + self, + message: TaskiqMessage, + ) -> TaskiqMessage: + """ + Function to track received tasks. + + This function increments a counter of received tasks, + when called. + + :param message: current message. + :return: message + """ + self.stat_metrics.received_tasks.label(message.task_name).inc() + return message + + def post_execute( + self, + message: TaskiqMessage, + result: TaskiqResult[Any], + ) -> None: + """ + This function tracks number of errors and success executions. + + :param message: received message. + :param result: result of the execution. + """ + if result.is_err: + self.stat_metrics.task_errors.label(message.task_name).inc() + self.stat_metrics.execution_time.label(message.task_name).aggregate( + result.execution_time, + ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..d8d3e33 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,91 @@ +import random + +from taskiq.metrics import Aggregator, Counter, Stat, Stats + +random.seed() + +data = [5, 7, 9, 0, 1, 3, 5, 5, 3, 2, 6, 9, 5, 1, 7, 9, 9, 4, 3, 8, 7, 5, 9, 0, 2, 4, 5] +labels = ["label 1", "label 2", "label 3", "label 4", "label 5"] + + +def test_counter() -> None: + metrix = {} + for worker_pid in (1, 2): + counter = Counter( + name=f"test counter for worker {worker_pid}", + description="test counter description", + ) + metrix[worker_pid] = counter + for _ in data: + label = random.choice(labels) + counter.label(label).inc() + dump = counter.model_dump() + assert Counter.model_validate(dump) + cnt1, cnt2 = metrix.values() + cnt = cnt1 + cnt2 + assert cnt.count == cnt1.count + cnt2.count # type: ignore + + +def test_aggregator() -> None: + metrix = {} + for worker_pid in (1, 2): + agg = Aggregator( + name=f"test aggregator for worker {worker_pid}", + description="test aggregator description", + ) + metrix[worker_pid] = agg + for val in data: + label = random.choice(labels) + agg.label(label).aggregate(val) + dump = agg.model_dump() + assert Aggregator.model_validate(dump) + agg1, agg2 = metrix.values() + agg = agg1 + agg2 + assert agg.count == agg1.count + agg2.count # type: ignore + + +def test_stat() -> None: + stats: dict[int, Stat] = {} + for worker_pid in (1, 2): + stat = Stat() + stats[worker_pid] = stat + for val in data: + label = random.choice(labels) + stat.received_tasks.label(label).inc() + stat.execution_time.label(label).aggregate(val) + dump = stat.model_dump(exclude_none=True) + assert Stat.model_validate(dump) + st1, st2 = stats.values() + st = st1 + st2 + assert ( + st.received_tasks.count == st1.received_tasks.count + st2.received_tasks.count # type: ignore + ) + assert ( + st.execution_time.count == st1.execution_time.count + st2.execution_time.count # type: ignore + ) + + +def test_stats() -> None: + stats = {} + for worker_pid in [11, 22]: + stat = Stat() + stats[worker_pid] = stat + for val in data: + label = random.choice(labels) + stat.received_tasks.label(label).inc() + stat.execution_time.label(label).aggregate(val) + sum_stats = Stats(workers=stats) + dump = sum_stats.model_dump(exclude_none=True) + assert sum_stats.received_tasks.count == sum( + [stat.received_tasks.count for stat in stats.values()], # type: ignore + ) + sum_stats1 = Stats.model_validate(dump) + assert isinstance(sum_stats1, Stats) + assert sum_stats.received_tasks.count == sum_stats1.received_tasks.count + assert sum_stats.execution_time.count == sum_stats1.execution_time.count + assert sum_stats.execution_time.average == sum_stats1.execution_time.average + assert ( + sum_stats.execution_time.average # type: ignore + == sum_stats1.execution_time.average + == sum(data) / len(data) + )