-
Notifications
You must be signed in to change notification settings - Fork 20
Thread the scaling type argument throughout fp8 #301
base: gh/drisspg/1/base
Are you sure you want to change the base?
Conversation
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: a740bcf5ce2098160870ee8da085861e956c5254 Pull Request resolved: #301
ghstack-source-id: 705ded1417d32b52fec3bae871c5f0c2922a5d0e Pull Request resolved: #301
ghstack-source-id: 583ea3369732127d122e45e182e1d1bc7c45fcc0 Pull Request resolved: #301
ghstack-source-id: c34e19a3ce0453fd1abb2b05db6bbb60ce3c90b8 Pull Request resolved: #301
ghstack-source-id: 333db4234b522fa07eced1b63c3d998317955c74 Pull Request resolved: #301
amax_buffer: Optional[torch.Tensor] = None, | ||
mm_config: Optional[ScaledMMConfig] = None, | ||
float8_dtype: torch.dtype, | ||
amax_buffer: Optional[torch.Tensor], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the defualt args since this is always called from inner func with defualt args
ghstack-source-id: aa6f0c03f3fffaee5277337518c520a5895719ba Pull Request resolved: #301
ghstack-source-id: b09361e159b17dafe7940b24b3482ed482bba811 Pull Request resolved: #301
float8_experimental/float8_tensor.py
Outdated
@@ -31,6 +28,20 @@ | |||
) | |||
|
|||
|
|||
class ScalingStrategy(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thoughts about using Granularity
, which is more specific than Strategy
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah thats a better word, this needed some bikeshedding
return Float8Tensor(bits_fp8, x_scale, x.dtype, mm_config=mm_config) | ||
return Float8Tensor( | ||
bits_fp8, | ||
x_scale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just curious, since we decided to not add scaling_strategy
to torch._scaled_mm
, why do we need it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make this a property of Float8Tensor, e.g. infer from the existing scales... hmm
Actually I might like this more..
We need the enum still since we want modules to specify their granularity
ghstack-source-id: 6f9b9299f4429ede127c0ed639a652d8888e947a Pull Request resolved: #301
# Summary This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale. - Currently hardcodes TensroWise.Scaling in Float8Linear, Float8DynamicLinear, Float8InferenceLinear. Asserts that granularity is TensorWise for now. - Added this as a property of WeightWithDynamicFloat8CastTensor, since we need to know a prior how do do the scaling for fp8 comms. ### Testing ``` Shell ============================================================================= test session starts ============================================================================= platform linux -- Python 3.12.4, pytest-7.4.0, pluggy-1.5.0 rootdir: /home/drisspg/meta/float8_experimental plugins: hypothesis-6.104.1 collected 9 items test/test_fsdp2/test_fsdp2_eager.py ......... [100%] ============================================================================= 9 passed in 30.77s ============================================================================== all tests successful ``` [ghstack-poisoned]
Summary
This PR adds a ScalingGranularity Enum, and threads it though the stack to all the places we call 'tensor_to_amax" and tensor_to_scale.
Testing
Stack from ghstack (oldest at bottom):