Skip to content

Commit

Permalink
fix(andata): Fix axis selections to work with CorrData.
Browse files Browse the repository at this point in the history
  • Loading branch information
tristpinsm committed Aug 30, 2022
1 parent d81ad3a commit f0b982c
Showing 1 changed file with 27 additions and 22 deletions.
49 changes: 27 additions & 22 deletions ch_util/andata.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,28 @@ def time(self):
return time

@classmethod
def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group, sel=None):
"""Read and concatenate the list of files. Optionally specify one axis on which to make
selections with a tuple like `("axis", selection)`"""
def _interpret_and_read(cls, acq_files, start, stop, datasets, out_group, **kwargs):
"""Read and concatenate the list of files. Keyword args may contain up to one axis selection."""
# Save a reference to the first file to get index map information for
# later.
f_first = acq_files[0]

# Handle axis selections
sel = []
for key in kwargs:
if key[-4:] == "_sel":
sel.append((key[:-4], kwargs[key]))
if len(sel) > 1:
raise ValueError("Cannot handle more than one axis selection.")
elif len(sel) == 0:
sel = None
else:
ax, sel = sel[0]

if sel is None:
andata_objs = [cls(d) for d in acq_files]
else:
andata_objs = [_read_axis_sel(cls, d, sel[0], sel[1]) for d in acq_files]
andata_objs = [_read_axis_sel(cls, d, ax, sel) for d in acq_files]

data = concatenate(
andata_objs,
Expand Down Expand Up @@ -354,11 +365,12 @@ def _from_acq_h5_single(
stop=None,
datasets=None,
out_group=None,
sel=None,
**kwargs,
):
"""Load and concatenate the list of acquisition files into a local array. Optionally
specify one axis on which to make selections with a tuple like `("axis", selection)`"""
"""Load and concatenate the list of acquisition files into a local array.
Axis selections may be supplied as keyword args, but the `BaseData` implementation
only supports up to one axis selection.
"""

# Make sure the input is a sequence and that we have at least one file.
acq_files = tod.ensure_file_list(acq_files)
Expand All @@ -380,7 +392,6 @@ def _from_acq_h5_single(
stop=stop,
datasets=datasets,
out_group=out_group,
sel=sel,
**kwargs,
)

Expand All @@ -400,12 +411,12 @@ def _from_acq_h5_distributed(
stop,
datasets,
comm,
sel=None,
**kwargs,
):
"""Load and concatenate the list of acquisition files into a distributed array. Optionally
specify a selection on the distributed axis with a tuple like `("axis", selection)`. Note
that selections are only allowed along the distributed axis."""
"""Load and concatenate the list of acquisition files into a distributed array.
Axis selections may be supplied as keyword args, but the `BaseData` implementation
only supports up to one axis selection, and it must match the distributed axis.
"""

if cls.distributed_axis is None:
raise RuntimeError(
Expand All @@ -431,16 +442,10 @@ def _from_acq_h5_distributed(
ndist = len(f["index_map/" + ax][:])
ndist = comm.bcast(ndist, root=0)

# Handle selections along the distributed axis
dist_sel = kwargs.get(ax + "_sel", None)

# Calculate the global distributed selection
if sel is not None:
if sel[0] != ax:
raise ValueError(
"For distributed reads, selections are only allowed on the distributed axis. "
f"The distributed axis is {ax} and a selection was passed for {sel[0]}."
)
dist_sel = sel[1]
else:
dist_sel = None
dist_sel = _ensure_1D_selection(dist_sel)
if isinstance(dist_sel, slice):
dist_sel = list(range(*dist_sel.indices(ndist)))
Expand All @@ -451,6 +456,7 @@ def _from_acq_h5_distributed(
local_dist_sel = _ensure_1D_selection(
_convert_to_slice(dist_sel[d_start:d_end])
)
kwargs.update({ax + "_sel": local_dist_sel})

# Load just the local part of the data.
local_data = cls._from_acq_h5_single(
Expand All @@ -459,7 +465,6 @@ def _from_acq_h5_distributed(
stop=stop,
datasets=datasets,
out_group=None,
sel=(ax, local_dist_sel),
**kwargs,
)

Expand Down

0 comments on commit f0b982c

Please sign in to comment.