diff --git a/heat/core/io.py b/heat/core/io.py index 427c7b8d4..c981b9712 100644 --- a/heat/core/io.py +++ b/heat/core/io.py @@ -38,6 +38,29 @@ "load_npy_from_path", ] + +def size_from_slice(size: int, s: slice) -> Tuple[int, int]: + """ + Determines the size of a slice object. + + Parameters + ---------- + size: int + The size of the array the slice object is applied to. + s : slice + The slice object to determine the size of. + + Returns + ------- + int + The size of the sliced object. + int + The start index of the slice object. + """ + new_range = range(size)[s] + return len(new_range), new_range.start if len(new_range) > 0 else 0 + + try: import netCDF4 as nc except ImportError: @@ -490,6 +513,7 @@ def load_hdf5( dataset: str, dtype: datatype = types.float32, load_fraction: float = 1.0, + slices: Optional[Tuple[slice]] = None, split: Optional[int] = None, device: Optional[str] = None, comm: Optional[Communication] = None, @@ -509,6 +533,8 @@ def load_hdf5( if 1. (default), the whole dataset is loaded from the file specified in path else, the dataset is loaded partially, with the fraction of the dataset (along the split axis) specified by load_fraction If split is None, load_fraction is automatically set to 1., i.e. the whole dataset is loaded. + slices : tuple of slice objects, optional + Load only the specified slices of the dataset. split : int or None, optional The axis along which the data is distributed among the processing cores. device : str, optional @@ -563,6 +589,26 @@ def load_hdf5( with h5py.File(path, "r") as handle: data = handle[dataset] gshape = data.shape + new_gshape = tuple() + offsets = [0] * len(gshape) + if slices is not None: + if len(slices) != len(gshape): + raise ValueError( + f"Number of slices ({len(slices)}) does not match the number of dimensions ({len(gshape)})" + ) + for i, s in enumerate(slices): + if s: + if s.step is not None and s.step != 1: + raise ValueError("Slices with step != 1 are not supported") + new_axis_size, offset = size_from_slice(gshape[i], s) + new_gshape += (new_axis_size,) + offsets[i] = offset + else: + new_gshape += (gshape[i],) + offsets[i] = 0 + + gshape = new_gshape + if split is not None: gshape = list(gshape) gshape[split] = int(gshape[split] * load_fraction) @@ -570,6 +616,13 @@ def load_hdf5( dims = len(gshape) split = sanitize_axis(gshape, split) _, _, indices = comm.chunk(gshape, split) + + if slices is not None: + new_indices = tuple() + for offset, index in zip(offsets, indices): + new_indices += (slice(index.start + offset, index.stop + offset),) + indices = new_indices + balanced = True if split is None: data = torch.tensor( diff --git a/heat/core/tests/test_io.py b/heat/core/tests/test_io.py index 6f75846e5..7f993f985 100644 --- a/heat/core/tests/test_io.py +++ b/heat/core/tests/test_io.py @@ -11,6 +11,10 @@ import heat as ht from .test_suites.basic_test import TestCase +import pytest +from hypothesis import given, settings, note, assume +import hypothesis.strategies as st + class TestIO(TestCase): @classmethod @@ -56,6 +60,23 @@ def tearDown(self): # synchronize all nodes ht.MPI_WORLD.Barrier() + @given(size=st.integers(1, 1000), slice=st.slices(1000)) + def test_size_from_slice(self, size, slice): + expected_sequence = list(range(size))[slice] + if len(expected_sequence) == 0: + expected_offset = 0 + else: + expected_offset = expected_sequence[0] + + expected_new_size = len(expected_sequence) + + new_size, offset = ht.io.size_from_slice(size, slice) + note(f"Expected sequence: {expected_sequence}") + note(f"Expected new size: {expected_new_size}, new size: {new_size}") + note(f"Expected offset: {expected_offset}, offset: {offset}") + self.assertEqual(expected_new_size, new_size) + self.assertEqual(expected_offset, offset) + # catch-all loading def test_load(self): # HDF5 @@ -892,3 +913,37 @@ def test_load_multiple_csv_exception(self): ht.MPI_WORLD.Barrier() if ht.MPI_WORLD.rank == 0: shutil.rmtree(os.path.join(os.getcwd(), "heat/datasets/csv_tests")) + + +@unittest.skipIf(not ht.io.supports_hdf5(), reason="Requires HDF5") +@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize( + "slices", + [ + (slice(0, 50, None), slice(None, None, None)), + (slice(0, 50, None), slice(0, 2, None)), + (slice(50, 100, None), slice(None, None, None)), + (slice(None, None, None), slice(2, 4, None)), + ], +) +def test_load_partial_hdf5(axis, slices): + print("axis: ", axis) + HDF5_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.h5") + HDF5_DATASET = "data" + expect_error = False + for s in slices: + if s and s.step not in [None, 1]: + expect_error = True + break + + if expect_error: + with pytest.raises(ValueError): + sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices) + else: + original_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis) + expected_iris = original_iris[slices] + sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices) + print("Original shape: " + str(original_iris.shape)) + print("Sliced shape: " + str(sliced_iris.shape)) + print("Expected shape: " + str(expected_iris.shape)) + assert not ht.equal(sliced_iris, expected_iris) diff --git a/setup.py b/setup.py index bb71bf1c3..14e2b9f1e 100644 --- a/setup.py +++ b/setup.py @@ -44,9 +44,9 @@ "docutils": ["docutils>=0.16"], "hdf5": ["h5py>=2.8.0"], "netcdf": ["netCDF4>=1.5.6"], - "dev": ["pre-commit>=1.18.3"], + "dev": ["pre-commit>=1.18.3", "pytest>=8.0", "hypothesis<=6.100"], "examples": ["scikit-learn>=0.24.0", "matplotlib>=3.1.0"], - "cb": ["perun>=0.2.0"], + "cb": ["perun>=0.8"], "pandas": ["pandas>=1.4"], }, )