@@ -538,6 +538,11 @@ def shuffle(self) -> None:
538
538
if all (isinstance (idx , slice ) for idx in self ._group_indices ):
539
539
return
540
540
541
+ if TYPE_CHECKING :
542
+ for idx in self ._group_indices :
543
+ assert not isinstance (idx , slice )
544
+ indices : tuple [list [int ]] = self ._group_indices # type: ignore[assignment]
545
+
541
546
was_array = isinstance (self ._obj , DataArray )
542
547
as_dataset = self ._obj ._to_temp_dataset () if was_array else self ._obj
543
548
@@ -547,20 +552,24 @@ def shuffle(self) -> None:
547
552
shuffled [name ] = var
548
553
continue
549
554
shuffled_data = shuffle_array (
550
- var ._data , list (self . _group_indices ), axis = var .get_axis_num (dim )
555
+ var ._data , list (indices ), axis = var .get_axis_num (dim )
551
556
)
552
557
shuffled [name ] = var ._replace (data = shuffled_data )
553
558
554
559
# Replace self._group_indices with slices
555
560
slices = []
556
561
start = 0
557
562
for idxr in self ._group_indices :
563
+ if TYPE_CHECKING :
564
+ assert not isinstance (idxr , slice )
558
565
slices .append (slice (start , start + len (idxr )))
559
566
start += len (idxr )
560
567
# TODO: we have now broken the invariant
561
568
# self._group_indices ≠ self.groupers[0].group_indices
562
569
self ._group_indices = tuple (slices )
563
570
if was_array :
571
+ if TYPE_CHECKING :
572
+ assert isinstance (self ._obj , DataArray )
564
573
self ._obj = self ._obj ._from_temp_dataset (shuffled )
565
574
else :
566
575
self ._obj = shuffled
0 commit comments