From 16a2b16b4f939d8c0dbcb277c1c9e0ded0f0e600 Mon Sep 17 00:00:00 2001 From: Don Setiawan Date: Tue, 15 Aug 2023 11:22:21 -0700 Subject: [PATCH] fix: update prov attributes combine (#1116) * fix: fix how provenance data are combined * test: add test to check actual prov vals * fix: remove ECHODATA_FILENAME, use ED_FILENAME * tests: simplify 'test_combine_echodata_combined_append' (lsetiawan/echopype#1) Simplified by using common code for repeated blocks; plus other readability tweaks --------- Co-authored-by: Emilio Mayorga Refs: #1115 --- echopype/echodata/combine.py | 102 ++++++++++++++++-- .../tests/echodata/test_echodata_combine.py | 80 +++++++++----- 2 files changed, 149 insertions(+), 33 deletions(-) diff --git a/echopype/echodata/combine.py b/echopype/echodata/combine.py index bd08f84fc..faf991771 100644 --- a/echopype/echodata/combine.py +++ b/echopype/echodata/combine.py @@ -1,5 +1,6 @@ import itertools import re +from collections import ChainMap from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Tuple, Union from warnings import warn @@ -643,6 +644,50 @@ def _capture_prov_attrs( return prov_ds +def _get_prov_attrs( + ds: xr.Dataset, is_combined: bool = True +) -> Optional[Dict[str, List[Dict[str, str]]]]: + """ + Get the provenance attributes from the dataset. + This function is meant to be used on an already combined dataset. + + Parameters + ---------- + ds : xr.Dataset + The Provenance group dataset to get attributes from + is_combined: bool + The flag to indicate if it's combined + + Returns + ------- + Dict[str, List[Dict[str, str]]] + The provenance attributes + """ + + if is_combined: + attrs_dict = {} + for k, v in ds.data_vars.items(): + # Go through each data variable and extract the attribute values + # based on the echodata group as stored in the variable attribute + if ED_GROUP in v.attrs: + ed_group = v.attrs[ED_GROUP] + if ed_group not in attrs_dict: + attrs_dict[ed_group] = [] + # Store the values as a list of dictionary for each group + attrs_dict[ed_group].append([{k: i} for i in v.values]) + + # Merge the attributes for each group so it matches the + # attributes dict for later merging + return { + ed_group: [ + dict(ChainMap(*v)) + for _, v in pd.DataFrame.from_dict(attrs).to_dict(orient="list").items() + ] + for ed_group, attrs in attrs_dict.items() + } + return None + + def _combine( sonar_model: str, eds: List[EchoData] = [], @@ -680,7 +725,27 @@ def _combine( attrs_dict = {} # Check if input data are combined datasets - all_combined = all(ed["Provenance"].attrs.get("is_combined", False) for ed in eds) + # Create combined mapping for later use + combined_mapping = [] + for idx, ed in enumerate(eds): + is_combined = ed["Provenance"].attrs.get("is_combined", False) + combined_mapping.append( + { + "is_combined": is_combined, + "attrs_dict": _get_prov_attrs(ed["Provenance"], is_combined), + "echodata_filename": [str(s) for s in ed["Provenance"][ED_FILENAME].values] + if is_combined + else [echodata_filenames[idx]], + } + ) + # Get single boolean value to see if there's any combined files + any_combined = any(d["is_combined"] for d in combined_mapping) + + if any_combined: + # Fetches the true echodata filenames if there are any combined files + echodata_filenames = list( + itertools.chain.from_iterable([d[ED_FILENAME] for d in combined_mapping]) + ) # Create Echodata tree dict tree_dict = {} @@ -697,8 +762,28 @@ def _combine( ] if ds_list: - # Get all of the keys and attributes - ds_attrs = [ds.attrs for ds in ds_list] + if not any_combined: + # Get all of the keys and attributes + # for regular non combined echodata object + ds_attrs = [ds.attrs for ds in ds_list] + else: + # If there are any combined files, + # iterate through from mapping above + ds_attrs = [] + for idx, ds in enumerate(ds_list): + # Retrieve the echodata attrs dict + # parsed from provenance group above + ed_attrs_dict = combined_mapping[idx]["attrs_dict"] + if ed_attrs_dict is not None: + # Set attributes to the appropriate group + # from echodata attrs provenance, + # set default empty dict for missing group + attrs = ed_attrs_dict.get(ed_group, {}) + else: + # This is for non combined echodata object + attrs = [ds.attrs] + ds_attrs += attrs + # Attribute holding attrs_dict[ed_group] = ds_attrs @@ -753,15 +838,16 @@ def _combine( # Data holding tree_dict[ed_group] = combined_ds - if not all_combined: - # Capture provenance for all the attributes - prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model) + # Capture provenance for all the attributes + prov_ds = _capture_prov_attrs(attrs_dict, echodata_filenames, sonar_model) + if not any_combined: # Update the provenance dataset with the captured data prov_ds = tree_dict["Provenance"].assign(prov_ds) else: - prov_ds = tree_dict["Provenance"] + prov_ds = tree_dict["Provenance"].drop_dims(ED_FILENAME).assign(prov_ds) + # Update filenames to iter integers - prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape)) + prov_ds[FILENAMES] = prov_ds[FILENAMES].copy(data=np.arange(*prov_ds[FILENAMES].shape)) # noqa tree_dict["Provenance"] = prov_ds return tree_dict diff --git a/echopype/tests/echodata/test_echodata_combine.py b/echopype/tests/echodata/test_echodata_combine.py index f1f5b6f33..5359a1b8f 100644 --- a/echopype/tests/echodata/test_echodata_combine.py +++ b/echopype/tests/echodata/test_echodata_combine.py @@ -233,6 +233,26 @@ def attr_time_to_dt(time_str): assert test_ds.identical(combined_group.drop_dims(grp_drop_dims)) +def _check_prov_ds(prov_ds, eds): + """Checks the Provenance dataset against source_filenames variable + and global attributes in the original echodata object""" + for i in range(prov_ds.dims["echodata_filename"]): + ed_ds = eds[i] + one_ds = prov_ds.isel(echodata_filename=i, filenames=i) + for key, value in one_ds.data_vars.items(): + if key == "source_filenames": + ed_group = "Provenance" + assert np.array_equal( + ed_ds[ed_group][key].isel(filenames=0).values, value.values + ) + else: + ed_group = value.attrs.get("echodata_group") + expected_val = ed_ds[ed_group].attrs[key] + if not isinstance(expected_val, str): + expected_val = str(expected_val) + assert str(value.values) == expected_val + + @pytest.mark.parametrize("test_param", [ "single", "multi", @@ -263,6 +283,17 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona # First combined file combined_ed = echopype.combine_echodata(eds[:2]) combined_ed.to_zarr(first_zarr, overwrite=True) + + def _check_prov_ds_and_dims(sel_comb_ed, n_val_expected): + prov_ds = sel_comb_ed["Provenance"] + for _, n_val in prov_ds.dims.items(): + assert n_val == n_val_expected + _check_prov_ds(prov_ds, eds) + + # Checks for Provenance group + # Both dims of filenames and echodata filename should be 2 + expected_n_vals = 2 + _check_prov_ds_and_dims(combined_ed, expected_n_vals) # Second combined file combined_ed_other = echopype.combine_echodata(eds[2:]) @@ -271,48 +302,47 @@ def test_combine_echodata_combined_append(ek60_multi_test_data, test_param, sona combined_ed = echopype.open_converted(first_zarr) combined_ed_other = echopype.open_converted(second_zarr) + # Set expected values for Provenance if test_param == "single": data_inputs = [combined_ed, eds[2]] + expected_n_vals = 3 elif test_param == "multi": data_inputs = [combined_ed, eds[2], eds[3]] + expected_n_vals = 4 else: data_inputs = [combined_ed, combined_ed_other] - combined_ed2 = echopype.combine_echodata( - data_inputs - ) + expected_n_vals = 4 + + combined_ed2 = echopype.combine_echodata(data_inputs) + # Verify that combined objects are all EchoData objects assert isinstance(combined_ed, EchoData) assert isinstance(combined_ed_other, EchoData) assert isinstance(combined_ed2, EchoData) # Ensure that they're from the same file source - assert eds[0]['Provenance'].source_filenames[0].values == combined_ed['Provenance'].source_filenames[0].values - assert eds[1]['Provenance'].source_filenames[0].values == combined_ed['Provenance'].source_filenames[1].values - assert eds[2]['Provenance'].source_filenames[0].values == combined_ed2['Provenance'].source_filenames[2].values - if test_param != "single": - assert eds[3]['Provenance'].source_filenames[0].values == combined_ed2['Provenance'].source_filenames[3].values + group_path = "Provenance" + for i in range(4): + ds_i = eds[i][group_path] + select_comb_ds = combined_ed[group_path] if i < 2 else combined_ed2[group_path] + if i < 3 or (i == 3 and test_param != "single"): + assert ds_i.source_filenames[0].values == select_comb_ds.source_filenames[i].values # Check beam_group1. Should be exactly same xr dataset group_path = "Sonar/Beam_group1" - ds0 = eds[0][group_path] - filt_ds0 = combined_ed[group_path].sel(ping_time=ds0.ping_time) - assert filt_ds0.identical(ds0) is True - - ds1 = eds[1][group_path] - filt_ds1 = combined_ed[group_path].sel(ping_time=ds1.ping_time) - assert filt_ds1.identical(ds1) is True - - ds2 = eds[2][group_path] - filt_ds2 = combined_ed2[group_path].sel(ping_time=ds2.ping_time) - assert filt_ds2.identical(ds2) is True - - if test_param != "single": - ds3 = eds[3][group_path] - filt_ds3 = combined_ed2[group_path].sel(ping_time=ds3.ping_time) - assert filt_ds3.identical(ds3) is True + for i in range(4): + ds_i = eds[i][group_path] + select_comb_ds = combined_ed[group_path] if i < 2 else combined_ed2[group_path] + if i < 3 or (i == 3 and test_param != "single"): + filt_ds_i = select_comb_ds.sel(ping_time=ds_i.ping_time) + assert filt_ds_i.identical(ds_i) is True filt_combined = combined_ed2[group_path].sel(ping_time=combined_ed[group_path].ping_time) - assert filt_combined.identical(combined_ed[group_path]) + assert filt_combined.identical(combined_ed[group_path]) is True + + # Checks for Provenance group + # Both dims of filenames and echodata filename should be expected_n_vals + _check_prov_ds_and_dims(combined_ed2, expected_n_vals) def test_combine_echodata_channel_selection():