Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Add a Float8LinearInference module to support static, dynamic, and wo quant #287

Closed
wants to merge 9 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Jun 20, 2024

Summary

Perf script:

https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2

Performance

In eager this produces:

Operation Time (μs)
bf16 2667.9172
fp8_dynamic_activations 2494.7294
fp8_static_activations 2449.1784
fp8_weight_only_activations 4084.7190

With compile this produces:

Operation Time (μs)
bf16 2547.1938
fp8_dynamic_activations 1542.0729
fp8_static_activations 1407.0310
fp8_weight_only_activations 2750.6369

UX

Dynamic activation quantization

original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

dynamic_fp8_mlp = copy.deepcopy(original_mlp)

quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantize_to_float8(dynamic_fp8_mlp, quant_config)

Static activation quantization

original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

static_fp8_mlp = copy.deepcopy(original_mlp)
quant_config = QuantConfig(
    ActivationCasting.STATIC,
    static_quantization_scale=torch.tensor(
        [1.0], device="cuda", dtype=torch.float32
    ),
)
quantize_to_float8(static_fp8_mlp, quant_config)

Weight Only quantization

  original_mlp = FeedForward().to("cuda", dtype=dtype)
  original_mlp.reset_parameters()

  wo_fp8_mlp = copy.deepcopy(original_mlp)
  quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY)
  quantize_to_float8(wo_fp8_mlp, quant_config)

All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 20, 2024
@drisspg drisspg changed the title Updates to enable static weight quantization/dynamic activation quanization Updates to enable static weight quantization + Static & Dynamic Activation quantization Jun 21, 2024
mod,
emulate: bool = False,
static_quantize_weight: bool = False,
activation_scale: Optional[torch.Tensor] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would probably want to provide two options:

  1. user specifies the scale they want (for example, for cases where the preceding activation has a bounded range)
  2. user specifies how to calculate the scale from calibration data

For (2), the current quantization UXs do this with Observer objects (https://fburl.com/code/ady6pz23). cc @jerryzh168 on if torchao has any different plans to do this in future UXs.

@drisspg drisspg force-pushed the float8-linear-inference branch 5 times, most recently from 0bd148f to 2ff810c Compare June 22, 2024 06:51
@ani300
Copy link
Contributor

ani300 commented Jun 25, 2024

On simple static weight and activation mlp I am seeing a copy_ error

    swap_linear_with_float8_linear(
            static_fp8_mlp,
            Float8DynamicLinear,
            from_float_kwargs={"static_quantize_weight": True, "activation_scale": torch.tensor([1.0], device="cuda", dtype=torch.float32)},
    )

    print(f"out_static = {static_fp8_mlp(input_tensor)}")
    torch.save(static_fp8_mlp.state_dict(), "/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt")

    static_load = torch.load("/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt")
    static_fp8_mlp.load_state_dict(static_load)
    print(f"out_static_load = {static_load(input_tensor)}")
    
RuntimeError: Error(s) in loading state_dict for FeedForward:
       While copying the parameter named "w1.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).
       While copying the parameter named "w3.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).
       While copying the parameter named "w2.weight", whose dimensions in the model are torch.Size([4096, 14336]) and whose dimensions in the checkpoint are torch.Size([4096, 14336]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).

@ani300 I imagine you were doing some state_dict loading for Float8Tensors?

Hey, we were quantizing bf16 weights on the fly from our checkpoints, but I think we'll do something akin to AutoFP8 (https://github.com/neuralmagic/AutoFP8) to handle the FP8 checkpoints and load into Float8Tensors



@dataclass(frozen=True)
class QuantConfig:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can see if there is any alignment possible with torchao on the name or structure of this? It doesn't have to align now, but ideally the structure is similar so it's easier to unify later. cc @jerryzh168

@drisspg drisspg force-pushed the float8-linear-inference branch 2 times, most recently from d54bd00 to c1f2cad Compare June 28, 2024 22:03
False,
device=torch.device("meta"),
)
linear.set_weight_and_bias(module.weight, module.bias)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg, maybe inline since only one callsite and function is super short?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda think it reads nicer tbh, is it okay if we keep as is and can remove in follow up if its distracting?

