diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index 319c8a9a7a0..a1261c1e016 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -11,17 +11,32 @@ from .indexing import get_indexer_nd from .utils import is_dict_like, is_full_slice from .variable import IndexVariable, Variable +from .indexes import or_with_tolerance, and_with_tolerance, equal_indexes_with_tolerance if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset -def _get_joiner(join): +def _get_joiner(join, tolerance=0): if join == "outer": - return functools.partial(functools.reduce, operator.or_) + if tolerance == 0: + or_ = operator.or_ + else: + + def or_(x1, x2, tolerance=tolerance): + return or_with_tolerance(x1, x2, tolerance) + + return functools.partial(functools.reduce, or_) elif join == "inner": - return functools.partial(functools.reduce, operator.and_) + if tolerance == 0: + and_ = operator.and_ + else: + + def and_(x1, x2, tolerance=tolerance): + return and_with_tolerance(x1, x2, tolerance) + + return functools.partial(functools.reduce, and_) elif join == "left": return operator.itemgetter(0) elif join == "right": @@ -65,6 +80,7 @@ def align( indexes=None, exclude=frozenset(), fill_value=dtypes.NA, + tolerance=0, ): """ Given any number of Dataset and/or DataArray objects, returns new @@ -107,6 +123,8 @@ def align( Value to use for newly missing values. If a dict-like, maps variable names to fill values. Use a data array's name to refer to its values. + tolerance: numerical + Value used to check equality between the coordinate values with numerical tolerance. Returns ------- @@ -284,30 +302,41 @@ def align( # - It ensures it's possible to do operations that don't require alignment # on indexes with duplicate values (which cannot be reindexed with # pandas). This is useful, e.g., for overwriting such duplicate indexes. - joiner = _get_joiner(join) joined_indexes = {} for dim, matching_indexes in all_indexes.items(): if dim in indexes: index = utils.safe_cast_to_index(indexes[dim]) if ( - any(not index.equals(other) for other in matching_indexes) - or dim in unlabeled_dim_sizes - ): + not equal_indexes_with_tolerance(index, matching_indexes, tolerance) + ) or dim in unlabeled_dim_sizes: joined_indexes[dim] = index - else: + elif ( + any(not matching_indexes[0].equals(other) for other in matching_indexes[1:]) + or dim in unlabeled_dim_sizes + ): if ( - any( - not matching_indexes[0].equals(other) - for other in matching_indexes[1:] + not equal_indexes_with_tolerance( + matching_indexes[0], matching_indexes[1:], tolerance ) or dim in unlabeled_dim_sizes ): if join == "exact": raise ValueError(f"indexes along dimension {dim!r} are not equal") + # this logic could be moved out if _get_joiner is changed into _do_join(join, matching_index, tolerance) + if (tolerance > 0) and all( + index.is_numeric() for index in matching_indexes + ): + joiner = _get_joiner(join, tolerance) + else: + joiner = _get_joiner(join) index = joiner(matching_indexes) joined_indexes[dim] = index else: + # they are identical within tolerance, reindexing is necessary index = matching_indexes[0] + joined_indexes[dim] = index + else: + index = matching_indexes[0] # I believe this line is not useful... if dim in unlabeled_dim_sizes: unlabeled_sizes = unlabeled_dim_sizes[dim] @@ -336,6 +365,14 @@ def align( if not valid_indexers: # fast path for no reindexing necessary new_obj = obj.copy(deep=copy) + elif tolerance > 0: + new_obj = obj.reindex( + copy=copy, + fill_value=fill_value, + **valid_indexers, + tolerance=tolerance, + method="nearest", + ) else: new_obj = obj.reindex(copy=copy, fill_value=fill_value, **valid_indexers) new_obj.encoding = obj.encoding diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 846e4044a2c..08b7629bed9 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -13,6 +13,7 @@ cast, ) +import numpy as np import pandas as pd from . import formatting, indexing @@ -107,8 +108,49 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index: return self._data.get_index(dim) # type: ignore else: indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore - names = list(ordered_dims) - return pd.MultiIndex.from_product(indexes, names=names) + + # compute the sizes of the repeat and tile for the cartesian product + # (taken from pandas.core.reshape.util) + index_lengths = np.fromiter( + (len(index) for index in indexes), dtype=np.intp + ) + cumprod_lengths = np.cumproduct(index_lengths) + + if cumprod_lengths[-1] != 0: + # sizes of the repeats + repeat_counts = cumprod_lengths[-1] / cumprod_lengths + else: + # if any factor is empty, the cartesian product is empty + repeat_counts = np.zeros_like(cumprod_lengths) + + # sizes of the tiles + tile_counts = np.roll(cumprod_lengths, 1) + tile_counts[0] = 1 + + # loop over the indexes + # for each MultiIndex or Index compute the cartesian product of the codes + + code_list = [] + level_list = [] + names = [] + + for i, index in enumerate(indexes): + if isinstance(index, pd.MultiIndex): + codes, levels = index.codes, index.levels + else: + code, level = pd.factorize(index) + codes = [code] + levels = [level] + + # compute the cartesian product + code_list += [ + np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i]) + for code in codes + ] + level_list += levels + names += index.names + + return pd.MultiIndex(level_list, code_list, names=names) def update(self, other: Mapping[Hashable, Any]) -> None: other_vars = getattr(other, "variables", other) diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 84cf35d3b4f..13908117877 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -143,3 +143,55 @@ def propagate_indexes( new_indexes = None # type: ignore return new_indexes + + +def equal_indexes_with_tolerance( + index: pd.Index, others: Iterable[pd.Index], tolerance: float +) -> bool: + + return all(other.equals(index) for other in others) or ( + tolerance > 0 + and index.is_numeric() + and all(other.is_numeric() for other in others) + and all(np.allclose(index, other, atol=tolerance, rtol=0) for other in others) + ) + + +def unique_with_tolerance(x, tolerance): + # check uniqueness based on the difference between successive values + # this eliminates all the numbers that are within tolerance of the previous one. + # a number can be removed even if the previous one was removed + + x = np.sort(x) + # x is sorted + notclose = np.diff(x) > tolerance + + unique = x[1:][notclose] + return np.insert(unique, 0, x[0]) + + +def or_with_tolerance(x1, x2, tolerance): + return unique_with_tolerance(np.concatenate((x1, x2)), tolerance=tolerance) + + +def and_with_tolerance(x1, x2, tolerance): + # return values common in the two arrays (within some tolerance) + + x = np.concatenate((x1, x2)) + + # sort x, remember the origin (x1 or x2) + ind = np.argsort(x) + x = x[ind] + + # check for close values + close = np.diff(x) < tolerance + + origin = ind < len(x1) + first = origin[:-1][close] # where is the 'first' coming from + second = origin[1:][close] # where is the 'second' coming from + + intersect = ( + first != second + ) # they must come from the two different arrays to be an interestion + + return x[:-1][close][intersect] diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 5e0fe13ea52..314c0eb391a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -453,7 +453,7 @@ def test_equals_and_identical(self): assert not expected.identical(actual) actual = expected.copy() - actual["a"] = 100000 + actual["a"] = 100_000 assert not expected.equals(actual) assert not expected.identical(actual) @@ -3372,6 +3372,37 @@ def test_align_without_indexes_errors(self): DataArray([1, 2], coords=[("x", [0, 1])]), ) + def test_align_with_tolerance_when_nearly_equal_index(self): + da1 = DataArray([1, 2, 3], coords=[("x", [1.0001, 2.0004, 2.9999])]) + + da2 = DataArray([4, 5, 6], coords=[("x", [1, 2, 3])]) + + aligned_da1, aligned_da2 = align(da1, da2, tolerance=4e-4) + + assert_identical(aligned_da1.x, da1.x) + assert_identical(aligned_da2.x, da1.x) + + def test_align_with_tolereance_and_intersect(self): + da1 = DataArray([1, 2, 3, 4], coords=[("x", [1.0001, 1.5, 2.0004, 2.9999])]) + + da2 = DataArray([5, 6, 7, 8], coords=[("x", [1, 2, 3, 5])]) + + aligned_da1, aligned_da2 = align(da1, da2, tolerance=4e-4, join="inner") + + assert_identical(aligned_da1.x, aligned_da2.x) + print(aligned_da1.x) + assert_array_equal(aligned_da1.x, np.array([1.0, 2.0, 2.9999])) + + def test_align_with_tolereance_and_union(self): + da1 = DataArray([1, 2, 3, 4], coords=[("x", [1.0001, 1.5, 2.0004, 2.9999])]) + + da2 = DataArray([5, 6, 7, 8], coords=[("x", [1, 2, 3, 5])]) + + aligned_da1, aligned_da2 = align(da1, da2, tolerance=4e-4, join="outer") + + assert_identical(aligned_da1.x, aligned_da2.x) + assert_array_equal(aligned_da1.x, np.array([1.0, 1.5, 2.0, 2.9999, 5.0])) + def test_broadcast_arrays(self): x = DataArray([1, 2], coords=[("a", [-1, -2])], name="x") y = DataArray([1, 2], coords=[("b", [3, 4])], name="y") @@ -3520,6 +3551,33 @@ def test_to_dataframe(self): with raises_regex(ValueError, "unnamed"): arr.to_dataframe() + def test_to_dataframe_multiindex(self): + # regression test for #3008 + arr_np = np.random.randn(4, 3) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo") + + actual = arr.to_dataframe() + assert_array_equal(actual["foo"].values, arr_np.flatten()) + assert_array_equal(actual.index.names, list("ABC")) + assert_array_equal(actual.index.levels[0], [1, 2]) + assert_array_equal(actual.index.levels[1], ["a", "b"]) + assert_array_equal(actual.index.levels[2], [5, 6, 7]) + + def test_to_dataframe_0length(self): + # regression test for #3008 + arr_np = np.random.randn(4, 0) + + mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"]) + + arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo") + + actual = arr.to_dataframe() + assert len(actual) == 0 + assert_array_equal(actual.index.names, list("ABC")) + def test_to_pandas_name_matches_coordinate(self): # coordinate with same name as array arr = DataArray([1, 2, 3], dims="x", name="x")