Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

bring back torch.autograd.Function #316

Open
wants to merge 5 commits into
base: gh/vkuzo/29/base
Choose a base branch
from

Conversation

vkuzo
Copy link
Contributor

@vkuzo vkuzo commented Jul 16, 2024

Stack from ghstack (oldest at bottom):

Summary:

I want to plan for how we are going to add scaling granularities in the Python layer of float8 code. Today, we only have per-tensor scaling which is transposeable. For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice

The modeling logic in Float8Linear for a/b would look like:

def forward(self, x):
    if scaling_type == TENSORWISE:
        x_maybe_fp8 = to_fp8_tensorwise(x, ...)
    elif scaling_type == ROWWISE:
        x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
    # repeat for w

    y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)

And, there are at least two choices I see for float8_mm_op:

# Option 1 (current code without this PR): use the torch.mm override
@implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
    ...

# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
    ...

To support future scaling granularities, whichever choice we go with will have to do something like below:

def float8_mm(x_maybe_fp8, w_maybe_fp8):
    if isinstance(x_maybe_fp8, Float8Tensor):
        x_fp8 = x_maybe_fp8
    else:
        x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
    # repeat for w
    # call torch._scaled_mm

Furthermore, to keep things readable / debuggable, it would be good to:

  1. be able to print tensors before/after quantization
  2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module

To do the above, we'll need to pass around metadata such as module FQNs.

This PR implements Option 2 as IMO this is more readable/debuggable.

Test plan:

// all green
./test/test_everything.sh

Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 16, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 09c4625b2a859ce6468bac328d5f0ff61bb86251
Pull Request resolved: #316
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 16, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

```

# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files  
# and modified to only support dynamic scaling                                                       
#                                                                                                    
# Why do we want a torch.autograd.Function here? Vasiliy's opinion is that                           
# as we add more scaling granularities, keeping the scaling code close to Float8Linear               
# will be really useful for readability and debuggability of numerics.                               
#                                                                                                    
# For example, a future PR to add rowwise scaling could do                                           
#                                                                                                    
#   # forward                                                                                        
#   x_bf16 = ...                                                                                     
#   if scaling_granularity == ScalingGranularity.PER_TENSOR:                                         
#       # we can scale the same way for fwd/bwd                                                      
#       x_maybe_fp8 = to_fp8(...)                                                                    
#   else:                                                                                            
#       assert scaling_granularity == ScalingGranularity.PER_ROW:                                    
#       # defer scaling to float8_mm                                                                 
#       x_maybe_fp8 = x_bf16                                                                         
#                                                                                                    
#   # repeat for w                                                                                   
#                                                                                                    
#   y_bf16 = float8_mm(x_maybe_fp8, w_maybe_fp8)                                                     
#                                                                                                    
#   Requirements for float8_mm                                                                       
#   - composes with DTensor, compile, autograd                                                       
#   - readable/debuggable                                                                            
#                                                                                                    
#   Option 1 (this PR): float8_mm is a torch.autograd.Function                                       
#   - pros                                                                                           
#   - cons                                                                                           
#   Option 2 (current code without this PR): float8_mm is an override of torch.mm                    
#   - pros                                                                                           
#   - cons                                                                                           
#                                                                                                    

```

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 16, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 75842f4858804bc6f204eb55222a493ea9074630
Pull Request resolved: #316
Summary:

I want to plan for how we are going to add scaling granularities in the Python layer of float8 code.  Today, we only have per-tensor scaling which is transposeable.  For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice

The modeling logic in Float8Linear for a/b would look like:

```python
def forward(self, x):
    if scaling_type == TENSORWISE:
        x_maybe_fp8 = to_fp8_tensorwise(x, ...)
    elif scaling_type == ROWWISE:
        x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
    # repeat for w

    y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)
```

And, there are at least two choices I see for `float8_mm_op`:

```python
# Option 1 (current code without this PR): use the torch.mm override
implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
    ...

# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
    ...
```

To support future scaling granularities, whichever choice we go with will have to do something like below:

```python
def float8_mm(x_maybe_fp8, w_maybe_fp8):
    if isinstance(x_maybe_fp8, Float8Tensor):
        x_fp8 = x_maybe_fp8
    else:
        x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
    # repeat for w
    # call torch._scaled_mm
```

Furthermore, to keep things readable / debuggable, it would be good to:
1. be able to print tensors before/after quantization
2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module

To do the above, we'll need to pass around metadata such as module FQNs.

We should discuss whether we want Option 1 (keep overriding torch.mm) or Option 2 (torch.autograd.Function).

Vasiliy: I think Option 2 is cleaner/more readable/more debuggable, modeling code is usually written in the module or similar torch.autograd.Function overrides.  I would consider scaling tensors to float8 modeling code, and it's unintuitive IMO for this to happen deep inside op overrides.  However, Option 1 is less risky technically as we avoid torch.autograd.Function which is less mature in interactions with torch.compile.  While the current PR is all green, we are using `allow_in_graph` which is a bit unsafe.

Test plan:

```
// all green
./test/test_everything.sh
```

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 22, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: ea5bd3e7ec037b351703154363a6bbe9f4f638d5
Pull Request resolved: #316
@vkuzo vkuzo changed the title [TBD if for land] bring back torch.autograd.Function bring back torch.autograd.Function Jul 22, 2024
Summary:

I want to plan for how we are going to add scaling granularities in the Python layer of float8 code.  Today, we only have per-tensor scaling which is transposeable.  For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice

The modeling logic in Float8Linear for a/b would look like:

