@@ -147,6 +147,75 @@ def _get_chunk_slicer(dim: Hashable, chunk_index: Mapping, chunk_bounds: Mapping
147
147
return slice (None )
148
148
149
149
150
+ def subset_dataset_to_block (
151
+ graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
152
+ ):
153
+ """
154
+ Creates a task that subsets an xarray dataset to a block determined by chunk_index.
155
+ Block extents are determined by input_chunk_bounds.
156
+ Also subtasks that subset the constituent variables of a dataset.
157
+ """
158
+ import dask
159
+
160
+ # this will become [[name1, variable1],
161
+ # [name2, variable2],
162
+ # ...]
163
+ # which is passed to dict and then to Dataset
164
+ data_vars = []
165
+ coords = []
166
+
167
+ chunk_tuple = tuple (chunk_index .values ())
168
+ chunk_dims_set = set (chunk_index )
169
+ variable : Variable
170
+ for name , variable in dataset .variables .items ():
171
+ # make a task that creates tuple of (dims, chunk)
172
+ if dask .is_dask_collection (variable .data ):
173
+ # get task name for chunk
174
+ chunk = (
175
+ variable .data .name ,
176
+ * tuple (chunk_index [dim ] for dim in variable .dims ),
177
+ )
178
+
179
+ chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
180
+ graph [chunk_variable_task ] = (
181
+ tuple ,
182
+ [variable .dims , chunk , variable .attrs ],
183
+ )
184
+ else :
185
+ assert name in dataset .dims or variable .ndim == 0
186
+
187
+ # non-dask array possibly with dimensions chunked on other variables
188
+ # index into variable appropriately
189
+ subsetter = {
190
+ dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
191
+ for dim in variable .dims
192
+ }
193
+ if set (variable .dims ) < chunk_dims_set :
194
+ this_var_chunk_tuple = tuple (chunk_index [dim ] for dim in variable .dims )
195
+ else :
196
+ this_var_chunk_tuple = chunk_tuple
197
+
198
+ chunk_variable_task = (
199
+ f"{ name } -{ gname } -{ dask .base .tokenize (subsetter )} " ,
200
+ ) + this_var_chunk_tuple
201
+ # We are including a dimension coordinate,
202
+ # minimize duplication by not copying it in the graph for every chunk.
203
+ if variable .ndim == 0 or chunk_variable_task not in graph :
204
+ subset = variable .isel (subsetter )
205
+ graph [chunk_variable_task ] = (
206
+ tuple ,
207
+ [subset .dims , subset ._data , subset .attrs ],
208
+ )
209
+
210
+ # this task creates dict mapping variable name to above tuple
211
+ if name in dataset ._coord_names :
212
+ coords .append ([name , chunk_variable_task ])
213
+ else :
214
+ data_vars .append ([name , chunk_variable_task ])
215
+
216
+ return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
217
+
218
+
150
219
def map_blocks (
151
220
func : Callable [..., T_Xarray ],
152
221
obj : DataArray | Dataset ,
@@ -451,75 +520,6 @@ def _wrapper(
451
520
dim : np .cumsum ((0 ,) + chunks_v ) for dim , chunks_v in output_chunks .items ()
452
521
}
453
522
454
- def subset_dataset_to_block (
455
- graph : dict , gname : str , dataset : Dataset , input_chunk_bounds , chunk_index
456
- ):
457
- """
458
- Creates a task that subsets an xarray dataset to a block determined by chunk_index.
459
- Block extents are determined by input_chunk_bounds.
460
- Also subtasks that subset the constituent variables of a dataset.
461
- """
462
-
463
- # this will become [[name1, variable1],
464
- # [name2, variable2],
465
- # ...]
466
- # which is passed to dict and then to Dataset
467
- data_vars = []
468
- coords = []
469
-
470
- chunk_tuple = tuple (chunk_index .values ())
471
- chunk_dims_set = set (chunk_index )
472
- variable : Variable
473
- for name , variable in dataset .variables .items ():
474
- # make a task that creates tuple of (dims, chunk)
475
- if dask .is_dask_collection (variable .data ):
476
- # get task name for chunk
477
- chunk = (
478
- variable .data .name ,
479
- * tuple (chunk_index [dim ] for dim in variable .dims ),
480
- )
481
-
482
- chunk_variable_task = (f"{ name } -{ gname } -{ chunk [0 ]!r} " ,) + chunk_tuple
483
- graph [chunk_variable_task ] = (
484
- tuple ,
485
- [variable .dims , chunk , variable .attrs ],
486
- )
487
- else :
488
- assert name in dataset .dims or variable .ndim == 0
489
-
490
- # non-dask array possibly with dimensions chunked on other variables
491
- # index into variable appropriately
492
- subsetter = {
493
- dim : _get_chunk_slicer (dim , chunk_index , input_chunk_bounds )
494
- for dim in variable .dims
495
- }
496
- if set (variable .dims ) < chunk_dims_set :
497
- this_var_chunk_tuple = tuple (
498
- chunk_index [dim ] for dim in variable .dims
499
- )
500
- else :
501
- this_var_chunk_tuple = chunk_tuple
502
-
503
- chunk_variable_task = (
504
- f"{ name } -{ gname } -{ dask .base .tokenize (subsetter )} " ,
505
- ) + this_var_chunk_tuple
506
- # We are including a dimension coordinate,
507
- # minimize duplication by not copying it in the graph for every chunk.
508
- if variable .ndim == 0 or chunk_variable_task not in graph :
509
- subset = variable .isel (subsetter )
510
- graph [chunk_variable_task ] = (
511
- tuple ,
512
- [subset .dims , subset ._data , subset .attrs ],
513
- )
514
-
515
- # this task creates dict mapping variable name to above tuple
516
- if name in dataset ._coord_names :
517
- coords .append ([name , chunk_variable_task ])
518
- else :
519
- data_vars .append ([name , chunk_variable_task ])
520
-
521
- return (Dataset , (dict , data_vars ), (dict , coords ), dataset .attrs )
522
-
523
523
include_variables = set (template .variables ) - set (coordinates .indexes )
524
524
# iterate over all possible chunk combinations
525
525
for chunk_tuple in itertools .product (* ichunk .values ()):
0 commit comments