From ebcd58d58c9c5b449e08e2c2f197c650d303ad4f Mon Sep 17 00:00:00 2001 From: dodam Date: Sat, 15 Feb 2025 16:32:45 -0500 Subject: [PATCH] feat: worker types in subchunkable --- .../common/subchunkable_apply_flow.py | 29 ++++++++++++++----- .../common/volumetric_apply_flow.py | 12 ++++++-- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py index e3ebb76bf..419e941dd 100644 --- a/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/subchunkable_apply_flow.py @@ -91,6 +91,8 @@ def build_postpad_subchunkable_apply_flow( # pylint: disable=keyword-arg-before allow_cache_up_to_level: int | None = None, print_summary: bool = True, generate_ng_link: bool = False, + op_worker_type: str | None = None, + reduction_worker_type: str | None = None, fn: Callable[P, Tensor] | None = None, fn_semaphores: Sequence[SemaphoreType] | None = None, op: VolumetricOpProtocol[P, None, Any] | None = None, @@ -157,7 +159,7 @@ def build_postpad_subchunkable_apply_flow( # pylint: disable=keyword-arg-before f"\t`processing_blend`:\t\t{tuple(processing_blend)}\t(at all levels)\n" "As core chunk sizes, before padding for crop and blend:\n" f"\t`processing_chunk_sizes`:" - f"\t{', '.join(size.pformat() for size in processing_chunk_sizes)}\n" + "\t{', '.join(size.pformat() for size in processing_chunk_sizes)}\n" "The bottom level chunk size will be respected to maintain the input size of " f"{tuple(processing_input_sizes[-1])} while the other levels will be " "treated as upper bounds, fitting in as many chunks as possible." @@ -180,6 +182,8 @@ def build_postpad_subchunkable_apply_flow( # pylint: disable=keyword-arg-before shrink_processing_chunk=False, auto_divisibility=True, allow_cache_up_to_level=allow_cache_up_to_level, + op_worker_type=op_worker_type, + reduction_worker_type=reduction_worker_type, print_summary=print_summary, generate_ng_link=generate_ng_link, fn=fn, @@ -217,6 +221,8 @@ def build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg, shrink_processing_chunk: bool = False, auto_divisibility: bool = False, allow_cache_up_to_level: int | None = None, + op_worker_type: str | None = None, + reduction_worker_type: str | None = None, print_summary: bool = True, generate_ng_link: bool = False, fn: Callable[P, Tensor] | None = None, @@ -318,6 +324,10 @@ def build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg, :param allow_cache_up_to_level: The subchunking level (smallest is 0) where the cache for different remote layers should be cleared after the processing is done. Recommended to keep this at the level of the largest subchunks (default). + :param op_worker_type: The worker type required by the op. The subchunked tasks + will be routed to only the workers that have the requested worker type. + :param reduction_worker_type: The worker type required by the reduction process. The + subchunked tasks will be routed to only the workers that have the requested worker type. :param print_summary: Whether a summary should be printed. :param generate_ng_link: Whether a neuroglancer link should be generated in the summary. Requires ``print_summary``. @@ -384,10 +394,7 @@ def build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg, assert fn is not None op_ = VolumetricCallableOperation[P](fn, fn_semaphores=fn_semaphores) - if op_kwargs is not None: - op_kwargs_ = op_kwargs - else: - op_kwargs_ = {} + op_kwargs_ = op_kwargs if op_kwargs is not None else {} if generate_ng_link and not print_summary: raise ValueError("Cannot use `generate_ng_link` when `print_summary=False`.") @@ -564,6 +571,8 @@ def build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg, expand_bbox_processing=expand_bbox_processing, shrink_processing_chunk=shrink_processing_chunk, auto_divisibility=auto_divisibility, + op_worker_type=op_worker_type, + reduction_worker_type=reduction_worker_type, print_summary=print_summary, generate_ng_link=generate_ng_link, op_args=op_args, @@ -1068,6 +1077,8 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg expand_bbox_processing: bool, shrink_processing_chunk: bool, auto_divisibility: bool, + op_worker_type: str | None, + reduction_worker_type: str | None, print_summary: bool, generate_ng_link: bool, op: VolumetricOpProtocol[P, None, Any], @@ -1158,11 +1169,11 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg error_str = ( "At each level (where the 0-th level is the smallest), the" " `processing_chunk_size[level+1]` + 2*`processing_crop_pad[level+1]` + 2*`processing_blend_pad[level+1]`" - " + `processing_gap` must be" + " + processing_gap must be" f" evenly divisible by the `processing_chunk_size[level]` + processing_gap (processing_gap applies only on top level).\n" f"\nAt level {level}, received:\n" f"`processing_chunk_size[level+1]`:\t\t\t\t\t\t{processing_chunk_size_higher}\n" - f"`processing_gap`:\t\t\t\t\t\t\t\t{processing_gap_higher}\n" + f"`applicable processing_gap`:\t\t\t\t\t\t\t\t{processing_gap_higher}\n" f"`processing_crop_pad[level+1]` ((0, 0, 0) for the top level):\t\t\t{processing_crop_pad_higher}\n" f"`processing_blend_pad[level+1]`:\t\t\t\t\t\t{processing_blend_pad_higher}\n" f"Size of the region to be processed for the level:\t\t\t\t{processing_region}\n" @@ -1291,6 +1302,8 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg force_intermediaries=not (skip_intermediaries), flow_id=flow_id, l0_chunks_per_task=num_chunks_below[-1], + op_worker_type=op_worker_type, + reduction_worker_type=reduction_worker_type, ) """ Iteratively build the hierarchy of schemas @@ -1317,5 +1330,7 @@ def _build_subchunkable_apply_flow( # pylint: disable=keyword-arg-before-vararg force_intermediaries=not (skip_intermediaries), flow_id=flow_id, l0_chunks_per_task=num_chunks_below[-level - 1], + op_worker_type=op_worker_type, + reduction_worker_type=reduction_worker_type, ) return flow_schema(idx, dst, op_args, op_kwargs) diff --git a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py index 09b407228..a914b3466 100644 --- a/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py +++ b/zetta_utils/mazepa_layer_processing/common/volumetric_apply_flow.py @@ -351,6 +351,8 @@ class VolumetricApplyFlowSchema(Generic[P, R_co]): processing_chunker: VolumetricIndexChunker = attrs.field(init=False) flow_id: str = "no_id" l0_chunks_per_task: int = 0 + op_worker_type: str | None = None + reduction_worker_type: str | None = None @property def _intermediaries_are_local(self) -> bool: @@ -473,7 +475,9 @@ def _make_task( VolumetricIndex, VolumetricBasedLayerProtocol | None, dict[str, Any] ], # cannot type with P.kwargs ) -> mazepa.tasks.Task[R_co]: - return self.op.make_task(idx=arg[0], dst=arg[1], **arg[2]) + return self.op.make_task(idx=arg[0], dst=arg[1], **arg[2]).with_worker_type( + self.op_worker_type + ) def make_tasks_without_checkerboarding( self, @@ -712,9 +716,11 @@ def flow( # pylint:disable=too-many-branches, too-many-statements idx, mode="exact", stride_start_offset=stride_start_offset ) tasks_reduce = [ - Copy().make_task( + Copy() + .make_task( src=dst_temp, dst=dst.with_procs(read_procs=(), write_procs=()), idx=red_chunk ) + .with_worker_type(self.reduction_worker_type) for red_chunk in red_chunks ] logger.info( @@ -806,7 +812,7 @@ def flow( # pylint:disable=too-many-branches, too-many-statements roi_idx=idx.padded(self.roi_crop_pad + self.processing_blend_pad), dst=dst.with_procs(read_procs=(), write_procs=()), processing_blend_pad=self.processing_blend_pad, - ) + ).with_worker_type(self.reduction_worker_type) for ( red_chunk_task_idxs, red_chunk_temps,