```python
def forward(self, x):
    if scaling_type == TENSORWISE:
        x_maybe_fp8 = to_fp8_tensorwise(x, ...)
    elif scaling_type == ROWWISE:
        x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
    # repeat for w

    y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)
```

And, there are at least two choices I see for `float8_mm_op`:

```python
# Option 1 (current code without this PR): use the torch.mm override
implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
    ...

# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
    ...
```

To support future scaling granularities, whichever choice we go with will have to do something like below:

```python
def float8_mm(x_maybe_fp8, w_maybe_fp8):
    if isinstance(x_maybe_fp8, Float8Tensor):
        x_fp8 = x_maybe_fp8
    else:
        x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
    # repeat for w
    # call torch._scaled_mm
```

Furthermore, to keep things readable / debuggable, it would be good to:
1. be able to print tensors before/after quantization
2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module

To do the above, we'll need to pass around metadata such as module FQNs.

This PR implements Option 2 as IMO this is more readable/debuggable.

Test plan:

```
// all green
./test/test_everything.sh
```

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 22, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: c1b8d0c42c7f73fb8a0c8806ae88479a40d0be40
Pull Request resolved: #316
@vkuzo vkuzo requested review from bdhirsh and drisspg July 22, 2024 21:44
@@ -71,6 +71,54 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
scale.copy_(new_scale)


# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Does the structure work out to put this in float8 ops?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in how things look after this PR it would make sense, but might be good to see how the code looks after we add different granularities and the if/else branches on when to convert to lower precision. Maybe we can revisit then?

return res_bits

@staticmethod
def backward(ctx, go_fp8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: align go_fp8 / other naming to the other PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do that in separate PRs, since not user facing. Just keeping things small.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont know if that changes the size of the PR much but sure thats fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just a style difference on how to sequence the renames, either is ok IMO

Copy link
Contributor

@drisspg drisspg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, TBH I think this is a good balance of both subclassing + autograd func

Summary:

I want to plan for how we are going to add scaling granularities in the Python layer of float8 code.  Today, we only have per-tensor scaling which is transposeable.  For other types of scaling such as rowwise, the scaling is not transposeable and the user needs to choose what to do between fwd and bwd:
a. keep the bf16 copy to be able to rescale across dim0 and dim1
b. scale bf16 across dim0/dim1, keep that, then requantize along the other dim in the bw (reduce memory usage, lose some precision)
c. keep some of the gemms in bf16 to avoid the need to scale twice

The modeling logic in Float8Linear for a/b would look like:

```python
def forward(self, x):
    if scaling_type == TENSORWISE:
        x_maybe_fp8 = to_fp8_tensorwise(x, ...)
    elif scaling_type == ROWWISE:
        x_maybe_fp8 = to_fp8_rowwise(x, dim=0, ...)
    # repeat for w

    y = float8_mm_op(x_maybe_fp8, w_maybe_fp8, ...)
```

And, there are at least two choices I see for `float8_mm_op`:

```python
# Option 1 (current code without this PR): use the torch.mm override
implements([aten.mm.default, aten.matmul.default])
def float8_mm(aten_op, args, kwargs=None):
    ...

# Option 2 (this PR): use torch.autograd.Function
class float8_mm(torch.autograd.Function):
    ...
```

To support future scaling granularities, whichever choice we go with will have to do something like below:

```python
def float8_mm(x_maybe_fp8, w_maybe_fp8):
    if isinstance(x_maybe_fp8, Float8Tensor):
        x_fp8 = x_maybe_fp8
    else:
        x_fp8 = to_fp8(x_maybe_fp8, scaling_granularity, ...)
    # repeat for w
    # call torch._scaled_mm
```

Furthermore, to keep things readable / debuggable, it would be good to:
1. be able to print tensors before/after quantization
2. be able to associate tensors to their parent module, and the specific gemm in fwd/bwd in that module

To do the above, we'll need to pass around metadata such as module FQNs.

This PR implements Option 2 as IMO this is more readable/debuggable.

Test plan:

```
// all green
./test/test_everything.sh
```

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 22, 2024
Summary:

This approach is more readable as we add additional scaling options.

For now, seeing how many things break in 2024-07 with
torch.autograd.Function + subclasses + compile.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: a7abb00cce87d18273d3bb18996eebb2bb0c4c99
Pull Request resolved: #316
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 6cb1588bf59be73b5782f6af94e7a360eba7f40e
Pull Request resolved: #336
vkuzo added a commit that referenced this pull request Jul 25, 2024
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 42dd59511e4ec2a55846c2593955c4ff5f12b254
Pull Request resolved: #336
vkuzo added a commit that referenced this pull request Jul 25, 2024
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 25, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 26, 2024
…at8 matmul"

Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068)

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 26, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D60252068](https://our.internmc.facebook.com/intern/diff/D60252068)

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 26, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
vkuzo added a commit that referenced this pull request Jul 26, 2024
Summary:

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Test Plan:

```
./test/test_everything.sh
```

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Jul 26, 2024
Summary:
Pull Request resolved: #344

This is a redo of
#316

With upcoming support of scaling granularities other than tensorwise,
we need a good way to control which gemm kernel to call and how to scale
the input tensors in fwd and bwd. A `torch.autograd.Function` override
is the cleanest way to do that, and in 2024 this now works with
`torch.compile`.

Reviewed By: drisspg

Differential Revision: D60291446

fbshipit-source-id: 472f392227bca1c7f83ea0c1234285bc576e58d2
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants