Skip to content

Commit fe0b35d

Browse files
author
Flax Authors
committed
Merge pull request #4686 from kaixih:support_fp8_einsum
PiperOrigin-RevId: 745695362
2 parents e58aeea + 01cdcc7 commit fe0b35d

File tree

5 files changed

+493
-365
lines changed

5 files changed

+493
-365
lines changed

docs/guides/quantization/fp8_basics.ipynb

Lines changed: 147 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
"as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back\n",
1414
"to its original type.\n",
1515
"\n",
16-
"Although jnp.dot supports FP8 inputs, certain limitations make it impractical\n",
17-
"for real-world applications. Alternatively, XLA, our compiler, can recognize\n",
18-
"patterns like <FP8>->DQ->Dot and subsequently invoke FP8 backends (e.g.,\n",
19-
"cublasLt for GPUs). FLAX encapsulates such patterns into the\n",
20-
"nn.fp8_ops.Fp8DotGeneralOp module, allowing users to easily configure it for\n",
21-
"existing layers (e.g., nn.Dense).\n",
16+
"While jnp.dot supports FP8 inputs directly, proper quantization and\n",
17+
"dequantization is needed for optimal performance. Flax provides\n",
18+
"nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle\n",
19+
"this automatically and can be used with existing layers like nn.Dense.\n",
2220
"\n",
2321
"This tutorial will walk you through the basics of how to use it.\n",
2422
"\n",
@@ -50,7 +48,6 @@
5048
"from flax.linen import fp8_ops\n",
5149
"\n",
5250
"e4m3 = jnp.float8_e4m3fn\n",
53-
"e5m2 = jnp.float8_e5m2\n",
5451
"f32 = jnp.float32\n",
5552
"E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)\n",
5653
"\n",
@@ -82,34 +79,29 @@
8279
"metadata": {},
8380
"outputs": [],
8481
"source": [
85-
"key = random.key(0)\n",
86-
"A = random.uniform(key, (16, 32))\n",
87-
"B = random.uniform(key, (32, 64))\n",
82+
"k0, k1 = random.split(random.key(0), 2)\n",
83+
"a = random.uniform(k0, (16, 32))\n",
84+
"b = random.uniform(k1, (32, 64))\n",
8885
"@jax.jit\n",
89-
"def dot_fp8(A, B):\n",
90-
" return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)\n",
91-
"check_fp8_call(dot_fp8.lower(A, B))"
86+
"def dot_fp8(a, b):\n",
87+
" return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)\n",
88+
"check_fp8_call(dot_fp8.lower(a, b))"
9289
]
9390
},
9491
{
9592
"cell_type": "markdown",
9693
"id": "adb22878",
9794
"metadata": {},
9895
"source": [
99-
"However, there are two main issues with this approach. Firstly, `jnp.dot` does\n",
100-
"not accept scaling factors for the operands, defaulting to a scaling factor of\n",
101-
"1.0. Secondly, it does not support operands of mixed FP8 data types. For\n",
102-
"example, when the operands are E5M2 and E4M3, the dot product is performed using\n",
103-
"the promoted FP16 data type.\n",
96+
"However, this approach has two key limitations:\n",
10497
"\n",
105-
"In real-world scenarios, it is essential to specify scaling factors, either from\n",
106-
"calibration for inference or a user-defined algorithm during training.\n",
107-
"Additionally, it is common practice to use E5M2 for gradients and E4M3 for\n",
108-
"activations and kernels. These limitations make this method less practical for\n",
109-
"real-world applications.\n",
98+
"1. `jnp.dot` does not support custom scaling factors for operands, defaulting to\n",
99+
" a scale of 1.0\n",
100+
"2. The autodiff does not automatically use E5M2 for gradients and E4M3 for\n",
101+
" activations/weights during training, which is the recommended practice\n",
110102
"\n",
111-
"To address these limitations and create a more versatile FP8 dot product, we\n",
112-
"recommend leveraging XLA-FP8. Let's begin with a simple scaling strategy.\n",
103+
"To overcome these limitations and implement proper FP8 matrix multiplication, we\n",
104+
"recommend using the Flax FP8 APIs. Let's start with a basic scaling approach.\n",
113105
"\n",
114106
"\n",
115107
"### Current Scaling\n",
@@ -129,36 +121,38 @@
129121
"outputs": [],
130122
"source": [
131123
"@jax.jit\n",
132-
"def dot_fp8(A, B):\n",
133-
" A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n",
134-
" B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n",
135-
" A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)\n",
136-
" B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)\n",
137-
"\n",
138-
" C = jnp.dot(A, B)\n",
139-
" return C\n",
124+
"def dot_fp8(a, b):\n",
125+
" a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n",
126+
" b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n",
127+
" a = fp8_ops.quantize(a, e4m3, a_scale, f32)\n",
128+
" b = fp8_ops.quantize(b, e4m3, b_scale, f32)\n",
129+
"\n",
130+
" c = jnp.dot(a, b, preferred_element_type=f32)\n",
131+
" c = fp8_ops.dequantize(c, f32, a_scale * b_scale)\n",
132+
" return c\n",
140133
"\n",
141-
"C = dot_fp8(A, B)\n",
142-
"check_fp8_call(dot_fp8.lower(A, B))"
134+
"c = dot_fp8(a, b)\n",
135+
"check_fp8_call(dot_fp8.lower(a, b))"
143136
]
144137
},
145138
{
146139
"cell_type": "markdown",
147140
"id": "59aca6fe",
148141
"metadata": {},
149142
"source": [
150-
"As shown in the code, we perform fake quantization\n",
151-
"(`fp8_ops.quantize_dequantize`) on the operands of the dot product. Although the\n",
152-
"`jnp.dot` still processes higher-precision inputs, XLA detects this pattern and\n",
153-
"rewrites the dot operation as an FP8 dot call (e.g., cublasLt call for GPUs).\n",
154-
"This approach effectively mimics the first example but offers greater\n",
155-
"flexibility. We can control the input dtypes (both are set to E4M3 here, but we\n",
156-
"could use mixed E4M3 and E5M2) and define scaling factors, which XLA can detect\n",
157-
"and use in the dot backend.\n",
158-
"\n",
159-
"One major issue with the current scaling method is the overhead introduced by\n",
160-
"computing `A_scale` and `B_scale`, which requires additional loading of the\n",
161-
"operand tensors. To overcome this issue, we recommend the delayed scaling.\n",
143+
"As shown in the code, we perform quantization (`fp8_ops.quantize`) on the\n",
144+
"tensors to get the lower precision operands. The `jnp.dot` processes them and\n",
145+
"accumulates the output in high precision (i.e., the `preferred_element_type`).\n",
146+
"After that, we multiply the result by the scaling factors to dequantize back to\n",
147+
"the original range (`fp8_ops.dequantize`). Note that while this example uses\n",
148+
"E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and\n",
149+
"E5M2 for the inputs. The quantization method and the scaling factors can also be\n",
150+
"customized based on application needs.\n",
151+
"\n",
152+
"One major issue with the current scaling method is the performance overhead\n",
153+
"introduced by computing `a_scale` and `b_scale`, which requires additional\n",
154+
"loading of the operand tensors. To overcome this issue, we recommend the delayed\n",
155+
"scaling.\n",
162156
"\n",
163157
"### Delayed Scaling\n",
164158
"\n",
@@ -167,8 +161,10 @@
167161
"values from recent steps (e.g., 1024 steps). Both tensors are computed from\n",
168162
"previous steps and maintained in the model parameters.\n",
169163
"\n",
170-
"Fake quantization for delayed scaling is provided by `fp8_ops.in_qdq` for the\n",
171-
"activations and weights, and `fp8_ops.out_qdq` for the gradients."
164+
"The quantization and dequantization operations for delayed scaling are provided\n",
165+
"by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles\n",
166+
"input quantization and update the amax history and scaling factor, while\n",
167+
"`fp8_ops.out_dq` performs output dequantization."
172168
]
173169
},
174170
{
@@ -180,25 +176,20 @@
180176
"source": [
181177
"a_scale = jnp.array(1.0)\n",
182178
"b_scale = jnp.array(1.0)\n",
183-
"g_scale = jnp.array(1.0)\n",
184179
"a_amax_hist = jnp.zeros((1024,))\n",
185180
"b_amax_hist = jnp.zeros((1024,))\n",
186-
"g_amax_hist = jnp.zeros((1024,))\n",
187181
"\n",
188182
"@jax.jit\n",
189-
"def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,\n",
190-
" g_scale, g_amax_hist):\n",
191-
" a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)\n",
192-
" b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)\n",
183+
"def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):\n",
184+
" a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)\n",
185+
" b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)\n",
193186
" \n",
194-
" c = jnp.dot(a, b)\n",
195-
" c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)\n",
187+
" c = jnp.dot(a, b, preferred_element_type=f32)\n",
188+
" c = fp8_ops.out_dq(f32, a_scale, b_scale, c)\n",
196189
" return c\n",
197190
"\n",
198-
"C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n",
199-
" g_scale, g_amax_hist)\n",
200-
"check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n",
201-
" g_scale, g_amax_hist))"
191+
"c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)\n",
192+
"check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))"
202193
]
203194
},
204195
{
@@ -208,22 +199,22 @@
208199
"source": [
209200
"In this example, we first prepare three pairs of scaling factors and amax\n",
210201
"histories, treating them as results computed from previous steps. Then, we apply\n",
211-
"`fp8_ops.in_qdq` to the input operands of `jnp.dot`, followed by\n",
212-
"`fp8_ops.out_qdq` to the output of `jnp.dot`. Note the `fp8_ops.out_qdq` will\n",
213-
"apply fake quantization to the gradient of the output via custom_vjp functions.\n",
214-
"The new scaling factors and amax histories will be returned through their\n",
215-
"gradients, which will be covered in the next section.\n",
202+
"`fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq`\n",
203+
"to the output of `jnp.dot`.\n",
216204
"\n",
217205
"\n",
218206
"## FLAX High Level API\n",
219-
"With the FLAX library, incorporating FP8 operations into existing FLAX layers\n",
220-
"is a seamless process. Users don't need to manipulate the low-level APIs for\n",
221-
"quantization. Instead, they can integrate the provided custom FP8 dot\n",
222-
"(`fp8_ops.Fp8DotGeneralOp`) into FLAX layers using a straightforward\n",
223-
"\"code-injection\" approach. This custom operation encapsulates all FP8-related\n",
224-
"tasks, including the placement of quantization-dequantization ops, algorithms\n",
225-
"for updating scaling factors, and the selection of FP8 dtype combinations for\n",
226-
"forward and backward propagation.\n",
207+
"Flax provides high-level operations to seamlessly integrate FP8 quantization\n",
208+
"into existing layers. Instead of manually handling quantization of the delayed\n",
209+
"scaling (e.g., the maintanence of the amax history and scaling factors), users\n",
210+
"can simply use these drop-in replacements:\n",
211+
"\n",
212+
"* `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations\n",
213+
"* `fp8_ops.Fp8Einsum` for `jnp.einsum` operations \n",
214+
"\n",
215+
"These operations automatically handle all FP8-related functionality, including\n",
216+
"quantization/dequantization, scale factor updates, and FP8 dtype selection for\n",
217+
"both forward and backward passes.\n",
227218
"\n",
228219
"Consider the following example:"
229220
]
@@ -235,8 +226,8 @@
235226
"metadata": {},
236227
"outputs": [],
237228
"source": [
238-
"model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)\n",
239-
"params = model.init(key, A)\n",
229+
"model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)\n",
230+
"params = model.init(k0, A)\n",
240231
"\n",
241232
"@jax.jit\n",
242233
"def train_step(var, a): \n",
@@ -248,16 +239,64 @@
248239
},
249240
{
250241
"cell_type": "markdown",
251-
"id": "a83b0851",
242+
"id": "ba280e79",
252243
"metadata": {},
253244
"source": [
254-
"In this example, we simply set `dot_general_cls=fp8_ops.Fp8DotGeneralOp` to\n",
255-
"enable the Dense layer to utilize the FP8 dot operation. The usage of the model\n",
256-
"remains almost the same as before. The main difference is the addition of a new\n",
257-
"category of parameters: the sets of scaling factors and amax history. In the\n",
258-
"next section, we will explore how to update these parameters.\n",
245+
"By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the\n",
246+
"default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version.\n",
247+
"The model usage remains similar, but now includes additional parameters for FP8\n",
248+
"quantization: scaling factors and amax history values. The next section explains\n",
249+
"how to update these FP8-specific parameters.\n",
250+
"\n",
251+
"For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE)\n",
252+
"layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8\n",
253+
"quantization. Here's an example:"
254+
]
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"id": "961b4549",
260+
"metadata": {},
261+
"outputs": [],
262+
"source": [
263+
"from typing import Any\n",
264+
"class FooModule(nn.Module):\n",
265+
" einsum: Any = None\n",
266+
" @nn.compact\n",
267+
" def __call__(self, a, b):\n",
268+
" if self.einsum is not None:\n",
269+
" einsum_fn = self.einsum()\n",
270+
" elif self.einsum is None:\n",
271+
" einsum_fn = jnp.einsum\n",
272+
" c = einsum_fn(\"mk,kn->mn\", a, b)\n",
273+
" return c\n",
274+
"\n",
275+
"model = FooModule(einsum=fp8_ops.Fp8Einsum)\n",
276+
"params = model.init(k0, a, b)\n",
259277
"\n",
278+
"@jax.jit\n",
279+
"def train_step(var, a, b):\n",
280+
" c = model.apply(var, a, b)\n",
281+
" return jnp.sum(c)\n",
282+
"\n",
283+
"check_fp8_call(train_step.lower(params, a, b))"
284+
]
285+
},
286+
{
287+
"cell_type": "markdown",
288+
"id": "a83b0851",
289+
"metadata": {},
290+
"source": [
260291
"## Manipulate FP8 params\n",
292+
"\n",
293+
"The following sections explain the internal FP8 parameters managed by\n",
294+
"`fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters\n",
295+
"include scaling factors and amax history values that control the FP8\n",
296+
"quantization process. While most users don't need to interact with these\n",
297+
"directly, understanding them can be valuable for advanced optimization and\n",
298+
"debugging.\n",
299+
"\n",
261300
"Let's first examine the data structure of `params`. In the code below, we redact\n",
262301
"the parameter values and then display the PyTree structure."
263302
]
@@ -285,13 +324,12 @@
285324
"The output is as follows:\n",
286325
"\n",
287326
"```plaintext\n",
288-
"{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',\n",
289-
" 'input_scale': '*',\n",
290-
" 'kernel_amax_history': '*',\n",
291-
" 'kernel_scale': '*',\n",
292-
" 'output_grad_amax_history': '*',\n",
293-
" 'output_grad_scale': '*'}},\n",
294-
" 'params': {'bias': '*', 'kernel': '*'}}\n",
327+
"{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',\n",
328+
" 'input_scale': '*',\n",
329+
" 'kernel_amax_history': '*',\n",
330+
" 'kernel_scale': '*',\n",
331+
" 'output_grad_amax_history': '*',\n",
332+
" 'output_grad_scale': '*'}}}\n",
295333
"```\n",
296334
"\n",
297335
"In addition to the expected `params`, there is an additional category called\n",
@@ -400,7 +438,26 @@
400438
"2.0 [5. 0. 0. ... 0. 0. 0.]\n",
401439
"```\n",
402440
"\n",
403-
"This casting is already included if users choose to use the high-level APIs."
441+
"This casting is already included if users choose to use the high-level APIs.\n",
442+
"\n",
443+
"## Deprecated APIs\n",
444+
"Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current\n",
445+
"scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with\n",
446+
"high precision dot operations, leveraging an XLA-FP8 feature that\n",
447+
"pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding\n",
448+
"high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern\n",
449+
"matching-based solution proved brittle, as the patterns could be easily broken\n",
450+
"by other XLA optimizations. We recommend users migrate from these deprecated\n",
451+
"APIs to the newer ones described above.\n",
452+
"\n",
453+
"For migration, users should replace:\n",
454+
"* `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot ->\n",
455+
" fp8_ops.dequantize`\n",
456+
"* `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot\n",
457+
" -> fp8_ops.out_dq`\n",
458+
"* `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral`\n",
459+
"\n",
460+
"Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`."
404461
]
405462
}
406463
],

0 commit comments

Comments
 (0)