Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AdEMAMix optimizer #1360

Merged
merged 5 commits into from
Sep 20, 2024
Merged

Add AdEMAMix optimizer #1360

merged 5 commits into from
Sep 20, 2024

Conversation

matthewdouglas
Copy link
Member

@matthewdouglas matthewdouglas commented Sep 16, 2024

Adds support for the AdEMAMix optimizer described here: https://arxiv.org/abs/2409.03137

Includes blockwise 8bit and 32bit versions, each supporting paged operation.

AdEMAMix is a modification to Adam which introduces an additional EMA component. It is observed that AdEMAMix can forget training data at a slower pace and can reach similar loss as AdamW with significantly less training data.

TODO: Implement scheduler for alpha/beta3

@matthewdouglas matthewdouglas added the enhancement New feature or request label Sep 16, 2024
Comment on lines +60 to +62
# For parity with bnb implementation we combine both fast
# and slow EMA stats into one stacked tensor.
state["m1_m2"] = p.new_zeros((2, *p.size()))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is done for ease of compatibility with the existing test suite. In most other implementations we'll see two separate buffers here.

// AdEMAMix has an additional state buffer, which we packed
// into state1. We need thread-local storage here for these.
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD];
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a few extra memory allocations like this to support AdEMAMix. Have not confirmed if the compiler is optimizing these out for instantiations with OPTIMIZER=ADAM, but if not, the overhead isn't very much.

TimDettmers
TimDettmers previously approved these changes Sep 20, 2024
Copy link
Collaborator

@TimDettmers TimDettmers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks all good to me.

Copy link

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@matthewdouglas matthewdouglas merged commit d964546 into main Sep 20, 2024
54 checks passed
matthewdouglas added a commit to matthewdouglas/bitsandbytes that referenced this pull request Oct 28, 2024
* Add AdEMAMix optimizer

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* AdEMAMix: add support for alpha/beta3 scheduling

* Update paged AdEMAMix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants