From eba3cca29795c042fcfd92b9be7507e03901031a Mon Sep 17 00:00:00 2001 From: Don Setiawan Date: Fri, 4 Aug 2023 10:34:45 -0700 Subject: [PATCH] feat: pin netcdf to >1.6 and add explicit encoding [all tests ci] (#1112) * fix(deps): unpin netcdf4 * feat: add explicit encoding for netcdf vars * test(utils): add set_netcdf_encodings test * ci: add error reporter for pr gh workflow * test: update file group check to use datatree * ci: fix error reporter to run only on py39 --- .github/workflows/pr.yaml | 6 +++ .../test_convert_source_target_locs.py | 7 ++-- echopype/tests/utils/test_coding.py | 40 +++++++++++++++++- echopype/utils/coding.py | 42 ++++++++++++++----- requirements.txt | 3 +- 5 files changed, 81 insertions(+), 17 deletions(-) diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index e327098e9..90a07a385 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -63,6 +63,12 @@ jobs: - name: Install dev tools run: | micromamba install -c conda-forge -n ${{ env.CONDA_ENV }} --yes --file requirements-dev.txt + # We only want to install this on one run, because otherwise we'll have + # duplicate annotations. + - name: Install error reporter + if: ${{ matrix.python-version == '3.9' }} + run: | + python -m pip install pytest-github-actions-annotate-failures - name: Install echopype run: | python -m pip install -e .[plot] diff --git a/echopype/tests/convert/test_convert_source_target_locs.py b/echopype/tests/convert/test_convert_source_target_locs.py index 843f7ad83..66cbbfe40 100644 --- a/echopype/tests/convert/test_convert_source_target_locs.py +++ b/echopype/tests/convert/test_convert_source_target_locs.py @@ -12,15 +12,16 @@ import fsspec import xarray as xr import pytest +from datatree import open_datatree from tempfile import TemporaryDirectory from echopype import open_raw from echopype.utils.coding import DEFAULT_ENCODINGS def _check_file_group(data_file, engine, groups): - for g in groups: - ds = xr.open_dataset(data_file, engine=engine, group=g) - + tree = open_datatree(data_file, engine=engine) + for group in groups: + ds = tree[f"/{group}"].ds assert isinstance(ds, xr.Dataset) is True diff --git a/echopype/tests/utils/test_coding.py b/echopype/tests/utils/test_coding.py index 52f9e4aeb..a848b8d61 100644 --- a/echopype/tests/utils/test_coding.py +++ b/echopype/tests/utils/test_coding.py @@ -4,7 +4,7 @@ import math import dask -from echopype.utils.coding import _get_auto_chunk +from echopype.utils.coding import _get_auto_chunk, set_netcdf_encodings @pytest.mark.parametrize( "chunk", @@ -31,3 +31,41 @@ def test__get_auto_chunk(chunk): assert chunk_byte_size == dask_data.nbytes, "Default chunk is not equal to data array size!" else: assert chunk_byte_size <= dask.utils.parse_bytes(chunk), "Calculated chunk exceeded max chunk!" + +def test_set_netcdf_encodings(): + # create a test dataset + ds = xr.Dataset( + { + "var1": xr.DataArray(np.random.rand(10), dims="dim1"), + "var2": xr.DataArray(np.random.rand(10), dims="dim1", attrs={"attr1": "value1"}), + "var3": xr.DataArray(["a", "b", "c"], dims="dim2"), + }, + attrs={"global_attr": "global_value"}, + ) + + # test with default compression settings + encoding = set_netcdf_encodings(ds, {}) + assert isinstance(encoding, dict) + assert len(encoding) == 3 + assert "var1" in encoding + assert "var2" in encoding + assert "var3" in encoding + assert encoding["var1"]["zlib"] is True + assert encoding["var1"]["complevel"] == 4 + assert encoding["var2"]["zlib"] is True + assert encoding["var2"]["complevel"] == 4 + assert encoding["var3"]["zlib"] is False + + # test with custom compression settings + compression_settings = {"zlib": True, "complevel": 5} + encoding = set_netcdf_encodings(ds, compression_settings) + assert isinstance(encoding, dict) + assert len(encoding) == 3 + assert "var1" in encoding + assert "var2" in encoding + assert "var3" in encoding + assert encoding["var1"]["zlib"] is True + assert encoding["var1"]["complevel"] == 5 + assert encoding["var2"]["zlib"] is True + assert encoding["var2"]["complevel"] == 5 + assert encoding["var3"]["zlib"] is False diff --git a/echopype/utils/coding.py b/echopype/utils/coding.py index e199840ee..626aeb6e5 100644 --- a/echopype/utils/coding.py +++ b/echopype/utils/coding.py @@ -1,5 +1,5 @@ from re import search -from typing import Tuple +from typing import Any, Dict, Tuple import numpy as np import xarray as xr @@ -212,18 +212,38 @@ def set_zarr_encodings( return encoding -def set_netcdf_encodings(ds: xr.Dataset, compression_settings: dict) -> dict: - """ - Obtains all variable encodings based on netcdf default values +def set_netcdf_encodings( + ds: xr.Dataset, + compression_settings: Dict[str, Any] = {}, +) -> Dict[str, Dict[str, Any]]: """ + Obtains all variables encodings based on netcdf default values + + Parameters + ---------- + ds : xr.Dataset + The dataset object to generate encoding for + compression_settings : dict + The compression settings dictionary - # TODO: below is the encoding we were using for netcdf, we need to make - # sure that the encoding is appropriate for all data variables - encoding = ( - {var: compression_settings for var in ds.data_vars} - if compression_settings is not None - else {} - ) + Returns + ------- + dict + The final encoding values for dataset variables + """ + encoding = dict() + for name, val in ds.variables.items(): + encoding[name] = {**val.encoding} + if np.issubdtype(val.dtype, np.str_): + encoding[name].update( + { + "zlib": False, + } + ) + elif compression_settings: + encoding[name].update(compression_settings) + else: + encoding[name].update(COMPRESSION_SETTINGS["netcdf4"]) return encoding diff --git a/requirements.txt b/requirements.txt index 9d5823b1c..eb5ca32c8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ dask[array,distributed] jinja2 -# https://github.com/Unidata/netcdf4-python/issues/1175#issuecomment-1173142506 -netCDF4<1.6 +netCDF4>1.6 numpy pynmea2 pytz