Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[wip] make all 3 gemms in float8 linear configurable #258

Closed
wants to merge 1 commit into from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented May 10, 2024

Summary:

This PR ensures each of the 3 gemms in fw+bw of the float8 linear can have its own configuration.

The user interface is clean: the float8 module has separate configs for each of the gemms. Configuring this can be exposed in module utilites in a future PR.

The implementation is a bit ugly. given the 3 gemms in fw+bw

Y = X @ W_t
gradX = gradY @ W
gradW = X_t @ gradY

and the fact that a torch.mm doesn't have access to any global state other than its arguments, we need to ensure that the matmul arguments contain the right state to map to the right config. We adopt a simple set of rules to do this:

  1. if only one of the arguments to mm has a config, use the defined config
  2. if both arguments of mm have a config, use the second argument's config
  3. in the float8 modules, do the following:
    3a. set X's config for arg0 to config_Y
    3b. set W's config for arg0 and arg1 to None
    3c. set gradY's config for arg0 to config_gradX, and for arg1 to
    config_gradW

If 3 is done correctly, following 1 and 2 will lead to the right config being used for each gemm. It's ugly, but it works.

Test Plan:

for now, works on single GPU

pytest -s test/test_base.py
pytest -s test/test_compile.py

just a question of eng time to also make the distributed tests pass

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

This PR ensures each of the 3 gemms in fw+bw of the float8 linear
can have its own configuration.

The user interface is clean: the float8 module has separate configs for
each of the gemms. Configuring this can be exposed in module utilites in
a future PR.

The implementation is a bit ugly. given the 3 gemms in fw+bw

```
Y = X @ W_t
gradX = gradY @ W
gradW = X_t @ gradY
```

and the fact that a `torch.mm` doesn't have access to any global state
other than its arguments, we need to ensure that the matmul arguments
contain the right state to map to the right config.  We adopt a simple
set of rules to do this:
1. if only one of the arguments to mm has a config, use the defined config
2. if both arguments of mm have a config, use the second argument's
   config
3. in the float8 modules, do the following:
3a. set X's config for arg0 to config_Y
3b. set W's config for arg0 and arg1 to None
3c. set gradY's config for arg0 to config_gradX, and for arg1 to
    config_gradW

If 3 is done correctly, following 1 and 2 will lead to the right config
being used for each gemm.  It's ugly, but it works.

Test Plan:

for now, works on single GPU
```
pytest -s test/test_base.py
pytest -s test/test_compile.py
```

just a question of eng time to also make the distributed tests pass

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 10, 2024
@vkuzo
Copy link
Contributor Author

vkuzo commented Jul 19, 2024

closing in favor of #315

@vkuzo vkuzo closed this Jul 19, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants