-
-
Notifications
You must be signed in to change notification settings - Fork 136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow splitting to any number of partitions #829
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! I left a bunch of comments all over the place, would appreciate to hear your opinion on them.
@@ -135,21 +135,45 @@ def filter( | |||
def partition( | |||
pytree: PyTree, | |||
filter_spec: PyTree[AxisSpec], | |||
*filter_specs: PyTree[AxisSpec], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dislike breaking backwards compatibility! Could you instead make filter_specs
a keyword argument? This will be less ergonomic, since will force you to actually spell eqx.partition(..., filter_specs=())
, but it is worth it.
And, the ordering of the filters matters in your implementation, so it is more 'logical' to have them as a sequence, and not as a variadic: in the current implementation it is not clear that order of filters matters a lot, but if it is another parameter, then it's pretty clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW if we made this a keyword then we'd need to allow not passing filter_spec
itself in the case that it's used.
Hmm, this backward compatibility break may sink this whole endeavour.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO forcing replace
and is_leaf
to be keyword arguments is actually an improvement as it could prevent user errors. It does break backward compatibility slightly, but I don't think there are many (if any) codebases using these arguments as positional. Note that jax.tree.map
forces is_leaf
as a keyword argument.
Concerning the Sequence
instead of variadic, I disagree. map
or even jax.tree.map
uses variadic arguments in a similar manner.
More generally, provide $N$ filter specifications to split the tree into $N + 1$ | ||
non-overlapping partitions. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you slightly extend on the documentation? Before I looked at the code, it was not obvious what multiple partitioning does. My first assumption was that you would split things with some redundancy:
everything_that_is_array, everything_that_is_inexact_array, rest = eqx.filter(x, eqx.is_array, eqx.is_inexact_array)
I feel like this is more in line with the previous partitioning logic, at least that is how my brains understand the standard partitioning.
Actually, I really like this way to specify things, but it does indeed break backwards compatibility :(
I would probably implement it as a separate partition function, e.g. eqx.multi_partition
? This way the syntax could be wonderfully concise, and compatibility will be kept intact. Anyhow, that is for Patrick to decide, I'm just passing by.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in.
The current 'only split into 2' scenario necessitates making this explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When a leaf satisfies several filters, it is contained in the partition corresponding to the first filter that is satisfied, as determined by the user-provided order. The last partition is dedicated to leaves that do not satisfy any filter.
This is aligned with eqx.combine
which assumes that each leaf is only represented in one of the partitions.
filter_trees = [ | ||
jtu.tree_map(_make_filter_tree(is_leaf), spec, pytree) | ||
for spec in (filter_spec, *filter_specs) | ||
] | ||
|
||
partitions = [] | ||
|
||
for i in range(len(filter_trees)): | ||
partition = jtu.tree_map( | ||
lambda x, curr, *prev: x if curr and not any(prev) else replace, | ||
pytree, | ||
filter_trees[i], | ||
*filter_trees[:i], | ||
) | ||
|
||
partitions.append(partition) | ||
|
||
rest = jtu.tree_map( | ||
lambda x, prev: replace if any(prev) else x, | ||
pytree, | ||
*filter_trees, | ||
) | ||
|
||
return *partitions, rest |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Especially considering how much new code is added into this 'simple' function, separating it into a distinct function makes sense to me.
|
||
for i in range(len(filter_trees)): | ||
partition = jtu.tree_map( | ||
lambda x, curr, *prev: x if curr and not any(prev) else replace, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure it should be any
here, not jnp.any
? Could you also test this under jit, just to make sure it works correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand your concern here. The filters should return True
or False
values, never arrays. In fact, you should never use booleans arrays for control flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@knyazer has spotted some meaningful flaws, I think.
For that reason I'm afraid I'm inclined not to take this PR -- neither the backward compatibility break nor the unclear semantics seem worth the trade-off for this feature, I'm afraid.
@@ -135,21 +135,45 @@ def filter( | |||
def partition( | |||
pytree: PyTree, | |||
filter_spec: PyTree[AxisSpec], | |||
*filter_specs: PyTree[AxisSpec], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW if we made this a keyword then we'd need to allow not passing filter_spec
itself in the case that it's used.
Hmm, this backward compatibility break may sink this whole endeavour.
More generally, provide $N$ filter specifications to split the tree into $N + 1$ | ||
non-overlapping partitions. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does also speak to another good point, how do overlapping filters work? It's not really clear which group things end up in.
The current 'only split into 2' scenario necessitates making this explicit.
Targets #824
Hi @patrick-kidger, in this PR I modify$N$ ) number of filter specifications and returns $N + 1$ partitions. To keep backward compatibility, I have kept the
eqx.partition
such that it accepts any (filter_spec
argument and added a variadic positional argument*filter_specs
. This however forcesreplace
andis_leaf
to become keyword arguments. I think that is ok.