Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Adding Float8 Linear variants supporting inference-only with lower overhead #283

Closed
wants to merge 1 commit into from

Conversation

cyang49
Copy link

@cyang49 cyang49 commented Jun 14, 2024

The changes include two new Float8 Linear implementations that removes some extra wiring in Float8Linear unnecessary for inference-only use cases to result in lower latency.

  • Float8SWLinear supports direct fp8 type direct downcast for activation, and Static per-tensor scale for Weight. Our analysis shows that using this results in no loss of accuracy in Llama models.
  • Float8DASWLinear supports Dynamic per-tensor scale for Activation, and Static per-tensor scale for Weight. This is used when activation tensor requires dynamic scaling. Compared to Float8SWLinear, this has higher overhead introduced by dynamic activation tensor scale calculation. The overhead can be mitigated when used with torch.compile.

cc: @ani300

Co-authored-by: Mauricio Serrano <[email protected]>
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2024
@vkuzo
Copy link
Contributor

vkuzo commented Jun 14, 2024

Hi folks, thanks for the PR and we definitely want this functionality. We have some changes coming up to set up the inference path, let me give a preview and @drisspg is planning to publish a more detailed RFC related to float8 + inference UX soon.

  1. we plan to unify all training logic in Float8Linear, Float8DynamicLinear would fold into Float8Linear and each tensor (act/weight/grad) will be configurable on whether to scale dynamically or delayed
  2. we plan to add a separate construct for inference, since the weights are frozen. It will be something like Float8InferenceLinear, although we might also provide a tensor subclass based version. The behavior of the scaling of the input activation would be configurable at construction time.

In the framing above, this PR would be adding additional options to (2). Are you ok to wait a little bit for us to put out the official plan?

super(Float8SWLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
self.w_inv_s = None
self.dtype = torch.float8_e4m3fn
self.use_triton = use_triton
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this used anywhere?

x_f8 = x.to(self.dtype)
ishape= list(x_f8.shape)

if ishape[0] == 0: # special case handling for mixtral
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this not supported by scaled_mm today? cc @drisspg

@cyang49
Copy link
Author

cyang49 commented Jul 29, 2024

no longer needed

@cyang49 cyang49 closed this Jul 29, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants