Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve order of variables in combine_by_coords #9070

Merged
merged 22 commits into from
Jan 30, 2025

Conversation

kmuehlbauer
Copy link
Contributor

@kmuehlbauer kmuehlbauer commented Jun 5, 2024

@kmuehlbauer kmuehlbauer changed the title FIX: do not sort datasets in combine_by_coords Preserve order of variables in in combine_by_coords Jun 5, 2024
@kmuehlbauer kmuehlbauer changed the title Preserve order of variables in in combine_by_coords Preserve order of variables in combine_by_coords Jun 5, 2024
@TomNicholas
Copy link
Member

I think the reason I originally put the sort call in there was because in the itertools.groupby docs https://docs.python.org/3/library/itertools.html#itertools.groupby it says

Generally, the iterable needs to already be sorted on the same key function.

But if it seems to work without that then I guess it's fine?

@kmuehlbauer
Copy link
Contributor Author

I think this was from times where we had to deal with unsorted dict. At least that was what I understood from a previous comment at that code position.

@kmuehlbauer
Copy link
Contributor Author

Generally, the iterable needs to already be sorted on the same key function.

Now, in light of this... But somehow it works.

@kmuehlbauer
Copy link
Contributor Author

OK, let me add some more testing to be sure this works in datasets of any order.

@keewis
Copy link
Collaborator

keewis commented Jun 5, 2024

the reason itertools.groupby needs sorted iterables is that it combines groups locally, not globally. So list(map(list, itertools.groupby([1, 1, 2, 1, 3], key=lambda x: x))) would result in [(1, [1, 1]), (2, [2]), (1, [1]), (3, [3])], not [(1, [1, 1, 1]), (2, [2]), (3, [3])] (this is not the case for more_itertools.bucket and toolz.itertoolz.groupby).

@kmuehlbauer
Copy link
Contributor Author

kmuehlbauer commented Jun 6, 2024

Thanks @keewis. AFAICT, the sorting before groupby makes sure that all Datasets with same variables are grouped together regardless of the variable order within each Dataset. Update: And the position in the object list.

Removing the sorting will result in more groups but doesn't break the tests. Does that mean we are undertesting?. Or is it just fixed by the subsequent merge?

The issue which should be solved by this PR is that this sorting rearranges the input objects and when using compat="override" might move the first object to another position, resulting in wrong output. Then, I was thinking to special case compat="override" to keep the first object at it's place, but didn't find a way to incorporate that with the groupby.

@kmuehlbauer
Copy link
Contributor Author

After reading on itertools and collections I've found that it's possible to change itertools.groupby with a collections.defaultdict implementation to preserve order and with some intermediate performance gain:

import xarray as xr
import collections
import itertools

def vars_as_keys(ds):
    return tuple(sorted(ds))

def groupby_defaultdict(iter, key=lambda x: x):
    idx = collections.defaultdict(list)
    for i, obj in enumerate(iter):
        idx[key(obj)].append(i)
    for k, ix in idx.items():
        yield k, (iter[i] for i in ix)
        
def groupby_itertools(iter, key=lambda x: x):
    iter = sorted(iter, key=vars_as_keys)
    return itertools.groupby(iter, key=vars_as_keys)

x1 = xr.Dataset({"a": (("y", "x"), [[1]]),
                 "c": (("y", "x"), [[1]]),
                 "b": (("y", "x"), [[1]]),}, 
                coords={"y": [0], "x": [0]})
x2 = xr.Dataset({"d": (("y", "x"), [[1]]),
                 "a": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [0]})
x3 = xr.Dataset({"a": (("y", "x"), [[3]]),
                 "d": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [1]})
data_objects = [x2, x1, x3]
%%timeit
grouped_by_vars = groupby_defaultdict(data_objects, key=vars_as_keys)
274 ns ± 0.934 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
%%timeit
grouped_by_vars = groupby_itertools(data_objects, key=vars_as_keys)
4.98 µs ± 14 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

@kmuehlbauer
Copy link
Contributor Author

It looks like this can be replaced here too:

# TODO: is the sorted need?
combined_ids = dict(sorted(combined_ids.items(), key=_new_tile_id))
grouped = itertools.groupby(combined_ids.items(), key=_new_tile_id)

@dcherian dcherian requested a review from TomNicholas June 6, 2024 18:15
@kmuehlbauer kmuehlbauer added the run-benchmark Run the ASV benchmark workflow label Jun 7, 2024
Copy link
Collaborator

@headtr1ck headtr1ck left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, some type hints works be nice!

xarray/core/combine.py Outdated Show resolved Hide resolved
xarray/tests/test_combine.py Outdated Show resolved Hide resolved
@kmuehlbauer
Copy link
Contributor Author

From all what I read about the comparison itertools.groupby vs. defaultdict(list) both have their use cases.

