diff --git a/requirements.txt b/requirements.txt index 2f8e1e5f..8b8e6ada 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +dask xarray julia tqdm diff --git a/tests/test_get_bitinformation.py b/tests/test_get_bitinformation.py index 2fcb3feb..939e5442 100644 --- a/tests/test_get_bitinformation.py +++ b/tests/test_get_bitinformation.py @@ -63,72 +63,106 @@ def bitinfo_assert_different(bitinfo1, bitinfo2): assert (bitinfo1 != bitinfo2).any() -def test_get_bitinformation_returns_dataset(): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_returns_dataset(implementation): """Test xb.get_bitinformation returns xr.Dataset.""" ds = xr.tutorial.load_dataset("rasm") - assert isinstance(xb.get_bitinformation(ds, axis=0), xr.Dataset) + assert isinstance( + xb.get_bitinformation(ds, implementation=implementation, axis=0), xr.Dataset + ) -def test_get_bitinformation_dim(): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_dim(implementation): """Test xb.get_bitinformation is sensitive to dim.""" ds = xr.tutorial.load_dataset("rasm") - bitinfo0 = xb.get_bitinformation(ds, axis=0) - bitinfo2 = xb.get_bitinformation(ds, axis=2) + bitinfo0 = xb.get_bitinformation(ds, axis=0, implementation=implementation) + bitinfo2 = xb.get_bitinformation(ds, axis=2, implementation=implementation) assert_different(bitinfo0, bitinfo2) -def test_get_bitinformation_dim_string_equals_axis_int(): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_dim_string_equals_axis_int(implementation): """Test xb.get_bitinformation undestands xarray dimension names the same way as axis as integers.""" ds = xr.tutorial.load_dataset("rasm") - bitinfox = xb.get_bitinformation(ds, dim="x") - bitinfo2 = xb.get_bitinformation(ds, axis=2) + bitinfox = xb.get_bitinformation(ds, dim="x", implementation=implementation) + bitinfo2 = xb.get_bitinformation(ds, axis=2, implementation=implementation) assert_identical(bitinfox, bitinfo2) -def test_get_bitinformation_masked_value(): +def test_get_bitinformation_masked_value(implementation="julia"): """Test xb.get_bitinformation is sensitive to masked_value.""" ds = xr.tutorial.load_dataset("rasm") - bitinfo = xb.get_bitinformation(ds, dim="x") - bitinfo_no_mask = xb.get_bitinformation(ds, dim="x", masked_value="nothing") - bitinfo_no_mask_None = xb.get_bitinformation(ds, dim="x", masked_value=None) + bitinfo = xb.get_bitinformation(ds, dim="x", implementation=implementation) + bitinfo_no_mask = xb.get_bitinformation( + ds, dim="x", masked_value="nothing", implementation=implementation + ) + bitinfo_no_mask_None = xb.get_bitinformation( + ds, dim="x", masked_value=None, implementation=implementation + ) assert_identical(bitinfo_no_mask, bitinfo_no_mask_None) assert_different(bitinfo, bitinfo_no_mask) -def test_get_bitinformation_set_zero_insignificant(): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_set_zero_insignificant(implementation): """Test xb.get_bitinformation is sensitive to set_zero_insignificant.""" ds = xr.tutorial.load_dataset("air_temperature") dim = "lon" - bitinfo_szi_False = xb.get_bitinformation(ds, dim=dim, set_zero_insignificant=False) - bitinfo_szi_True = xb.get_bitinformation(ds, dim=dim, set_zero_insignificant=True) - bitinfo = xb.get_bitinformation(ds, dim=dim) - assert_different(bitinfo, bitinfo_szi_False) - assert_identical(bitinfo, bitinfo_szi_True) - - -def test_get_bitinformation_confidence(): + bitinfo = xb.get_bitinformation(ds, dim=dim, implementation=implementation) + bitinfo_szi_False = xb.get_bitinformation( + ds, dim=dim, set_zero_insignificant=False, implementation=implementation + ) + try: + bitinfo_szi_True = xb.get_bitinformation( + ds, dim=dim, set_zero_insignificant=True, implementation=implementation + ) + assert_identical(bitinfo, bitinfo_szi_True) + except NotImplementedError: + assert implementation == "python" + if implementation == "python": + assert_identical(bitinfo, bitinfo_szi_False) + elif implementation == "julia": + assert_different(bitinfo, bitinfo_szi_False) + + +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_confidence(implementation): """Test xb.get_bitinformation is sensitive to confidence.""" ds = xr.tutorial.load_dataset("air_temperature") dim = "lon" - bitinfo_conf99 = xb.get_bitinformation(ds, dim=dim, confidence=0.99) - bitinfo_conf50 = xb.get_bitinformation(ds, dim=dim, confidence=0.5) - bitinfo = xb.get_bitinformation(ds, dim=dim) - assert_different(bitinfo_conf99, bitinfo_conf50) - assert_identical(bitinfo, bitinfo_conf99) - - -def test_get_bitinformation_label(rasm): + bitinfo = xb.get_bitinformation(ds, dim=dim, implementation=implementation) + try: + bitinfo_conf99 = xb.get_bitinformation( + ds, dim=dim, confidence=0.99, implementation=implementation + ) + bitinfo_conf50 = xb.get_bitinformation( + ds, dim=dim, confidence=0.5, implementation=implementation + ) + assert_different(bitinfo_conf99, bitinfo_conf50) + assert_identical(bitinfo, bitinfo_conf99) + except AssertionError: + assert implementation == "python" + + +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_label(rasm, implementation): """Test xb.get_bitinformation serializes when label given.""" ds = rasm - xb.get_bitinformation(ds, dim="x", label="./tmp_testdir/rasm") + xb.get_bitinformation( + ds, dim="x", label="./tmp_testdir/rasm", implementation=implementation + ) assert os.path.exists("./tmp_testdir/rasm.json") # second call should be faster - xb.get_bitinformation(ds, dim="x", label="./tmp_testdir/rasm") + xb.get_bitinformation( + ds, dim="x", label="./tmp_testdir/rasm", implementation=implementation + ) os.remove("./tmp_testdir/rasm.json") +@pytest.mark.parametrize("implementation", ["julia", "python"]) @pytest.mark.parametrize("dtype", ["float64", "float32", "float16"]) -def test_get_bitinformation_dtype(rasm, dtype): +def test_get_bitinformation_dtype(rasm, dtype, implementation): """Test xb.get_bitinformation returns correct number of bits depending on dtype.""" ds = rasm.astype(dtype) v = list(ds.data_vars)[0] @@ -138,10 +172,11 @@ def test_get_bitinformation_dtype(rasm, dtype): ) -def test_get_bitinformation_multidim(rasm): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_multidim(rasm, implementation): """Test xb.get_bitinformation runs on all dimensions by default""" ds = rasm - bi = xb.get_bitinformation(ds) + bi = xb.get_bitinformation(ds, implementation=implementation) # check length of dimension assert bi.dims["dim"] == len(ds.dims) bi_time = bi.sel(dim="time").Tair.values @@ -152,28 +187,31 @@ def test_get_bitinformation_multidim(rasm): assert any(bi_y != bi_x) -def test_get_bitinformation_different_variables_dims(rasm): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_different_variables_dims(rasm, implementation): """Test xb.get_bitinformation runs with variables of different dimensionality""" ds = rasm # add variable with different dimensionality ds["Tair_mean"] = ds.Tair.mean(dim="time") - bi = xb.get_bitinformation(ds) + bi = xb.get_bitinformation(ds, implementation=implementation) assert all(np.isnan(bi.Tair_mean.sel(dim="time"))) bi_Tair_mean_x = bi.Tair_mean.sel(dim="x") bi_Tair_x = bi.Tair.sel(dim="x") assert_different(bi_Tair_mean_x, bi_Tair_x) -def test_get_bitinformation_different_dtypes(rasm): +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_different_dtypes(rasm, implementation): ds = rasm ds["Tair32"] = ds.Tair.astype("float32") ds["Tair16"] = ds.Tair.astype("float16") - bi = xb.get_bitinformation(ds) + bi = xb.get_bitinformation(ds, implementation=implementation) for bitdim in ["bit16", "bit32", "bit64"]: assert bitdim in bi.dims assert bitdim in bi.coords -def test_get_bitinformation_dim_list(rasm): - bi = xb.get_bitinformation(rasm, dim=["x", "y"]) +@pytest.mark.parametrize("implementation", ["julia", "python"]) +def test_get_bitinformation_dim_list(rasm, implementation): + bi = xb.get_bitinformation(rasm, dim=["x", "y"], implementation=implementation) assert (bi.dim == ["x", "y"]).all() diff --git a/xbitinfo/_py_bitinfo.py b/xbitinfo/_py_bitinfo.py new file mode 100644 index 00000000..65c2d7b1 --- /dev/null +++ b/xbitinfo/_py_bitinfo.py @@ -0,0 +1,69 @@ +import dask.array as da +import numpy as np +import numpy.ma as nm + + +def bitpaircount_u1(a, b): + assert a.dtype == "u1" + assert b.dtype == "u1" + unpack_a = ( + a.flatten() + .map_blocks( + np.unpackbits, + drop_axis=0, + meta=np.array((), dtype=np.uint8), + chunks=(a.size * 8,), + ) + .astype("u1") + ) + unpack_b = ( + b.flatten() + .map_blocks( + np.unpackbits, + drop_axis=0, + meta=np.array((), dtype=np.uint8), + chunks=(b.size * 8,), + ) + .astype("u1") + ) + index = ((unpack_a << 1) | unpack_b).reshape(-1, 8) + + selection = np.array([0, 1, 2, 3], dtype="u1") + sel = np.where((index[..., np.newaxis]) == selection, True, False) + to_return = sel.sum(axis=0).reshape(8, 2, 2) + return to_return + + +def bitpaircount(a, b): + assert a.dtype.kind == "u" + assert b.dtype.kind == "u" + nbytes = max(a.dtype.itemsize, b.dtype.itemsize) + + a, b = np.broadcast_arrays(a, b) + + bytewise_counts = [] + for i in range(nbytes): + s = (nbytes - 1 - i) * 8 + bitc = bitpaircount_u1((a >> s).astype("u1"), (b >> s).astype("u1")) + bytewise_counts.append(bitc) + return np.concatenate(bytewise_counts, axis=0) + + +def mutual_information(a, b, base=2): + size = np.prod(np.broadcast_shapes(a.shape, b.shape)) + counts = bitpaircount(a, b) + + p = counts.astype("float") / size + p = da.ma.masked_equal(p, 0) + pr = p.sum(axis=-1)[..., np.newaxis] + ps = p.sum(axis=-2)[..., np.newaxis, :] + mutual_info = (p * np.log(p / (pr * ps))).sum(axis=(-1, -2)) / np.log(base) + return mutual_info + + +def bitinformation(a, axis=0): + sa = tuple(slice(0, -1) if i == axis else slice(None) for i in range(len(a.shape))) + sb = tuple( + slice(1, None) if i == axis else slice(None) for i in range(len(a.shape)) + ) + return mutual_information(a[sa], a[sb]) diff --git a/xbitinfo/xbitinfo.py b/xbitinfo/xbitinfo.py index 7da5baac..08efb4a3 100644 --- a/xbitinfo/xbitinfo.py +++ b/xbitinfo/xbitinfo.py @@ -4,17 +4,18 @@ import numpy as np import xarray as xr +from dask import array as da from julia.api import Julia from tqdm.auto import tqdm from . import __version__ +from . import _py_bitinfo as pb from .julia_helpers import install already_ran = False if not already_ran: already_ran = install(quiet=True) - jl = Julia(compiled_modules=False, debug=False) from julia import Main # noqa: E402 @@ -89,7 +90,15 @@ def dict_to_dataset(info_per_bit): return dsb -def get_bitinformation(ds, dim=None, axis=None, label=None, overwrite=False, **kwargs): +def get_bitinformation( + ds, + dim=None, + axis=None, + label=None, + overwrite=False, + implementation="julia", + **kwargs, +): """Wrap `BitInformation.jl.bitinformation() `__. Parameters @@ -106,12 +115,16 @@ def get_bitinformation(ds, dim=None, axis=None, label=None, overwrite=False, **k Label of the json to serialize bitinfo. When string, serialize results to disk into file ``{{label}}.json`` to be reused later. Defaults to ``None``. overwrite : bool If ``False``, try using serialized bitinfo based on label; if true or label does not exist, run bitinformation + implementation : str + Bitinformation algorithm implementation. Valid options are + - julia, the original implementation of julia in julia by Milan Kloewer + - python, a copy of the core functionality of julia in python kwargs to be passed to bitinformation: - - masked_value: defaults to ``NaN`` (different to ``bitinformation.jl`` defaulting to ``"nothing"``), set ``None`` disable masking + - masked_value: defaults to ``NaN`` (different to ``julia`` defaulting to ``"nothing"``), set ``None`` disable masking - mask: use ``masked_value`` instead - - set_zero_insignificant (``bool``): defaults to ``True`` + - set_zero_insignificant (``bool``): defaults to ``True`` (julia implementation) or ``False`` (python implementation) - confidence (``float``): defaults to ``0.99`` @@ -157,12 +170,22 @@ def get_bitinformation(ds, dim=None, axis=None, label=None, overwrite=False, **k if dim is None and axis is None: # gather bitinformation on all axis return _get_bitinformation_along_dims( - ds, dim=dim, label=label, overwrite=overwrite, **kwargs + ds, + dim=dim, + label=label, + overwrite=overwrite, + implementation=implementation, + **kwargs, ) if isinstance(dim, list) and axis is None: # gather bitinformation on dims specified return _get_bitinformation_along_dims( - ds, dim=dim, label=label, overwrite=overwrite, **kwargs + ds, + dim=dim, + label=label, + overwrite=overwrite, + implementation=implementation, + **kwargs, ) else: # gather bitinformation along one axis @@ -193,31 +216,22 @@ def get_bitinformation(ds, dim=None, axis=None, label=None, overwrite=False, **k pbar = tqdm(ds.data_vars) for var in pbar: pbar.set_description("Processing %s" % var) - X = ds[var].values - Main.X = X - if axis is not None: - # in julia convention axis + 1 - axis_jl = axis + 1 - dim = ds[var].dims[axis] - if isinstance(dim, str): - try: - # in julia convention axis + 1 - axis_jl = ds[var].get_axis_num(dim) + 1 - except ValueError: - logging.info( - f"Variable [var] does not have dimension {dim}. Skipping." - ) + if implementation == "julia": + info_per_bit_var = _jl_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: continue - assert isinstance(axis_jl, int) - Main.dim = axis_jl - kwargs_str = _get_bitinformation_kwargs_handler(ds[var], kwargs) - logging.debug(f"get_bitinformation(X, dim={dim}, {kwargs_str})") - info_per_bit[var] = {} - info_per_bit[var]["bitinfo"] = jl.eval( - f"get_bitinformation(X, dim={axis_jl}, {kwargs_str})" - ) - info_per_bit[var]["dim"] = dim - info_per_bit[var]["axis"] = axis_jl - 1 + else: + info_per_bit[var] = info_per_bit_var + elif implementation == "python": + info_per_bit_var = _py_get_bitinformation(ds, var, axis, dim, kwargs) + if info_per_bit_var is None: + continue + else: + info_per_bit[var] = info_per_bit_var + else: + raise ValueError( + f"Implementation of bitinformation algorithm {implementation} is unknown. Please choose a different one." + ) if label is not None: with open(label + ".json", "w") as f: logging.debug(f"Save bitinformation to {label + '.json'}") @@ -225,7 +239,68 @@ def get_bitinformation(ds, dim=None, axis=None, label=None, overwrite=False, **k return dict_to_dataset(info_per_bit) -def _get_bitinformation_along_dims(ds, dim=None, label=None, overwrite=False, **kwargs): +def _jl_get_bitinformation(ds, var, axis, dim, kwargs={}): + X = ds[var].values + Main.X = X + if axis is not None: + # in julia convention axis + 1 + axis_jl = axis + 1 + dim = ds[var].dims[axis] + if isinstance(dim, str): + try: + # in julia convention axis + 1 + axis_jl = ds[var].get_axis_num(dim) + 1 + except ValueError: + logging.info(f"Variable {var} does not have dimension {dim}. Skipping.") + return + assert isinstance(axis_jl, int) + Main.dim = axis_jl + kwargs_str = _get_bitinformation_kwargs_handler(ds[var], kwargs) + logging.debug(f"get_bitinformation(X, dim={dim}, {kwargs_str})") + info_per_bit = {} + info_per_bit["bitinfo"] = jl.eval( + f"get_bitinformation(X, dim={axis_jl}, {kwargs_str})" + ) + info_per_bit["dim"] = dim + info_per_bit["axis"] = axis_jl - 1 + return info_per_bit + + +def _py_get_bitinformation(ds, var, axis, dim, kwargs={}): + if "set_zero_insignificant" in kwargs.keys(): + if kwargs["set_zero_insignificant"]: + raise NotImplementedError( + "set_zero_insignificant is not implemented in the python implementation" + ) + else: + assert ( + kwargs == {} + ), "This implementation only supports the plain bitinfo implementation" + X = da.array(ds[var]).astype(np.uint) + if axis is not None: + dim = ds[var].dims[axis] + if isinstance(dim, str): + try: + axis = ds[var].get_axis_num(dim) + except ValueError: + logging.info(f"Variable {var} does not have dimension {dim}. Skipping.") + return + info_per_bit = {} + logging.info("Calling python implementation now") + info_per_bit["bitinfo"] = pb.bitinformation(X, axis=axis).compute() + info_per_bit["dim"] = dim + info_per_bit["axis"] = axis + return info_per_bit + + +def _get_bitinformation_along_dims( + ds, + dim=None, + label=None, + overwrite=False, + implementation="julia", + **kwargs, +): """Helper function for :py:func:`xbitinfo.xbitinfo.get_bitinformation` to handle multi-dimensional analysis for each dim specified. Simple wrapper around :py:func:`xbitinfo.xbitinfo.get_bitinformation`, which calls :py:func:`xbitinfo.xbitinfo.get_bitinformation`