-
Notifications
You must be signed in to change notification settings - Fork 20
Add a Float8LinearInference module to support static, dynamic, and wo quant #287
Conversation
9bee8a3
to
d1eae9a
Compare
12b32d3
to
5d5a48e
Compare
mod, | ||
emulate: bool = False, | ||
static_quantize_weight: bool = False, | ||
activation_scale: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we would probably want to provide two options:
- user specifies the scale they want (for example, for cases where the preceding activation has a bounded range)
- user specifies how to calculate the scale from calibration data
For (2), the current quantization UXs do this with Observer
objects (https://fburl.com/code/ady6pz23). cc @jerryzh168 on if torchao has any different plans to do this in future UXs.
0bd148f
to
2ff810c
Compare
Hey, we were quantizing bf16 weights on the fly from our checkpoints, but I think we'll do something akin to AutoFP8 (https://github.com/neuralmagic/AutoFP8) to handle the FP8 checkpoints and load into Float8Tensors |
c7e087d
to
b89515c
Compare
|
||
|
||
@dataclass(frozen=True) | ||
class QuantConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can see if there is any alignment possible with torchao on the name or structure of this? It doesn't have to align now, but ideally the structure is similar so it's easier to unify later. cc @jerryzh168
d54bd00
to
c1f2cad
Compare
False, | ||
device=torch.device("meta"), | ||
) | ||
linear.set_weight_and_bias(module.weight, module.bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg, maybe inline since only one callsite and function is super short?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kinda think it reads nicer tbh, is it okay if we keep as is and can remove in follow up if its distracting?
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
in_features: int, | ||
out_features: int, | ||
bias: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we're inheriting from torch.nn.Linear
, thoughts on adding the aditional arguments last? I understand the code works but it's extra mental load to think about how the argument order differs, unless it'a s simple "superclass args first, subclass args after"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem with this is that we dont want the two additional args to have default values. By putting two new args into the init this another example of breaking Liskov principle. However since we dont really expect anyone but "from_float" to call the constructor this is less likely to be hit in practice.
The only other way I can immediately think of to getting around this is to do what was done before, where you create the instance in a partially initialized state and then modify it inplacee.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do
def __init__(
self,
# FP8 specific arguments
quant_config: QuantConfig,
forward_config: ScaledMMConfig,
# nn.Linear arguments
in_features: int,
out_features: int,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
But I imagine you would like this even less
self.weight, Float8Tensor | ||
), "Weight has already been quantized, cannot quantize again." | ||
scale = tensor_to_scale(self.weight, dtype) | ||
quantized_weight = to_fp8_no_autograd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a TODO for later to fix the name of this function? Ideally we shouldn't have to think about autograd in inference code.
float8_experimental/inference.py
Outdated
|
||
if self.activation_casting == ActivationCasting.STATIC: | ||
self.register_buffer( | ||
"static_quantization_scale", quant_config.static_quantization_scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it would be nice to allow the user to pass a static scale (for example, for an activation with bounded range) without having to move it to the right device, can be a TODO for later
return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config) | ||
|
||
|
||
def quantize_to_float8( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe either move to the same utils file as the training swap function, or create inference utils? Just to stay consistent with how the training code is laid out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also can the names of the training vs inference top level UX be similar?
self.set_quantization_config(quant_config) | ||
self.forward_config = forward_config | ||
|
||
def forward(self, input: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thinking about this some more, there is some confusion I have now on how we want Float8Tensor
to be related to what gemm is being called
in training code, we seem to be avoiding modifying the F.linear
snippet, and instead Float8Tensor.__torch_dispatch__
contains the logic override
here, we are sometimes using F.linear
and sometimes using Float8Tensor.__torch_dispatch__
I can see two more principled solutions:
- remove the call to
torch._scaled_mm
fromFloat8Tensor.__torch_dispatch__
and force the user to call it directly - create a new variant of
Float8Tensor
for weight-only quantization
tbh I don't really love those solutions either, but if I had to pick I'd probably want (2), just to remove ambiguity. Maybe we should brainstorm on how to better separate data representation (float8 tensor = float8 data + scale) versus modeling code (which gemm to call).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or, (3) accept that now the user needs to understand both Float8InferenceLinear
and Float8Tensor
to understand which gemm is called (current PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is interesting and I was talking about this somewhat with Jerry.
Currently the dispatch entry for MM only accepts Float8Tensor & Float8Tensor for both its args.
We could update it to accept FP8 + FP8 AND FP8 + Regular. I didnt want to do this since we have multiple systems controlling who casts FP8, fsdp, dtensor, nn_module logc, this feels error prone and we would silently not be casting when we expect. I find comfort in this invariant lol
That being said my initial suggestions is to make a generic "weight_only" subclass wrapper, that will handle the casting up.
Another way to do achieve something similar is to allow for FP8 + Regular and expand scaled_mm_config to add constraints. I think the first option is more elegant personally, but I have been pretty embedded in the subclass world so maybe not as easy for others to grok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great! I have a lot of comments but we can figure those out in future discussions, don't want to block.
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Summary
Perf script:
https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2
Performance
In eager this produces:
With compile this produces:
UX
Dynamic activation quantization
Static activation quantization
Weight Only quantization
All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default.