Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load slices of hdf5 dataset #1753

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions heat/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,29 @@
"load_npy_from_path",
]


def size_from_slice(size: int, s: slice) -> Tuple[int, int]:
"""
Determines the size of a slice object.

Parameters
----------
size: int
The size of the array the slice object is applied to.
s : slice
The slice object to determine the size of.

Returns
-------
int
The size of the sliced object.
int
The start index of the slice object.
"""
new_range = range(size)[s]
return len(new_range), new_range.start if len(new_range) > 0 else 0


try:
import netCDF4 as nc
except ImportError:
Expand Down Expand Up @@ -490,6 +513,7 @@ def load_hdf5(
dataset: str,
dtype: datatype = types.float32,
load_fraction: float = 1.0,
slices: Optional[Tuple[slice]] = None,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
Expand All @@ -509,6 +533,8 @@ def load_hdf5(
if 1. (default), the whole dataset is loaded from the file specified in path
else, the dataset is loaded partially, with the fraction of the dataset (along the split axis) specified by load_fraction
If split is None, load_fraction is automatically set to 1., i.e. the whole dataset is loaded.
slices : tuple of slice objects, optional
Load only the specified slices of the dataset.
split : int or None, optional
The axis along which the data is distributed among the processing cores.
device : str, optional
Expand Down Expand Up @@ -563,13 +589,40 @@ def load_hdf5(
with h5py.File(path, "r") as handle:
data = handle[dataset]
gshape = data.shape
new_gshape = tuple()
offsets = [0] * len(gshape)
if slices is not None:
if len(slices) != len(gshape):
raise ValueError(
f"Number of slices ({len(slices)}) does not match the number of dimensions ({len(gshape)})"
)
for i, s in enumerate(slices):
if s:
if s.step is not None and s.step != 1:
raise ValueError("Slices with step != 1 are not supported")
new_axis_size, offset = size_from_slice(gshape[i], s)
new_gshape += (new_axis_size,)
offsets[i] = offset
else:
new_gshape += (gshape[i],)
offsets[i] = 0

gshape = new_gshape

if split is not None:
gshape = list(gshape)
gshape[split] = int(gshape[split] * load_fraction)
gshape = tuple(gshape)
dims = len(gshape)
split = sanitize_axis(gshape, split)
_, _, indices = comm.chunk(gshape, split)

if slices is not None:
new_indices = tuple()
for offset, index in zip(offsets, indices):
new_indices += (slice(index.start + offset, index.stop + offset),)
indices = new_indices

balanced = True
if split is None:
data = torch.tensor(
Expand Down
55 changes: 55 additions & 0 deletions heat/core/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import heat as ht
from .test_suites.basic_test import TestCase

import pytest
from hypothesis import given, settings, note, assume
import hypothesis.strategies as st


class TestIO(TestCase):
@classmethod
Expand Down Expand Up @@ -56,6 +60,23 @@ def tearDown(self):
# synchronize all nodes
ht.MPI_WORLD.Barrier()

@given(size=st.integers(1, 1000), slice=st.slices(1000))
def test_size_from_slice(self, size, slice):
expected_sequence = list(range(size))[slice]
if len(expected_sequence) == 0:
expected_offset = 0
else:
expected_offset = expected_sequence[0]

expected_new_size = len(expected_sequence)

new_size, offset = ht.io.size_from_slice(size, slice)
note(f"Expected sequence: {expected_sequence}")
note(f"Expected new size: {expected_new_size}, new size: {new_size}")
note(f"Expected offset: {expected_offset}, offset: {offset}")
self.assertEqual(expected_new_size, new_size)
self.assertEqual(expected_offset, offset)

# catch-all loading
def test_load(self):
# HDF5
Expand Down Expand Up @@ -892,3 +913,37 @@ def test_load_multiple_csv_exception(self):
ht.MPI_WORLD.Barrier()
if ht.MPI_WORLD.rank == 0:
shutil.rmtree(os.path.join(os.getcwd(), "heat/datasets/csv_tests"))


@unittest.skipIf(not ht.io.supports_hdf5(), reason="Requires HDF5")
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize(
"slices",
[
(slice(0, 50, None), slice(None, None, None)),
(slice(0, 50, None), slice(0, 2, None)),
(slice(50, 100, None), slice(None, None, None)),
(slice(None, None, None), slice(2, 4, None)),
],
)
def test_load_partial_hdf5(axis, slices):
print("axis: ", axis)
HDF5_PATH = os.path.join(os.getcwd(), "heat/datasets/iris.h5")
HDF5_DATASET = "data"
expect_error = False
for s in slices:
if s and s.step not in [None, 1]:
expect_error = True
break

if expect_error:
with pytest.raises(ValueError):
sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices)
else:
original_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis)
expected_iris = original_iris[slices]
sliced_iris = ht.load_hdf5(HDF5_PATH, HDF5_DATASET, split=axis, slices=slices)
print("Original shape: " + str(original_iris.shape))
print("Sliced shape: " + str(sliced_iris.shape))
print("Expected shape: " + str(expected_iris.shape))
assert not ht.equal(sliced_iris, expected_iris)
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
"docutils": ["docutils>=0.16"],
"hdf5": ["h5py>=2.8.0"],
"netcdf": ["netCDF4>=1.5.6"],
"dev": ["pre-commit>=1.18.3"],
"dev": ["pre-commit>=1.18.3", "pytest>=8.0", "hypothesis<=6.100"],
"examples": ["scikit-learn>=0.24.0", "matplotlib>=3.1.0"],
"cb": ["perun>=0.2.0"],
"cb": ["perun>=0.8"],
"pandas": ["pandas>=1.4"],
},
)
Loading