Both have O(n) time complexity, but the sorting step of itertools.groupby makes this a O(n log n). The major PRO for itertools.groupby is the lazy yielding of groups, the major CON is the need of pre-sorting (which is the root-cause of #8828 and what this PR should fix).

The major PRO for defaultdict(list) is that it can work on unsorted input in a fast way, the major CON is it is storing the whole result in memory.

I'm not sure if there is another solution including itertools.groupby to fix #8828, so if someone has another idea of tackling that, please shoot.

xarray/core/combine.py Outdated Show resolved Hide resolved
xarray/core/combine.py Outdated Show resolved Hide resolved
xarray/core/combine.py Outdated Show resolved Hide resolved
@kmuehlbauer kmuehlbauer reopened this Jan 27, 2025
@kmuehlbauer
Copy link
Contributor Author

kmuehlbauer commented Jan 27, 2025

I've fixed the typing issue and replaced the second occurrence here. For the above shown minibenchmark there is a slight performance gain. The asv bench doesn't show a performance change.

The major win is that it fixes #8828. Ready for another review.

@kmuehlbauer
Copy link
Contributor Author

Do you see any blockers here @keewis, @TomNicholas?

@kmuehlbauer
Copy link
Contributor Author

I've added this to #10002 and will merge this tomorrow, if there are no complaints.

@TomNicholas
Copy link
Member

Sorry for taking so long to reply to this 😓

AFAICT, the sorting before groupby makes sure that all Datasets with same variables are grouped together regardless of the variable order within each Dataset. Update: And the position in the object list.

IIRC the point of the groupby is to group datasets into sets whose variables have common concatenation dimensions. See CWorthy-ocean/roms-tools#223 (comment) for an example of a use case.

Removing the sorting will result in more groups but doesn't break the tests. Does that mean we are undertesting?. Or is it just fixed by the subsequent merge?

I think as long as those are just subgroups of the original groups (i.e. those created by the code that does sorting) then yes the subsequent merge should fix it. The reason being that any subset of a group containing variables with common concatenation dimensions should also have those same common concatenation dimensions?

@kmuehlbauer
Copy link
Contributor Author

Sorry for taking so long to reply to this 😓

No worries :-) . So, essentially the proposed code does the same groupby, but keeps the initial order. Did you had a chance to test your notebook against this PR?

@kmuehlbauer
Copy link
Contributor Author

Just for the record, the example from above with the according output. We can see that the defaultdict_groupby generates the same groups and keeps them in their natural order. I think this is along the lines the behaviour @keewis mentioned for more_itertools.bucket and toolz.itertoolz.groupby. So I'd say we can replace the current implementation with the proposed.

import xarray as xr
import collections
import itertools

def vars_as_keys(ds):
    return tuple(sorted(ds))

def groupby_defaultdict(iter, key=lambda x: x):
    idx = collections.defaultdict(list)
    for i, obj in enumerate(iter):
        idx[key(obj)].append(i)
    for k, ix in idx.items():
        yield k, (iter[i] for i in ix)
        
def groupby_itertools(iter, key=lambda x: x):
    iter = sorted(iter, key=vars_as_keys)
    return itertools.groupby(iter, key=vars_as_keys)

x1 = xr.Dataset({"a": (("y", "x"), [[1]]),
                 "c": (("y", "x"), [[1]]),
                 "b": (("y", "x"), [[1]]),}, 
                coords={"y": [0], "x": [0]})
x2 = xr.Dataset({"d": (("y", "x"), [[1]]),
                 "a": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [0]})
x3 = xr.Dataset({"a": (("y", "x"), [[3]]),
                 "d": (("y", "x"), [[2]]),},
                coords={"y": [0], "x": [1]})
data_objects = [x2, x1, x3]

for vars, ds_with_same_vars in groupby_defaultdict(data_objects, key=vars_as_keys):
    print(vars, tuple(ds_with_same_vars))
print("---------------------------------------------------------------------")
for vars, ds_with_same_vars in groupby_itertools(data_objects, key=vars_as_keys):
    print(vars, tuple(ds_with_same_vars))
('a', 'd') (<xarray.Dataset> Size: 32B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 0
Data variables:
    d        (y, x) int64 8B 1
    a        (y, x) int64 8B 2, <xarray.Dataset> Size: 32B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 1
Data variables:
    a        (y, x) int64 8B 3
    d        (y, x) int64 8B 2)
('a', 'b', 'c') (<xarray.Dataset> Size: 40B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 0
Data variables:
    a        (y, x) int64 8B 1
    c        (y, x) int64 8B 1
    b        (y, x) int64 8B 1,)
---------------------------------------------------------------------
('a', 'b', 'c') (<xarray.Dataset> Size: 40B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 0
Data variables:
    a        (y, x) int64 8B 1
    c        (y, x) int64 8B 1
    b        (y, x) int64 8B 1,)
('a', 'd') (<xarray.Dataset> Size: 32B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 0
Data variables:
    d        (y, x) int64 8B 1
    a        (y, x) int64 8B 2, <xarray.Dataset> Size: 32B
Dimensions:  (y: 1, x: 1)
Coordinates:
  * y        (y) int64 8B 0
  * x        (x) int64 8B 1
Data variables:
    a        (y, x) int64 8B 3
    d        (y, x) int64 8B 2)

Copy link
Member

@TomNicholas TomNicholas left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can see that the groupby_defaultdict generates the same groups and keeps them in their natural order

This sounds fine to me.

@kmuehlbauer kmuehlbauer enabled auto-merge (squash) January 30, 2025 16:49
@kmuehlbauer kmuehlbauer merged commit 5b3f127 into pydata:main Jan 30, 2025
28 checks passed
@kmuehlbauer kmuehlbauer deleted the fix-combine branch January 30, 2025 17:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
run-benchmark Run the ASV benchmark workflow
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dataset combine_by_coords unexpected behavior
5 participants