Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Thread the scaling type argument throughout fp8 #301

Open
wants to merge 9 commits into
base: gh/drisspg/1/base
Choose a base branch
from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jul 3, 2024

Summary

This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale.

  • Currently hardcodes TensroWise.Scaling in Float8Linear, Float8DynamicLinear, Float8InferenceLinear. Asserts that granularity is TensorWise for now.
  • Added this as a property of WeightWithDynamicFloat8CastTensor, since we need to know a prior how do do the scaling for fp8 comms.

Testing

============================================================================= test session starts =============================================================================
platform linux -- Python 3.12.4, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/drisspg/meta/float8_experimental
plugins: hypothesis-6.104.1
collected 9 items                                                                                                                                                             

test/test_fsdp2/test_fsdp2_eager.py .........                                                                                                                           [100%]

============================================================================= 9 passed in 30.77s ==============================================================================
all tests successful

Stack from ghstack (oldest at bottom):

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jul 3, 2024
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254
Pull Request resolved: #301
@drisspg drisspg marked this pull request as draft July 3, 2024 00:10
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: 705ded1417d32b52fec3bae871c5f0c2922a5d0e
Pull Request resolved: #301
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: 583ea3369732127d122e45e182e1d1bc7c45fcc0
Pull Request resolved: #301
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: c34e19a3ce0453fd1abb2b05db6bbb60ce3c90b8
Pull Request resolved: #301
@drisspg drisspg changed the title threading the needle Thread through the scaling type argument throughout fp8 Jul 3, 2024
@drisspg drisspg changed the title Thread through the scaling type argument throughout fp8 Thread the scaling type argument throughout fp8 Jul 3, 2024
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: 333db4234b522fa07eced1b63c3d998317955c74
Pull Request resolved: #301
amax_buffer: Optional[torch.Tensor] = None,
mm_config: Optional[ScaledMMConfig] = None,
float8_dtype: torch.dtype,
amax_buffer: Optional[torch.Tensor],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the defualt args since this is always called from inner func with defualt args

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: aa6f0c03f3fffaee5277337518c520a5895719ba
Pull Request resolved: #301
[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: b09361e159b17dafe7940b24b3482ed482bba811
Pull Request resolved: #301
@drisspg drisspg marked this pull request as ready for review July 3, 2024 05:48
@@ -31,6 +28,20 @@
)


class ScalingStrategy(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thoughts about using Granularity, which is more specific than Strategy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah thats a better word, this needed some bikeshedding

return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config)
return Float8Tensor(
bits_fp8,
x_scale,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious, since we decided to not add scaling_strategy to torch._scaled_mm, why do we need it here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this a property of Float8Tensor, e.g. infer from the existing scales... hmm
Actually I might like this more..
We need the enum still since we want modules to specify their granularity

[ghstack-poisoned]
drisspg added a commit that referenced this pull request Jul 3, 2024
ghstack-source-id: 6f9b9299f4429ede127c0ed639a652d8888e947a
Pull Request resolved: #301
# Summary

This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale.
 - Currently hardcodes TensroWise.Scaling in Float8Linear, Float8DynamicLinear, Float8InferenceLinear. Asserts that granularity is TensorWise for now.
 -  Added this as a property of WeightWithDynamicFloat8CastTensor, since we need to know a prior how do do the scaling for fp8 comms.  



### Testing

``` Shell
============================================================================= test session starts =============================================================================
platform linux -- Python 3.12.4, pytest-7.4.0, pluggy-1.5.0
rootdir: /home/drisspg/meta/float8_experimental
plugins: hypothesis-6.104.1
collected 9 items                                                                                                                                                             

test/test_fsdp2/test_fsdp2_eager.py .........                                                                                                                           [100%]

============================================================================= 9 passed in 30.77s ==============================================================================
all tests successful

```





[ghstack-poisoned]
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants