diff --git a/zetta_utils/mazepa/tasks.py b/zetta_utils/mazepa/tasks.py index 7abef9914..4d724d9f7 100644 --- a/zetta_utils/mazepa/tasks.py +++ b/zetta_utils/mazepa/tasks.py @@ -172,7 +172,6 @@ class _TaskableOperation(Generic[P, R_co]): id_fn: Callable[[Callable, list, dict], str] = attrs.field( default=functools.partial(id_generation.generate_invocation_id, prefix="task") ) - worker_type: str | None = None runtime_limit_sec: float | None = None upkeep_interval_sec: float = constants.DEFAULT_UPKEEP_INTERVAL @@ -185,8 +184,9 @@ def __call__( def make_task( self, - *args: P.args, - **kwargs: P.kwargs, + *args, + worker_type=None, + **kwargs, ) -> Task[R_co]: id_ = self.id_fn(self.fn, list(args), kwargs) upkeep_settings = TaskUpkeepSettings( @@ -197,7 +197,7 @@ def make_task( fn=self.fn, operation_name=self.operation_name, id_=id_, - worker_type=self.worker_type, + worker_type=worker_type, upkeep_settings=upkeep_settings, runtime_limit_sec=self.runtime_limit_sec, ) @@ -254,7 +254,7 @@ def taskable_operation_cls( *, operation_name: str | None = None, ): - def _make_task(self, *args, **kwargs): + def _make_task(self, *args, worker_type, **kwargs): if operation_name is None: if hasattr(self, "get_operation_name"): operation_name_final = self.get_operation_name() # pragma: no cover @@ -267,7 +267,9 @@ def _make_task(self, *args, **kwargs): operation_name=operation_name_final, # TODO: Other params passed to decorator ).make_task( - *args, **kwargs + *args, + worker_type=worker_type, + **kwargs, ) # pylint: disable=protected-access return task