Skip to content

Commit

Permalink
fix: update prov attributes combine (#1116)
Browse files Browse the repository at this point in the history
* fix: fix how provenance data are combined

* test: add test to check actual prov vals

* fix: remove ECHODATA_FILENAME, use ED_FILENAME

* tests: simplify 'test_combine_echodata_combined_append' (lsetiawan#1)

Simplified by using common code for repeated blocks; plus other readability tweaks

---------

Co-authored-by: Emilio Mayorga <[email protected]>
Refs: #1115
  • Loading branch information
lsetiawan and emiliom authored Aug 15, 2023
1 parent ad17757 commit 16a2b16
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 33 deletions.
102 changes: 94 additions & 8 deletions echopype/echodata/combine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import itertools
import re
from collections import ChainMap
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
from warnings import warn
Expand Down Expand Up @@ -643,6 +644,50 @@ def _capture_prov_attrs(
return prov_ds


def _get_prov_attrs(
ds: xr.Dataset, is_combined: bool = True
) -> Optional[Dict[str, List[Dict[str, str]]]]:
"""
Get the provenance attributes from the dataset.
This function is meant to be used on an already combined dataset.
Parameters
----------
ds : xr.Dataset
The Provenance group dataset to get attributes from
is_combined: bool
The flag to indicate if it's combined
Returns
-------
Dict[str, List[Dict[str, str]]]
The provenance attributes
"""

if is_combined:
attrs_dict = {}
for k, v in ds.data_vars.items():
# Go through each data variable and extract the attribute values
# based on the echodata group as stored in the variable attribute
if ED_GROUP in v.attrs:
ed_group = v.attrs[ED_GROUP]
if ed_group not in attrs_dict:
attrs_dict[ed_group] = []
# Store the values as a list of dictionary for each group
attrs_dict[ed_group].append([{k: i} for i in v.values])

# Merge the attributes for each group so it matches the
# attributes dict for later merging
return {
ed_group: [
dict(ChainMap(*v))
for _, v in pd.DataFrame.from_dict(attrs).to_dict(orient="list").items()
]
for ed_group, attrs in attrs_dict.items()
}
return None


def _combine(
sonar_model: str,
eds: List[EchoData] = [],
Expand Down Expand Up @@ -680,7 +725,27 @@ def _combine(
attrs_dict = {}

# Check if input data are combined datasets
all_combined = all(ed["Provenance"].attrs.get("is_combined", False) for ed in eds)
# Create combined mapping for later use
combined_mapping = []
for idx, ed in enumerate(eds):
is_combined = ed["Provenance"].attrs.get("is_combined", False)
combined_mapping.append(
{
"is_combined": is_combined,
"attrs_dict": _get_prov_attrs(ed["Provenance"], is_combined),
"echodata_filename": [str(s) for s in ed["Provenance"][ED_FILENAME].values]
if is_combined
else [echodata_filenames[idx]],
}
)
# Get single boolean value to see if there's any combined files
any_combined = any(d["is_combined"] for d in combined_mapping)

if any_combined:
# Fetches the true echodata filenames if there are any combined files
echodata_filenames = list(
itertools.chain.from_iterable([d[ED_FILENAME] for d in combined_mapping])
)

# Create Echodata tree dict
tree_dict = {}
Expand All @@ -697,8 +762,28 @@ def _combine(
]

if ds_list:
# Get all of the keys and attributes
ds_attrs = [ds.attrs for ds in ds_list]
if not any_combined:
# Get all of the keys and attributes
# for regular non combined echodata object
ds_attrs = [ds.attrs for ds in ds_list]
else:
# If there are any combined files,
# iterate through from mapping above
ds_attrs = []
for idx, ds in enumerate(ds_list):
# Retrieve the echodata attrs dict
# parsed from provenance group above
ed_attrs_dict = combined_mapping[idx]["attrs_dict"]
if ed_attrs_dict is not None:
# Set attributes to the appropriate group
# from echodata attrs provenance,
# set default empty dict for missing group
attrs = ed_attrs_dict.get(ed_group, {})
else:
# This is for non combined echodata object
attrs = [ds.attrs]
ds_attrs += attrs

# Attribute holding
attrs_dict[ed_group] = ds_attrs

Expand Down Expand Up @@ -753,15 +838,16 @@ def _combine(
# Data holding
tree_dict[ed_group] = combined_ds

if not all_combined:
# Capture provenance for all the attributes
prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model)
# Capture provenance for all the attributes
prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model)
if not any_combined:
# Update the provenance dataset with the captured data
prov_ds = tree_dict["Provenance"].assign(prov_ds)
else:
prov_ds = tree_dict["Provenance"]
prov_ds = tree_dict["Provenance"].drop_dims(ED_FILENAME).assign(prov_ds)

# Update filenames to iter integers
prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape))
prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape)) # noqa
tree_dict["Provenance"] = prov_ds

return tree_dict
Expand Down
80 changes: 55 additions & 25 deletions echopype/tests/echodata/test_echodata_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ def attr_time_to_dt(time_str):
assert test_ds.identical(combined_group.drop_dims(grp_drop_dims))


