From 86dc83c7421a5d660b2e465ca56d5ec938962239 Mon Sep 17 00:00:00 2001 From: Hendrik Makait Date: Mon, 12 Aug 2024 15:43:36 +0200 Subject: [PATCH] MINOR: Extract truncation logic out of partial concatenation in P2P rechunking (#8826) --- distributed/shuffle/_rechunk.py | 62 ++++++++++++++++----------------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/distributed/shuffle/_rechunk.py b/distributed/shuffle/_rechunk.py index 1ff0daa97c..99f9d39061 100644 --- a/distributed/shuffle/_rechunk.py +++ b/distributed/shuffle/_rechunk.py @@ -399,20 +399,21 @@ def _construct_graph(self) -> _T_LowLevelGraph: chunked_shape = tuple(len(axis) for axis in self.chunks) for ndpartial in _split_partials(_old_to_new, chunked_shape): - output_count = np.sum(self.keepmap[ndpartial.new]) + partial_keepmap = self.keepmap[ndpartial.new] + output_count = np.sum(partial_keepmap) if output_count == 0: continue elif output_count == 1: # Single output chunk - # TODO: Create new partial that contains ONLY the relevant chunk + ndindex = np.argwhere(partial_keepmap)[0] + ndpartial = _truncate_partial(ndindex, ndpartial, _old_to_new) + dsk.update( partial_concatenate( input_name=self.name_input, input_chunks=self.chunks_input, ndpartial=ndpartial, token=self.token, - keepmap=self.keepmap, - old_to_new=_old_to_new, ) ) else: @@ -516,8 +517,6 @@ def partial_concatenate( input_chunks: ChunkedAxes, ndpartial: _NDPartial, token: str, - keepmap: np.ndarray, - old_to_new: list[Any], ) -> dict[Key, Any]: import numpy as np @@ -528,31 +527,6 @@ def partial_concatenate( slice_group = f"rechunk-slice-{token}" - partial_keepmap = keepmap[ndpartial.new] - assert np.sum(partial_keepmap) == 1 - - ndindex = np.argwhere(partial_keepmap)[0] - - partial_per_axis = [] - for axis_index, index in enumerate(ndindex): - slc = slice( - ndpartial.new[axis_index].start + index, - ndpartial.new[axis_index].start + index + 1, - ) - first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] - last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] - partial_per_axis.append( - _Partial( - old=slice(first_old_chunk, last_old_chunk + 1), - new=slc, - left_start=first_old_slice.start, - right_stop=last_old_slice.stop, - ) - ) - - old, new, left_starts, right_stops = zip(*partial_per_axis) - ndpartial = _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) - old_offset = tuple(slice_.start for slice_ in ndpartial.old) shape = tuple(slice_.stop - slice_.start for slice_ in ndpartial.old) @@ -588,6 +562,32 @@ def partial_concatenate( return dsk +def _truncate_partial( + ndindex: NDIndex, + ndpartial: _NDPartial, + old_to_new: list[Any], +) -> _NDPartial: + partial_per_axis = [] + for axis_index, index in enumerate(ndindex): + slc = slice( + ndpartial.new[axis_index].start + index, + ndpartial.new[axis_index].start + index + 1, + ) + first_old_chunk, first_old_slice = old_to_new[axis_index][slc.start][0] + last_old_chunk, last_old_slice = old_to_new[axis_index][slc.stop - 1][-1] + partial_per_axis.append( + _Partial( + old=slice(first_old_chunk, last_old_chunk + 1), + new=slc, + left_start=first_old_slice.start, + right_stop=last_old_slice.stop, + ) + ) + + old, new, left_starts, right_stops = zip(*partial_per_axis) + return _NDPartial(old, new, left_starts, right_stops, ndpartial.ix) + + def _compute_partial_old_chunks( partial: _NDPartial, chunks: ChunkedAxes ) -> ChunkedAxes: