Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

enumerate breakages of torch.compile + Float8Linear + FSDP/TP/SP #168

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Dec 20, 2023

Summary:

This PR creates tests for composability of torch.compile + Float8Linear with:

  1. FSDP - this works if we do torch.compile(FSDP(model)), and only apply torch.compile after the 1st iteration (lol)
  2. TP/SP - torch.compile graph breaks on fairscale's RowParallelLinear/ColumnParallelLinear even without float8 (so it works but slow, https://gist.github.com/vkuzo/670b2806e222bef04da5f173c758a165 ), and our float8 code currently crashes on graph breaks, so float8 + torch.compile + TP/SP is broken

Test Plan:

./test_everything.sh

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

This PR creates tests for composability of torch.compile + Float8Linear
with:
1. FSDP - this works if we do torch.compile(FSDP(model)), and only apply
   torch.compile after the 1st iteration (lol)
2. TP/SP - torch.compile graph breaks on fairscale's
   RowParallelLinear/ColumnParallelLinear even without float8 (so it
   works but slow), and our float8 code currently crashes on graph
   breaks, so float8 + torch.compile + TP/SP is broken

Test Plan:

```
./test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 20, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants