Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
some more tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jun 29, 2024
1 parent 3e0a2fd commit 125390f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 31 deletions.
4 changes: 2 additions & 2 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def swap_linear_layers(
*,
skip_fqn_list: Optional[List[str]] = None,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
) -> Optional[nn.Module]:
"""
Generic function to swap linear layers in a module with a new type of linear layer.
Expand Down Expand Up @@ -174,7 +174,7 @@ def swap_linear_with_float8_linear(
skip_fqn_list: Optional[List[str]] = None,
emulate: bool = False,
linear_layer_filter: Optional[Callable[[nn.Linear], bool]] = None,
) -> nn.Module:
) -> Optional[nn.Module]:
return swap_linear_layers(
module,
lambda m: module_cls.from_float(m, emulate=emulate),
Expand Down
48 changes: 21 additions & 27 deletions float8_experimental/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ActivationCasting(Enum):
DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
"""

# TODO: A better name would be NONE, we should unify this with torchao
WEIGHT_ONLY = auto()
DYNAMIC = auto()
STATIC = auto()
Expand All @@ -59,19 +60,21 @@ def __post_init__(self):
), "When activation_casting is 'static', activation_scale must be a tensor."


class Float8LinearInference(torch.nn.Linear):
class Float8InferenceLinear(torch.nn.Linear):
"""
This is a wrapper around torch.nn.Linear that supports FP8 inference
Supported forms of infernce:
- FP8 inference with fp32 matmul - weight only
Supported forms of inference:
- FP8 inference with high precision matmul - weight only
- FP8 inference with fp8 matmul and dynamic weight casting
- FP8 inference with fp8 matmul and static weight casting
"""

def __init__(
self,
# FP8 specific arguments
quant_config: QuantConfig,
forward_config: ScaledMMConfig,
# nn.Linear arguments
in_features: int,
out_features: int,
bias: bool = True,
Expand All @@ -80,16 +83,22 @@ def __init__(
) -> None:
# Construct the superclass this will create dummy weights and biases
super().__init__(in_features, out_features, bias, device, dtype)
self.set_quantization_config(quant_config)
self.forward_config = forward_config
self.activation_casting = quant_config.activation_casting
if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
"static_quantization_scale", quant_config.static_quantization_scale
)
else:
self.static_quantization_scale = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.activation_casting == ActivationCasting.WEIGHT_ONLY:
return torch.nn.functional.linear(
input, self.weight.to_original_precision()
)

x_fp8 = cast_to_float8_e4m3fn(
x_fp8 = cast_to_float8_e4m3_inference(
input,
self.forward_config,
static_quantization_scale=self.static_quantization_scale,
Expand Down Expand Up @@ -127,25 +136,10 @@ def set_weight_and_bias(
self.weight = weight
self.bias = bias

def set_quantization_config(
self,
quant_config: QuantConfig,
):
# We destructure the quant_config into the individual fields
# If an activation config is passed in we want to register that as a buffer
self.activation_casting: ActivationCasting = quant_config.activation_casting

if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
"static_quantization_scale", quant_config.static_quantization_scale
)
else:
self.static_quantization_scale = None

@classmethod
def from_float(
cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool
) -> "Float8LinearInference":
) -> "Float8InferenceLinear":
"""
Create an nn.Linear with fp8 compute from another nn.Linear
Expand All @@ -169,13 +163,13 @@ def from_float(
return linear


def cast_to_float8_e4m3fn(
def cast_to_float8_e4m3_inference(
inpt_tensor: torch.Tensor,
mm_config: ScaledMMConfig,
reduce_amax: bool = False,
static_quantization_scale: Optional[torch.Tensor] = None,
) -> Float8Tensor:
"""Casts an input tensor to the Float8 (e4m3fn) format for efficient computation.
"""Casts an input tensor to the Float8 (e4m3fn*)
Args:
inpt_tensor: The input tensor to be cast.
Expand Down Expand Up @@ -205,9 +199,9 @@ def quantize_to_float8(
*,
skip_fqn_list: Optional[List[str]] = None,
use_fast_accum: bool = True,
) -> nn.Module:
) -> Optional[nn.Module]:
"""
Converts torch.nn.Linear layers in the given module to Float8LinearInference.
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
Expand All @@ -217,7 +211,7 @@ def quantize_to_float8(
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
skip_fqn_list (List[str], optional): List of module FQNs to skip during conversion.
use_fast_accum : Whether to enable fast accumulation for the Float8LinearInference. Defaults to True.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Expand All @@ -227,6 +221,6 @@ def quantize_to_float8(
"""
return swap_linear_layers(
module,
lambda m: Float8LinearInference.from_float(m, quant_config, use_fast_accum),
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
skip_fqn_list=skip_fqn_list,
)
4 changes: 2 additions & 2 deletions test/test_inference_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from float8_experimental.float8_utils import compute_error
from float8_experimental.inference import (
ActivationCasting,
Float8LinearInference,
Float8InferenceLinear,
QuantConfig,
quantize_to_float8,
)
Expand Down Expand Up @@ -228,7 +228,7 @@ def test_fp8_save_and_load(self, dtype: torch.dtype):

fp8_mod_count = 0
for module in new_fp8_mlp.modules():
if isinstance(module, Float8LinearInference):
if isinstance(module, Float8InferenceLinear):
assert isinstance(module.weight, Float8Tensor)
assert module.weight.requires_grad is False
fp8_mod_count += 1
Expand Down

0 comments on commit 125390f

Please sign in to comment.