Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion: Allow bfloat16 use to improve speed/memory usage #25

Open
pokepress opened this issue Feb 24, 2024 · 1 comment
Open

Suggestion: Allow bfloat16 use to improve speed/memory usage #25

pokepress opened this issue Feb 24, 2024 · 1 comment

Comments

@pokepress
Copy link

Hi, as you may have noticed, I've been developing a fork of this in an attempt to repurpose to train models for upscaling AM and FM radio recordings. In any event, while researching possible ways to improve speed and memory consumption, I started looking at 16-bit floating point formats. The standard float16 doesn't appear to have enough range to be useful for this project (I ended up with nan values relatively quickly), but bfloat16 (which uses the same number of exponent bits as float32) seems to work quite well and does speed up training a bit and has a significant impact on memory usage. You can see an example of its implementation in a recent commit I made. Note that I did have to restructure the code so loss calculation is done using standard 32-bit float values (as recommended by pytorch).

@m-mandel
Copy link
Collaborator

m-mandel commented Mar 1, 2024

Thank you for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants