forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
244 lines (207 loc) · 8.59 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Defines an nn module designed to be used during inference
"""
from dataclasses import dataclass
from enum import auto, Enum
from typing import Callable, List, Optional
import torch
import torch.nn as nn
from torchao.float8.float8_linear_utils import swap_linear_layers
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
hp_tensor_and_scale_to_float8,
LinearMMConfig,
ScaledMMConfig,
tensor_already_casted_to_fp8,
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
class ActivationCasting(Enum):
"""Types of quantization to perform on the activations
WEIGHT_ONLY: Only quantize the weight, no activation casting, weight will be dequantized in the forward pass
STATIC: Activation is quantized during model initialization with a static scale
DYNAMIC: Activation is quantized during forward pass with a dynamic scale calculated from the input activation
"""
# TODO: A better name would be NONE, we should unify this with torchao
WEIGHT_ONLY = auto()
DYNAMIC = auto()
STATIC = auto()
@dataclass(frozen=True)
class QuantConfig:
"""Defines the configuration for the quantization to fp8 of a linear module
Args:
activation_casting: The type of quantization to perform on the activations
static_quantization_scale: The scale of the input to this linear module, used for static quantization only
"""
activation_casting: ActivationCasting
static_quantization_scale: Optional[torch.Tensor] = None
# If True, then prior to performing the fp8 scaled mamtmul we will pad the
# inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls
# _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16.
# This can cause a memory spike however so we keep this off by default.
pad_inner_dim = False
def __post_init__(self):
if self.activation_casting == ActivationCasting.STATIC:
assert isinstance(
self.static_quantization_scale, torch.Tensor
), "When activation_casting is 'static', activation_scale must be a tensor."
class Float8InferenceLinear(torch.nn.Linear):
"""
This is a wrapper around torch.nn.Linear that supports FP8 inference
Supported forms of inference:
- FP8 inference with high precision matmul - weight only
- FP8 inference with fp8 matmul and dynamic weight casting
- FP8 inference with fp8 matmul and static weight casting
"""
def __init__(
self,
# FP8 specific arguments
quant_config: QuantConfig,
linear_mm_config: LinearMMConfig,
# nn.Linear arguments
in_features: int,
out_features: int,
bias: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
# Construct the superclass this will create dummy weights and biases
super().__init__(in_features, out_features, bias, device, dtype)
self.linear_mm_config = linear_mm_config
self.activation_casting = quant_config.activation_casting
if self.activation_casting == ActivationCasting.STATIC:
self.register_buffer(
"static_quantization_scale", quant_config.static_quantization_scale
)
else:
self.static_quantization_scale = None
def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.activation_casting == ActivationCasting.WEIGHT_ONLY:
return torch.nn.functional.linear(
input, self.weight.to_original_precision()
)
x_fp8 = cast_to_float8_e4m3_inference(
input,
self.linear_mm_config,
static_quantization_scale=self.static_quantization_scale,
)
return torch.nn.functional.linear(x_fp8, self.weight, self.bias)
# Builder functions for Float8LinearInference
def quantize_weight(self, dtype: torch.dtype = e4m3_dtype) -> None:
"""This functions converts the weight to a Float8Tensor and sets its requires_grad to False.
Args:
dtype: The dtype to quantize the weight to. Default is e4m3_dtype.
Note:
This function is typically called during inference to quantize the weight once since
the weight is not updated during inference.
"""
assert not isinstance(
self.weight, Float8Tensor
), "Weight has already been quantized, cannot quantize again."
scale = tensor_to_scale(self.weight, dtype)
quantized_weight = hp_tensor_and_scale_to_float8(
self.weight,
scale,
dtype,
self.linear_mm_config,
GemmInputRole.WEIGHT,
)
self.weight = nn.Parameter(quantized_weight)
self.weight.requires_grad = False
def set_weight_and_bias(
self, weight: torch.nn.Parameter, bias: Optional[torch.nn.Parameter]
):
self.weight = weight
self.bias = bias
@classmethod
def from_float(
cls, module: nn.Module, quant_config: QuantConfig, use_fast_accum: bool
) -> "Float8InferenceLinear":
"""
Create an nn.Linear with fp8 compute from another nn.Linear
Args:
mod (torch.nn.Linear): nn.Linear to convert
quant_config (QuantConfig): Configuration for the weight and activation casting
"""
forward_config = ScaledMMConfig(
False, use_fast_accum, pad_inner_dim=quant_config.pad_inner_dim
)
linear_mm_config = LinearMMConfig(
forward_config, forward_config, forward_config
)
linear = cls(
quant_config,
linear_mm_config,
module.in_features,
module.out_features,
False,
device=torch.device("meta"),
)
linear.set_weight_and_bias(module.weight, module.bias)
linear.quantize_weight()
return linear
def cast_to_float8_e4m3_inference(
inpt_tensor: torch.Tensor,
linear_mm_config: LinearMMConfig,
reduce_amax: bool = False,
static_quantization_scale: Optional[torch.Tensor] = None,
) -> Float8Tensor:
"""Casts an input tensor to the Float8 (e4m3fn*)
Args:
inpt_tensor: The input tensor to be cast.
linear_mm_config: Configuration settings for the matrix multiplication
reduce_amax: Whether to reduce the amax (absolute maximum) among the local distributed group.
static_quantization_scale: Optional tensor specifying the scale for activation. Default is None.
Returns:
Float8Tensor: The input tensor cast to Float8 (e4m3fn) format.
Note:
If the input tensor is already in Float8 format, it is returned as is without re-casting.
"""
if tensor_already_casted_to_fp8(inpt_tensor):
return inpt_tensor
scale = (
static_quantization_scale
if static_quantization_scale is not None
else tensor_to_scale(inpt_tensor, e4m3_dtype, reduce_amax)
)
return hp_tensor_and_scale_to_float8(
inpt_tensor,
scale,
e4m3_dtype,
linear_mm_config,
GemmInputRole.INPUT,
)
def quantize_to_float8(
module: nn.Module,
quant_config: QuantConfig,
*,
module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None,
use_fast_accum: bool = True,
) -> nn.Module:
"""
Converts torch.nn.Linear layers in the given module to Float8InferenceLinear.
Note:
If applied to a root-level nn.Linear, the module will not be modified in place
and returned instead
Args:
module (nn.Module): The module to modify.
quant_config (QuantConfig): Quantization configuration for Float8 conversion.
module_filter_fn: If specified, only the `torch.nn.Linear` subclasses that
that pass the filter function will be swapped. The inputs to the
filter function are the module instance and the FQN.
use_fast_accum : Whether to enable fast accumulation for the Float8InferenceLinear. Defaults to True.
Returns:
nn.Module: The modified module with applicable Linear layers converted to Float8.
Raises:
AssertionError: If a root-level nn.Linear with children is encountered.
"""
return swap_linear_layers(
module,
lambda m: Float8InferenceLinear.from_float(m, quant_config, use_fast_accum),
module_filter_fn=module_filter_fn,
)