Skip to content

Commit

Permalink
update load_from_h5file to support multiple fields
Browse files Browse the repository at this point in the history
  • Loading branch information
TomDonoghue committed Apr 20, 2024
1 parent 4c0d5ca commit cf5ea37
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
33 changes: 20 additions & 13 deletions convnwb/io/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,8 +372,8 @@ def load_jsons_to_df(files, folder=None):
## HDF5 FILE SUPPORT, INCLUDING CONTEXT MANAGERS

@check_dependency(h5py, 'h5py')
def read_h5file(file_name, folder=None, ext='.h5', **kwargs):
"""Read a hdf5 file.
def access_h5file(file_name, folder=None, ext='.h5', mode='r', **kwargs):
"""Access a HDF5 file.
Parameters
----------
Expand All @@ -383,6 +383,8 @@ def read_h5file(file_name, folder=None, ext='.h5', **kwargs):
Folder to open the file from.
ext : str, optional default: '.h5'
The extension to check and use for the file.
mode : {'r', 'r+', 'w', 'w-', 'x', 'a'}
Mode to access file. See h5py.File for details.
**kwargs
Additional keyword arguments to pass into h5py.File.
Expand All @@ -396,15 +398,15 @@ def read_h5file(file_name, folder=None, ext='.h5', **kwargs):
This function is a wrapper for `h5py.File`.
"""

h5file = h5py.File(check_ext(check_folder(file_name, folder), ext), 'r', **kwargs)
h5file = h5py.File(check_ext(check_folder(file_name, folder), ext), mode, **kwargs)

return h5file


@contextmanager
@check_dependency(h5py, 'h5py')
def open_h5file(file_name, folder=None, ext='.h5', **kwargs):
"""Context manager to open a hdf5 file.
def open_h5file(file_name, folder=None, mode='r', ext='.h5', **kwargs):
"""Context manager to open a HDF5 file.
Parameters
----------
Expand All @@ -427,7 +429,7 @@ def open_h5file(file_name, folder=None, ext='.h5', **kwargs):
This function is a wrapper for `h5py.File`, creating a context manager.
"""

h5file = read_h5file(file_name, folder, ext, **kwargs)
h5file = access_h5file(file_name, folder, ext, **kwargs)

try:
yield h5file
Expand All @@ -436,13 +438,13 @@ def open_h5file(file_name, folder=None, ext='.h5', **kwargs):


@check_dependency(h5py, 'h5py')
def load_from_h5file(field, file_name, folder=None, ext='.h5', **kwargs):
"""Load a specified field from a HDF5 file.
def load_from_h5file(fields, file_name, folder=None, ext='.h5', **kwargs):
"""Load one or more specified field(s) from a HDF5 file.
Parameters
----------
field : str
Name of the field to load from the HDF5 file.
field : str or list of str
Name(s) of the field to load from the HDF5 file.
file_name : str
File name of the h5file to open.
folder : str or Path, optional
Expand All @@ -454,8 +456,9 @@ def load_from_h5file(field, file_name, folder=None, ext='.h5', **kwargs):
Returns
-------
data
data : dict
Loaded data field from the file.
Each key is the field label, each set of values the loaded data.
Notes
-----
Expand All @@ -464,7 +467,11 @@ def load_from_h5file(field, file_name, folder=None, ext='.h5', **kwargs):
Files with multiple fields should be opened and accessed with `open_h5file`.
"""

fields = [fields] if isinstance(fields, str) else fields

outputs = {}
with open_h5file(file_name, folder, ext=ext, **kwargs) as h5file:
output = h5file[field][:]
for field in fields:
outputs[field] = h5file[field][:]

return output
return outputs
3 changes: 2 additions & 1 deletion convnwb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@ def th5file():
"""Save out a test HDF5 file."""

with h5py.File(TEST_PATHS['file'] / "test_hdf5.h5", "w") as h5file:
dset = h5file.create_dataset("dataset", (100,), dtype='i')
dset1 = h5file.create_dataset("data", (50,), dtype='i')
dset2 = h5file.create_dataset("data2", (50,), dtype='f')
12 changes: 9 additions & 3 deletions convnwb/tests/io/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,10 @@ def test_load_jsons_to_df():
out = load_jsons_to_df(TEST_FILE_PATH)
assert isinstance(out, pd.DataFrame)

def test_read_h5file():
def test_access_h5file():

f_name = 'test_hdf5'
h5file = read_h5file(f_name, TEST_FILE_PATH)
h5file = access_h5file(f_name, TEST_FILE_PATH)
assert h5file
h5file.close()

Expand All @@ -136,5 +136,11 @@ def test_open_h5file():
def test_load_from_h5file():

f_name = 'test_hdf5'
dataset = load_from_h5file('dataset', f_name, TEST_FILE_PATH)

# Test loading single field
dataset = load_from_h5file('data', f_name, TEST_FILE_PATH)
assert dataset is not None

# Test loading multiple fields
datasets = load_from_h5file(['data', 'data2'], f_name, TEST_FILE_PATH)
assert datasets is not None

0 comments on commit cf5ea37

Please sign in to comment.