def _check_prov_ds(prov_ds, eds):
"""Checks the Provenance dataset against source_filenames variable
and global attributes in the original echodata object"""
for i in range(prov_ds.dims["echodata_filename"]):
ed_ds = eds[i]
one_ds = prov_ds.isel(echodata_filename=i, filenames=i)
for key, value in one_ds.data_vars.items():
if key == "source_filenames":
ed_group = "Provenance"
assert np.array_equal(
ed_ds[ed_group][key].isel(filenames=0).values, value.values
)
else:
ed_group = value.attrs.get("echodata_group")
expected_val = ed_ds[ed_group].attrs[key]
if not isinstance(expected_val, str):
expected_val = str(expected_val)
assert str(value.values) == expected_val


@pytest.mark.parametrize("test_param", [
"single",
"multi",
Expand Down Expand Up @@ -263,6 +283,17 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona
# First combined file
combined_ed = echopype.combine_echodata(eds[:2])
combined_ed.to_zarr(first_zarr, overwrite=True)

def _check_prov_ds_and_dims(sel_comb_ed, n_val_expected):
prov_ds = sel_comb_ed["Provenance"]
for _, n_val in prov_ds.dims.items():
assert n_val == n_val_expected
_check_prov_ds(prov_ds, eds)

# Checks for Provenance group
# Both dims of filenames and echodata filename should be 2
expected_n_vals = 2
_check_prov_ds_and_dims(combined_ed, expected_n_vals)

# Second combined file
combined_ed_other = echopype.combine_echodata(eds[2:])
Expand All @@ -271,48 +302,47 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona
combined_ed = echopype.open_converted(first_zarr)
combined_ed_other = echopype.open_converted(second_zarr)

# Set expected values for Provenance
if test_param == "single":
data_inputs = [combined_ed, eds[2]]
expected_n_vals = 3
elif test_param == "multi":
data_inputs = [combined_ed, eds[2], eds[3]]
expected_n_vals = 4
else:
data_inputs = [combined_ed, combined_ed_other]
combined_ed2 = echopype.combine_echodata(
data_inputs
)
expected_n_vals = 4

combined_ed2 = echopype.combine_echodata(data_inputs)

# Verify that combined objects are all EchoData objects
assert isinstance(combined_ed, EchoData)
assert isinstance(combined_ed_other, EchoData)
assert isinstance(combined_ed2, EchoData)

# Ensure that they're from the same file source
assert eds[0]['Provenance'].source_filenames[0].values == combined_ed['Provenance'].source_filenames[0].values
assert eds[1]['Provenance'].source_filenames[0].values == combined_ed['Provenance'].source_filenames[1].values
assert eds[2]['Provenance'].source_filenames[0].values == combined_ed2['Provenance'].source_filenames[2].values
if test_param != "single":
assert eds[3]['Provenance'].source_filenames[0].values == combined_ed2['Provenance'].source_filenames[3].values
group_path = "Provenance"
for i in range(4):
ds_i = eds[i][group_path]
select_comb_ds = combined_ed[group_path] if i < 2 else combined_ed2[group_path]
if i < 3 or (i == 3 and test_param != "single"):
assert ds_i.source_filenames[0].values == select_comb_ds.source_filenames[i].values

# Check beam_group1. Should be exactly same xr dataset
group_path = "Sonar/Beam_group1"
ds0 = eds[0][group_path]
filt_ds0 = combined_ed[group_path].sel(ping_time=ds0.ping_time)
assert filt_ds0.identical(ds0) is True

ds1 = eds[1][group_path]
filt_ds1 = combined_ed[group_path].sel(ping_time=ds1.ping_time)
assert filt_ds1.identical(ds1) is True

ds2 = eds[2][group_path]
filt_ds2 = combined_ed2[group_path].sel(ping_time=ds2.ping_time)
assert filt_ds2.identical(ds2) is True

if test_param != "single":
ds3 = eds[3][group_path]
filt_ds3 = combined_ed2[group_path].sel(ping_time=ds3.ping_time)
assert filt_ds3.identical(ds3) is True
for i in range(4):
ds_i = eds[i][group_path]
select_comb_ds = combined_ed[group_path] if i < 2 else combined_ed2[group_path]
if i < 3 or (i == 3 and test_param != "single"):
filt_ds_i = select_comb_ds.sel(ping_time=ds_i.ping_time)
assert filt_ds_i.identical(ds_i) is True

filt_combined = combined_ed2[group_path].sel(ping_time=combined_ed[group_path].ping_time)
assert filt_combined.identical(combined_ed[group_path])
assert filt_combined.identical(combined_ed[group_path]) is True

# Checks for Provenance group
# Both dims of filenames and echodata filename should be expected_n_vals
_check_prov_ds_and_dims(combined_ed2, expected_n_vals)


def test_combine_echodata_channel_selection():
Expand Down

0 comments on commit 16a2b16

Please sign in to comment.