Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adaptive Pooling #50

Open
dlwh opened this issue Nov 6, 2023 · 0 comments
Open

Adaptive Pooling #50

dlwh opened this issue Nov 6, 2023 · 0 comments
Labels
help wanted Extra attention is needed

Comments

@dlwh
Copy link
Member

dlwh commented Nov 6, 2023

Adaptive Pooling is a pain in JAX/XLA, at least if you want to get the same results as torch in the general case. (Equinox doesn't aim for torch equivalence here, and other major JAX frameworks don't even implement it, except sometimes for the boring case when input % output == 0)

https://stackoverflow.com/a/63603993/1736826 seems to be a correct description of how it works, at least in the 1-d case. You end up with overlapping windows of differing sizes, which can't be turned into a reduce_window call (I think?) I think you can maybe get it done with a second quasi-mask argument to reduce window, but i haven't figured it out yet.

Maybe going for torch-equivalence here isn't worth it? Maybe it can be done w/ Pallas?

Seems like the easiest thing would be to do ~symmetric padding but no one seems to do it that way?

@dlwh dlwh added the help wanted Extra attention is needed label Nov 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

1 participant