From 860527ecd08cbb83ffadec59ca3f687a8423bce1 Mon Sep 17 00:00:00 2001 From: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Date: Thu, 13 Jun 2024 17:02:20 -0400 Subject: [PATCH] Adding Float8 Linear variants Co-authored-by: Mauricio Serrano --- float8_experimental/__init__.py | 4 +- float8_experimental/float8_linear.py | 84 ++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 2 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 72c09052..98158252 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. # Lets define a few top level things here -from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear import Float8Linear, Float8SWLinear, Float8DASWLinear from float8_experimental.float8_tensor import Float8Tensor -__all__ = ["Float8Tensor", "Float8Linear"] +__all__ = ["Float8Tensor", "Float8Linear", "Float8SWLinear", "Float8DASWLinear"] diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 568b36f5..bc730256 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -340,3 +340,87 @@ def from_float(cls, mod, emulate: bool = False): # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) return new_mod + +class Float8SWLinear(torch.nn.Linear): + """ + A variation of Float8Linear that operates on torch tensors directly instead of + Float8Tensor since the delayed scaling support is not needed. It supports direct fp8 + type downcast for activation, and Static per-tensor scale for Weight. + """ + + def __init__(self, in_features, out_features, bias=True, use_triton=False): + super(Float8SWLinear, self).__init__(in_features=in_features, out_features=out_features, bias=bias) + self.w_inv_s = None + self.dtype = torch.float8_e4m3fn + self.use_triton = use_triton + + @classmethod + def from_float(cls, mod, emulate: bool = False): + new_mod = cls(mod.in_features, mod.out_features, bias=mod.bias is not None) + assert(not emulate) # no emulation support + new_mod.emulate = emulate + + w_f8, w_inv_s = new_mod.to_float8(mod.weight) + + new_mod.weight = torch.nn.Parameter(w_f8, requires_grad=False) + new_mod.w_inv_s = torch.nn.Parameter(w_inv_s, requires_grad=False) + new_mod.bias = (torch.nn.Parameter(mod.bias.to(torch.float16), requires_grad=False) # force bias to be fp16 for now + if mod.bias is not None else None) + new_mod.unit_scale = torch.nn.Parameter(torch.tensor(1.0, dtype=torch.float32), requires_grad=False) + new_mod.to(mod.weight.device) + + return new_mod + + def to_float8(self, x): + finfo = torch.finfo(self.dtype) + # Calculate the scale as dtype max divided by absmax + scale = finfo.max / x.abs().max().clamp(min=1e-12) + # scale and clamp the tensor to bring it to + # the representative range of float8 data type + # (as default cast is unsaturated) + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + # Return both float8 data and the inverse scale (as float), + # as both required as inputs to torch._scaled_mm + return x_scl_sat.to(self.dtype), scale.float().reciprocal() # returns x in self.dtype and scale in f32 + + def forward(self, x): + x_f8 = x.to(self.dtype) + ishape= list(x_f8.shape) + + if ishape[0] == 0: # special case handling for mixtral + return torch.empty([ishape[0], self.weight.shape[0]], dtype=torch.float16, device=x.device) + + if len(ishape) == 3: + x_f8 = x_f8.view(-1,ishape[-1]) + + y, _ = torch._scaled_mm(x_f8, self.weight.T, out_dtype=torch.float16, + scale_b=self.w_inv_s, bias=self.bias, use_fast_accum=False) + + if len(ishape) == 3: + y = y.view(ishape[0],ishape[1],-1) + + return y + +class Float8DASWLinear(Float8SWLinear): + """ + A variation of Float8Linear that operates on torch tensors directly instead of + Float8Tensor since the delayed scaling support is not needed. It supports Dynamic + per-tensor scale for Activation, and Static per-tensor scale for Weight. + """ + def forward(self, x): + ishape= list(x.shape) + if ishape[0] == 0: # special case handling for mixtral + return torch.empty([ishape[0], self.weight.shape[0]], dtype=torch.float16, device=x.device) + + x_f8, x_inv_s = self.to_float8(x) + + if len(ishape) == 3: + x_f8 = x_f8.view(-1,ishape[-1]) + + y, _ = torch._scaled_mm(x_f8, self.weight.T, out_dtype=torch.float16, scale_a=x_inv_s, + scale_b=self.w_inv_s, bias=self.bias, use_fast_accum=False) + + if len(ishape) == 3: + y = y.view(ishape[0],ishape[1],-1) + + return y