Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

bring back torch.autograd.Function for float8 matmul #344

Closed
wants to merge 1 commit into from

Commits on Jul 26, 2024

  1. bring back torch.autograd.Function for float8 matmul

    Summary:
    
    This is a redo of
    #316
    
    With upcoming support of scaling granularities other than tensorwise,
    we need a good way to control which gemm kernel to call and how to scale
    the input tensors in fwd and bwd. A `torch.autograd.Function` override
    is the cleanest way to do that, and in 2024 this now works with
    `torch.compile`.
    
    Test Plan:
    
    ```
    ./test/test_everything.sh
    ```
    
    Reviewers:
    
    Subscribers:
    
    Tasks:
    
    Tags:
    
    [ghstack-poisoned]
    vkuzo committed Jul 26, 2024
    Configuration menu
    Copy the full SHA
    f5beb24 View commit details
    Browse the repository at this point in the history