diff --git a/tests/unit/mazepa/maker_utils.py b/tests/unit/mazepa/maker_utils.py index 5abbd0429..f9f09c2e8 100644 --- a/tests/unit/mazepa/maker_utils.py +++ b/tests/unit/mazepa/maker_utils.py @@ -3,9 +3,9 @@ from zetta_utils.mazepa.tasks import _TaskableOperation -def make_test_task(fn, id_, worker_type=None, operation_name="DummyTask"): # TODO: type me +def make_test_task(fn, id_, operation_name="DummyTask"): # TODO: type me return _TaskableOperation( - fn=fn, operation_name=operation_name, id_fn=get_literal_id_fn(id_), worker_type=worker_type + fn=fn, operation_name=operation_name, id_fn=get_literal_id_fn(id_) ).make_task() diff --git a/tests/unit/mazepa/test_task_router.py b/tests/unit/mazepa/test_task_router.py index 3a6263925..05adefbb5 100644 --- a/tests/unit/mazepa/test_task_router.py +++ b/tests/unit/mazepa/test_task_router.py @@ -24,9 +24,9 @@ def test_push_tasks(mocker): queue_a.name = "_type_a" queue_b.name = "_type_b" meq = TaskRouter([queue_a, queue_b]) - task_a = make_test_task(lambda: None, id_="dummy", worker_type="type_a") - task_b = make_test_task(lambda: None, "dummy", worker_type="type_b") - task_bb = make_test_task(lambda: None, "dummy", worker_type="type_b") + task_a = make_test_task(lambda: None, id_="dummy").with_worker_type("type_a") + task_b = make_test_task(lambda: None, "dummy").with_worker_type("type_b") + task_bb = make_test_task(lambda: None, "dummy").with_worker_type("type_b") meq.push([task_a, task_b, task_bb]) queue_a.push.assert_called_with([task_a]) queue_b.push.assert_called_with([task_b, task_bb]) diff --git a/zetta_utils/mazepa/tasks.py b/zetta_utils/mazepa/tasks.py index 7abef9914..545cd720f 100644 --- a/zetta_utils/mazepa/tasks.py +++ b/zetta_utils/mazepa/tasks.py @@ -68,6 +68,9 @@ def _set_up(self, *args: Iterable, **kwargs: dict): self.args = args self.kwargs = kwargs + def with_worker_type(self, worker_type: str | None) -> Task: + return attrs.evolve(self, worker_type=worker_type) + def _call_task_fn(self, debug: bool = True) -> R_co: if debug or self.runtime_limit_sec is None: return_value = self.fn(*self.args, **self.kwargs) @@ -172,7 +175,6 @@ class _TaskableOperation(Generic[P, R_co]): id_fn: Callable[[Callable, list, dict], str] = attrs.field( default=functools.partial(id_generation.generate_invocation_id, prefix="task") ) - worker_type: str | None = None runtime_limit_sec: float | None = None upkeep_interval_sec: float = constants.DEFAULT_UPKEEP_INTERVAL @@ -197,7 +199,6 @@ def make_task( fn=self.fn, operation_name=self.operation_name, id_=id_, - worker_type=self.worker_type, upkeep_settings=upkeep_settings, runtime_limit_sec=self.runtime_limit_sec, )