@drisspg drisspg requested a review from vkuzo June 28, 2024 23:40
@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@drisspg drisspg changed the title Updates to enable static weight quantization + Static & Dynamic Activation quantization Add a Float8LinearInference module to support static, dynamic, and wo quant Jun 28, 2024
Comment on lines +75 to +77
in_features: int,
out_features: int,
bias: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we're inheriting from torch.nn.Linear, thoughts on adding the aditional arguments last? I understand the code works but it's extra mental load to think about how the argument order differs, unless it'a s simple "superclass args first, subclass args after"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with this is that we dont want the two additional args to have default values. By putting two new args into the init this another example of breaking Liskov principle. However since we dont really expect anyone but "from_float" to call the constructor this is less likely to be hit in practice.

The only other way I can immediately think of to getting around this is to do what was done before, where you create the instance in a partially initialized state and then modify it inplacee.

Copy link
Contributor Author

@drisspg drisspg Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do

    def __init__(
        self,
        # FP8 specific arguments
        quant_config: QuantConfig,
        forward_config: ScaledMMConfig,
        # nn.Linear arguments
        in_features: int,
        out_features: int,
        bias: bool = True,
        device: Optional[torch.device] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:

But I imagine you would like this even less

self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype)
quantized_weight = to_fp8_no_autograd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a TODO for later to fix the name of this function? Ideally we shouldn't have to think about autograd in inference code.


if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
"static_quantization_scale", quant_config.static_quantization_scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to allow the user to pass a static scale (for example, for an activation with bounded range) without having to move it to the right device, can be a TODO for later

return Float8Tensor.to_float8(inpt_tensor, scale, e4m3_dtype, mm_config=mm_config)


def quantize_to_float8(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe either move to the same utils file as the training swap function, or create inference utils? Just to stay consistent with how the training code is laid out.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can the names of the training vs inference top level UX be similar?

self.set_quantization_config(quant_config)
self.forward_config = forward_config

def forward(self, input: torch.Tensor) -> torch.Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thinking about this some more, there is some confusion I have now on how we want Float8Tensor to be related to what gemm is being called

in training code, we seem to be avoiding modifying the F.linear snippet, and instead Float8Tensor.__torch_dispatch__ contains the logic override

here, we are sometimes using F.linear and sometimes using Float8Tensor.__torch_dispatch__

I can see two more principled solutions:

  1. remove the call to torch._scaled_mm from Float8Tensor.__torch_dispatch__ and force the user to call it directly
  2. create a new variant of Float8Tensor for weight-only quantization

tbh I don't really love those solutions either, but if I had to pick I'd probably want (2), just to remove ambiguity. Maybe we should brainstorm on how to better separate data representation (float8 tensor = float8 data + scale) versus modeling code (which gemm to call).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or, (3) accept that now the user needs to understand both Float8InferenceLinear and Float8Tensor to understand which gemm is called (current PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting and I was talking about this somewhat with Jerry.

Currently the dispatch entry for MM only accepts Float8Tensor & Float8Tensor for both its args.
We could update it to accept FP8 + FP8 AND FP8 + Regular. I didnt want to do this since we have multiple systems controlling who casts FP8, fsdp, dtensor, nn_module logc, this feels error prone and we would silently not be casting when we expect. I find comfort in this invariant lol

That being said my initial suggestions is to make a generic "weight_only" subclass wrapper, that will handle the casting up.

Another way to do achieve something similar is to allow for FP8 + Regular and expand scaled_mm_config to add constraints. I think the first option is more elegant personally, but I have been pretty embedded in the subclass world so maybe not as easy for others to grok

Copy link
Contributor

@vkuzo vkuzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great! I have a lot of comments but we can figure those out in future discussions, don't want to block.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@drisspg merged this pull request in 36405a7.

@drisspg drisspg mentioned this pull request Jul 10, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Merged
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants