diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index d8aa081..15237bc 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -51,6 +51,13 @@ def addmm_float8_unwrapped( ) output += bias return output + # Weight tensors are stored in N, K format. We call tensor_to_scale(dim=0) + # which produces a (N, 1) Tensor. However scaled_mm syntactically expects + # M X K @ K X N, and scales (M, 1) and (1, N) + b_inverse_scale = ( + b_inverse_scale.T if b_inverse_scale.dim() == 2 else b_inverse_scale + ) + output = torch._scaled_mm( a_data, b_data, diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index c1c97e7..6b827f9 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -109,7 +109,7 @@ def to_fp8_no_autograd( mm_config: Defines the configuration for the scaled_mm """ - x_scaled = x * x_scale + x_scaled = x * x_scale.to(dtype=x.dtype) bits_fp8 = to_fp8_saturated(x_scaled, float8_dtype) if isinstance(bits_fp8, DTensor): @@ -195,7 +195,9 @@ class FromFloat8ConstrFunc(torch.autograd.Function): @staticmethod def forward(ctx, tensor): - return tensor._data.to(tensor._orig_dtype) / tensor._scale + return tensor._data.to(tensor._orig_dtype) / tensor._scale.to( + tensor._orig_dtype + ) @staticmethod def backward(ctx, g): @@ -253,11 +255,11 @@ def __init__( orig_dtype: torch.dtype, mm_config: Optional[ScaledMMConfig], ): - assert ( - scale.numel() == 1 - ), "Scale should contain a single value, but got: {} elements".format( - scale.numel() - ) + # assert ( + # scale.numel() == 1 + # ), "Scale should contain a single value, but got: {} elements".format( + # scale.numel() + # ) self._data = data self._scale = scale diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 96df619..8999746 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from typing import Iterable, Literal, Tuple, Union +from typing import Iterable, Literal, Optional, Tuple, Union import float8_experimental.config as config import torch @@ -32,6 +32,12 @@ e5m2_dtype = torch.float8_e5m2 if not config.use_fnuz_dtype else torch.float8_e5m2fnuz +def get_supported_granularity(): + from float8_experimental.float8_tensor import ScalingGranularity + + return [ScalingGranularity.TensorWise, ScalingGranularity.AxisWise] + + @torch.no_grad() def amax_to_scale( amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype @@ -103,20 +109,34 @@ def amax_history_to_scale_stack( def tensor_to_amax( x: torch.Tensor, scaling_granularity, + dim: Optional[int] = None, reduce_amax: bool = False, ) -> torch.Tensor: """Calculates the amax of a tensor. Args: x: The tensor to calculate the amax for. scaling_granularity: The granularity of with which to calcualte the tensor amax + dim: The dimension along which to calculate the amax. This is only used if scaling_granularity is AxisWise. reduce_amax: Whether to perform a distributed reduction on the amax. """ from float8_experimental.float8_tensor import ScalingGranularity - assert ( - scaling_granularity == ScalingGranularity.TensorWise - ), f"Currently only TensorWise is supported for but given scaling_granularity: {scaling_granularity}" - amax = torch.max(torch.abs(x)) + supported_granularities = get_supported_granularity() + + if scaling_granularity not in supported_granularities: + raise ValueError( + f"Currently only {supported_granularities} are supported. Given scaling_granularity: {scaling_granularity}" + ) + + if scaling_granularity == ScalingGranularity.TensorWise: + amax = torch.max(torch.abs(x)) + elif scaling_granularity == ScalingGranularity.AxisWise: + if dim is None: + raise ValueError("For AxisWise scaling, a dim must be passed in!") + amax = torch.max(torch.abs(x), dim=dim, keepdim=True).values + else: + # This should never be reached due to the earlier check, but it's here for completeness + raise ValueError(f"Unsupported scaling_granularity: {scaling_granularity}") # If the user asked for distributed reduction, do it. # If the user did not ask for it, assume that it will @@ -132,16 +152,20 @@ def tensor_to_scale( x: torch.Tensor, float8_dtype: torch.dtype, scaling_granularity, + dim: Optional[int] = None, reduce_amax: bool = False, + collapse_leading_dims: bool = False, ) -> torch.Tensor: """Calculates the scale that will be used for quantization to Float8Tensor Args: x: The tensor to calculate the scale for. float8_dtype: The Float8 dtype to use. scaling_granularity: The granularity of the scale. See ScalingGranularity for more details. + dim: The dimension along which to calculate the scale. This is only used if scaling_granularity is AxisWise. reduce_amax: Whether to perform a distributed reduction on the amax. + collapse_leading_dims: Whether to collapse leading dimensions of the tensor. """ - amax = tensor_to_amax(x, scaling_granularity, reduce_amax=reduce_amax) + amax = tensor_to_amax(x, scaling_granularity, dim=dim, reduce_amax=reduce_amax) return amax_to_scale(amax, float8_dtype, x.dtype) diff --git a/float8_experimental/inference.py b/float8_experimental/inference.py index 6a950d7..7ce6b31 100644 --- a/float8_experimental/inference.py +++ b/float8_experimental/inference.py @@ -25,7 +25,13 @@ tensor_already_casted_to_fp8, to_fp8_no_autograd, ) -from float8_experimental.float8_utils import e4m3_dtype, tensor_to_scale +from float8_experimental.float8_utils import ( + e4m3_dtype, + get_supported_granularity, + tensor_to_scale, +) + +SUPPORTED_GRANULARITY = get_supported_granularity() class ActivationCasting(Enum): @@ -75,7 +81,7 @@ def __init__( # FP8 specific arguments quant_config: QuantConfig, forward_config: ScaledMMConfig, - scaling_granularity: ScalingGranularity, + scaling_granularity: Optional[ScalingGranularity], # nn.Linear arguments in_features: int, out_features: int, @@ -86,7 +92,26 @@ def __init__( # Construct the superclass this will create dummy weights and biases super().__init__(in_features, out_features, bias, device, dtype) self.forward_config = forward_config - self.scaling_granularity = scaling_granularity + if scaling_granularity is None: + self.scaling_granularity = ( + ScalingGranularity.AxisWise + if dtype == torch.bfloat16 + and quant_config.static_quantization_scale is None + else ScalingGranularity.TensorWise + ) + else: + assert ( + scaling_granularity in SUPPORTED_GRANULARITY + ), f"scaling_granularity must be in {SUPPORTED_GRANULARITY} but got {scaling_granularity}" + if ( + scaling_granularity == ScalingGranularity.AxisWise + and dtype != torch.bfloat16 + ): + raise ValueError( + "AxisWise scaling granularity is only supported for bfloat16." + ) + self.scaling_granularity = scaling_granularity + self.activation_casting = quant_config.activation_casting if self.activation_casting == ActivationCasting.STATIC: self.register_buffer( @@ -101,13 +126,22 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input, self.weight.to_original_precision() ) + # TODO we arent folding leading dims yet, but need it to calculate the proper scale.. this sucks + original_m = input.shape[:-1] + input = input.view(-1, input.shape[-1]) + x_fp8 = cast_to_float8_e4m3_inference( input, self.forward_config, static_quantization_scale=self.static_quantization_scale, scaling_granularity=self.scaling_granularity, ) - return torch.nn.functional.linear(x_fp8, self.weight, self.bias) + return torch.nn.functional.linear(x_fp8, self.weight, self.bias).view( + *original_m, -1 + ) + + def extra_repr(self): + return f"{super().extra_repr()},activation_casting={self.activation_casting.name},scaling_granularity={self.scaling_granularity.name}" # Builder functions for Float8LinearInference def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: @@ -124,7 +158,12 @@ def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None: assert not isinstance( self.weight, Float8Tensor ), "Weight has already been quantized, cannot quantize again." - scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity) + + # For weight tensors + AxisWise we calculate scales along columns + dim = None + if self.scaling_granularity == ScalingGranularity.AxisWise: + dim = 1 + scale = tensor_to_scale(self.weight, dtype, self.scaling_granularity, dim=dim) quantized_weight = to_fp8_no_autograd( self.weight, scale, dtype, self.forward_config ) @@ -143,6 +182,7 @@ def from_float( module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool, + scaling_granularity: Optional[ScalingGranularity], ) -> "Float8InferenceLinear": """ Create an nn.Linear with fp8 compute from another nn.Linear @@ -150,12 +190,12 @@ def from_float( Args: mod (torch.nn.Linear): nn.Linear to convert quant_config (QuantConfig): Configuration for the weight and activation casting + use_fast_accum (bool): Whether to enable fast accumulation for the Float8InferenceLinear. + scaling_granularity: The granularity of the scale. See ScalingGranularity for more details. """ forward_config = ScaledMMConfig( False, use_fast_accum, pad_inner_dim=config.pad_inner_dim ) - # TODO: For now hardcode TensorWise scaling - scaling_granularity = ScalingGranularity.TensorWise linear = cls( quant_config, forward_config, @@ -164,6 +204,7 @@ def from_float( module.out_features, False, device=torch.device("meta"), + dtype=module.weight.dtype, ) linear.set_weight_and_bias(module.weight, module.bias) linear.quantize_weight() @@ -194,11 +235,21 @@ def cast_to_float8_e4m3_inference( """ if tensor_already_casted_to_fp8(inpt_tensor): return inpt_tensor + + # For input tensors + AxisWise we calculate scales along rows + dim = None + if scaling_granularity == ScalingGranularity.AxisWise: + dim = 1 + scale = ( static_quantization_scale if static_quantization_scale is not None else tensor_to_scale( - inpt_tensor, e4m3_dtype, scaling_granularity, reduce_amax=reduce_amax + inpt_tensor, + e4m3_dtype, + scaling_granularity, + dim=dim, + reduce_amax=reduce_amax, ) ) return Float8Tensor.to_float8( @@ -206,6 +257,7 @@ def cast_to_float8_e4m3_inference( scale, e4m3_dtype, mm_config=mm_config, + scaling_granularity=scaling_granularity, ) @@ -215,6 +267,7 @@ def quantize_to_float8( *, skip_fqn_list: Optional[List[str]] = None, use_fast_accum: bool = True, + scaling_granularity: Optional[ScalingGranularity] = None, ) -> Optional[nn.Module]: """ Converts torch.nn.Linear layers in the given module to Float8InferenceLinear. @@ -228,6 +281,7 @@ def quantize_to_float8( quant_config (QuantConfig): Quantization configuration for Float8 conversion. skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion. use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True. + scaling_granularity: The granularity of the scale. See ScalingGranularity for more details. Returns: nn.Module: The modified module with applicable Linear layers converted to Float8. @@ -237,6 +291,8 @@ def quantize_to_float8( """ return swap_linear_layers( module, - lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum), + lambda m: Float8InferenceLinear.from_float( + m, quant_config, use_fast_accum, scaling_granularity + ), skip_fqn_list=skip_fqn_list, )