|
13 | 13 | "as quantization (Q). Conversely, de-quantization (DQ) rescales the FP8 data back\n",
|
14 | 14 | "to its original type.\n",
|
15 | 15 | "\n",
|
16 |
| - "Although jnp.dot supports FP8 inputs, certain limitations make it impractical\n", |
17 |
| - "for real-world applications. Alternatively, XLA, our compiler, can recognize\n", |
18 |
| - "patterns like <FP8>->DQ->Dot and subsequently invoke FP8 backends (e.g.,\n", |
19 |
| - "cublasLt for GPUs). FLAX encapsulates such patterns into the\n", |
20 |
| - "nn.fp8_ops.Fp8DotGeneralOp module, allowing users to easily configure it for\n", |
21 |
| - "existing layers (e.g., nn.Dense).\n", |
| 16 | + "While jnp.dot supports FP8 inputs directly, proper quantization and\n", |
| 17 | + "dequantization is needed for optimal performance. Flax provides\n", |
| 18 | + "nn.fp8_ops.Fp8DotGeneral and nn.fp8_ops.Fp8Einsum modules that handle\n", |
| 19 | + "this automatically and can be used with existing layers like nn.Dense.\n", |
22 | 20 | "\n",
|
23 | 21 | "This tutorial will walk you through the basics of how to use it.\n",
|
24 | 22 | "\n",
|
|
50 | 48 | "from flax.linen import fp8_ops\n",
|
51 | 49 | "\n",
|
52 | 50 | "e4m3 = jnp.float8_e4m3fn\n",
|
53 |
| - "e5m2 = jnp.float8_e5m2\n", |
54 | 51 | "f32 = jnp.float32\n",
|
55 | 52 | "E4M3_MAX = jnp.finfo(e4m3).max.astype(f32)\n",
|
56 | 53 | "\n",
|
|
82 | 79 | "metadata": {},
|
83 | 80 | "outputs": [],
|
84 | 81 | "source": [
|
85 |
| - "key = random.key(0)\n", |
86 |
| - "A = random.uniform(key, (16, 32))\n", |
87 |
| - "B = random.uniform(key, (32, 64))\n", |
| 82 | + "k0, k1 = random.split(random.key(0), 2)\n", |
| 83 | + "a = random.uniform(k0, (16, 32))\n", |
| 84 | + "b = random.uniform(k1, (32, 64))\n", |
88 | 85 | "@jax.jit\n",
|
89 |
| - "def dot_fp8(A, B):\n", |
90 |
| - " return jnp.dot(A.astype(e4m3), B.astype(e4m3), preferred_element_type=f32)\n", |
91 |
| - "check_fp8_call(dot_fp8.lower(A, B))" |
| 86 | + "def dot_fp8(a, b):\n", |
| 87 | + " return jnp.dot(a.astype(e4m3), b.astype(e4m3), preferred_element_type=f32)\n", |
| 88 | + "check_fp8_call(dot_fp8.lower(a, b))" |
92 | 89 | ]
|
93 | 90 | },
|
94 | 91 | {
|
95 | 92 | "cell_type": "markdown",
|
96 | 93 | "id": "adb22878",
|
97 | 94 | "metadata": {},
|
98 | 95 | "source": [
|
99 |
| - "However, there are two main issues with this approach. Firstly, `jnp.dot` does\n", |
100 |
| - "not accept scaling factors for the operands, defaulting to a scaling factor of\n", |
101 |
| - "1.0. Secondly, it does not support operands of mixed FP8 data types. For\n", |
102 |
| - "example, when the operands are E5M2 and E4M3, the dot product is performed using\n", |
103 |
| - "the promoted FP16 data type.\n", |
| 96 | + "However, this approach has two key limitations:\n", |
104 | 97 | "\n",
|
105 |
| - "In real-world scenarios, it is essential to specify scaling factors, either from\n", |
106 |
| - "calibration for inference or a user-defined algorithm during training.\n", |
107 |
| - "Additionally, it is common practice to use E5M2 for gradients and E4M3 for\n", |
108 |
| - "activations and kernels. These limitations make this method less practical for\n", |
109 |
| - "real-world applications.\n", |
| 98 | + "1. `jnp.dot` does not support custom scaling factors for operands, defaulting to\n", |
| 99 | + " a scale of 1.0\n", |
| 100 | + "2. The autodiff does not automatically use E5M2 for gradients and E4M3 for\n", |
| 101 | + " activations/weights during training, which is the recommended practice\n", |
110 | 102 | "\n",
|
111 |
| - "To address these limitations and create a more versatile FP8 dot product, we\n", |
112 |
| - "recommend leveraging XLA-FP8. Let's begin with a simple scaling strategy.\n", |
| 103 | + "To overcome these limitations and implement proper FP8 matrix multiplication, we\n", |
| 104 | + "recommend using the Flax FP8 APIs. Let's start with a basic scaling approach.\n", |
113 | 105 | "\n",
|
114 | 106 | "\n",
|
115 | 107 | "### Current Scaling\n",
|
|
129 | 121 | "outputs": [],
|
130 | 122 | "source": [
|
131 | 123 | "@jax.jit\n",
|
132 |
| - "def dot_fp8(A, B):\n", |
133 |
| - " A_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n", |
134 |
| - " B_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n", |
135 |
| - " A = fp8_ops.quantize_dequantize(A, e4m3, A_scale, f32)\n", |
136 |
| - " B = fp8_ops.quantize_dequantize(B, e4m3, B_scale, f32)\n", |
137 |
| - "\n", |
138 |
| - " C = jnp.dot(A, B)\n", |
139 |
| - " return C\n", |
| 124 | + "def dot_fp8(a, b):\n", |
| 125 | + " a_scale = jnp.max(jnp.abs(A)) / E4M3_MAX\n", |
| 126 | + " b_scale = jnp.max(jnp.abs(B)) / E4M3_MAX\n", |
| 127 | + " a = fp8_ops.quantize(a, e4m3, a_scale, f32)\n", |
| 128 | + " b = fp8_ops.quantize(b, e4m3, b_scale, f32)\n", |
| 129 | + "\n", |
| 130 | + " c = jnp.dot(a, b, preferred_element_type=f32)\n", |
| 131 | + " c = fp8_ops.dequantize(c, f32, a_scale * b_scale)\n", |
| 132 | + " return c\n", |
140 | 133 | "\n",
|
141 |
| - "C = dot_fp8(A, B)\n", |
142 |
| - "check_fp8_call(dot_fp8.lower(A, B))" |
| 134 | + "c = dot_fp8(a, b)\n", |
| 135 | + "check_fp8_call(dot_fp8.lower(a, b))" |
143 | 136 | ]
|
144 | 137 | },
|
145 | 138 | {
|
146 | 139 | "cell_type": "markdown",
|
147 | 140 | "id": "59aca6fe",
|
148 | 141 | "metadata": {},
|
149 | 142 | "source": [
|
150 |
| - "As shown in the code, we perform fake quantization\n", |
151 |
| - "(`fp8_ops.quantize_dequantize`) on the operands of the dot product. Although the\n", |
152 |
| - "`jnp.dot` still processes higher-precision inputs, XLA detects this pattern and\n", |
153 |
| - "rewrites the dot operation as an FP8 dot call (e.g., cublasLt call for GPUs).\n", |
154 |
| - "This approach effectively mimics the first example but offers greater\n", |
155 |
| - "flexibility. We can control the input dtypes (both are set to E4M3 here, but we\n", |
156 |
| - "could use mixed E4M3 and E5M2) and define scaling factors, which XLA can detect\n", |
157 |
| - "and use in the dot backend.\n", |
158 |
| - "\n", |
159 |
| - "One major issue with the current scaling method is the overhead introduced by\n", |
160 |
| - "computing `A_scale` and `B_scale`, which requires additional loading of the\n", |
161 |
| - "operand tensors. To overcome this issue, we recommend the delayed scaling.\n", |
| 143 | + "As shown in the code, we perform quantization (`fp8_ops.quantize`) on the\n", |
| 144 | + "tensors to get the lower precision operands. The `jnp.dot` processes them and\n", |
| 145 | + "accumulates the output in high precision (i.e., the `preferred_element_type`).\n", |
| 146 | + "After that, we multiply the result by the scaling factors to dequantize back to\n", |
| 147 | + "the original range (`fp8_ops.dequantize`). Note that while this example uses\n", |
| 148 | + "E4M3 for both inputs, it is possible to use different FP8 dtypes like E4M3 and\n", |
| 149 | + "E5M2 for the inputs. The quantization method and the scaling factors can also be\n", |
| 150 | + "customized based on application needs.\n", |
| 151 | + "\n", |
| 152 | + "One major issue with the current scaling method is the performance overhead\n", |
| 153 | + "introduced by computing `a_scale` and `b_scale`, which requires additional\n", |
| 154 | + "loading of the operand tensors. To overcome this issue, we recommend the delayed\n", |
| 155 | + "scaling.\n", |
162 | 156 | "\n",
|
163 | 157 | "### Delayed Scaling\n",
|
164 | 158 | "\n",
|
|
167 | 161 | "values from recent steps (e.g., 1024 steps). Both tensors are computed from\n",
|
168 | 162 | "previous steps and maintained in the model parameters.\n",
|
169 | 163 | "\n",
|
170 |
| - "Fake quantization for delayed scaling is provided by `fp8_ops.in_qdq` for the\n", |
171 |
| - "activations and weights, and `fp8_ops.out_qdq` for the gradients." |
| 164 | + "The quantization and dequantization operations for delayed scaling are provided\n", |
| 165 | + "by `fp8_ops.in_q` and `fp8_ops.out_dq` respectively. `fp8_ops.in_q` handles\n", |
| 166 | + "input quantization and update the amax history and scaling factor, while\n", |
| 167 | + "`fp8_ops.out_dq` performs output dequantization." |
172 | 168 | ]
|
173 | 169 | },
|
174 | 170 | {
|
|
180 | 176 | "source": [
|
181 | 177 | "a_scale = jnp.array(1.0)\n",
|
182 | 178 | "b_scale = jnp.array(1.0)\n",
|
183 |
| - "g_scale = jnp.array(1.0)\n", |
184 | 179 | "a_amax_hist = jnp.zeros((1024,))\n",
|
185 | 180 | "b_amax_hist = jnp.zeros((1024,))\n",
|
186 |
| - "g_amax_hist = jnp.zeros((1024,))\n", |
187 | 181 | "\n",
|
188 | 182 | "@jax.jit\n",
|
189 |
| - "def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist,\n", |
190 |
| - " g_scale, g_amax_hist):\n", |
191 |
| - " a = fp8_ops.in_qdq(f32, e4m3, a, a_scale, a_amax_hist)\n", |
192 |
| - " b = fp8_ops.in_qdq(f32, e4m3, b, b_scale, b_amax_hist)\n", |
| 183 | + "def dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist):\n", |
| 184 | + " a, a_scale = fp8_ops.in_q(f32, e4m3, a, a_scale, a_amax_hist)\n", |
| 185 | + " b, b_scale = fp8_ops.in_q(f32, e4m3, b, b_scale, b_amax_hist)\n", |
193 | 186 | " \n",
|
194 |
| - " c = jnp.dot(a, b)\n", |
195 |
| - " c = fp8_ops.out_qdq(f32, e5m2, c, g_scale, g_amax_hist)\n", |
| 187 | + " c = jnp.dot(a, b, preferred_element_type=f32)\n", |
| 188 | + " c = fp8_ops.out_dq(f32, a_scale, b_scale, c)\n", |
196 | 189 | " return c\n",
|
197 | 190 | "\n",
|
198 |
| - "C = dot_fp8(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n", |
199 |
| - " g_scale, g_amax_hist)\n", |
200 |
| - "check_fp8_call(dot_fp8.lower(A, a_scale, a_amax_hist, B, b_scale, b_amax_hist,\n", |
201 |
| - " g_scale, g_amax_hist))" |
| 191 | + "c = dot_fp8(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist)\n", |
| 192 | + "check_fp8_call(dot_fp8.lower(a, a_scale, a_amax_hist, b, b_scale, b_amax_hist))" |
202 | 193 | ]
|
203 | 194 | },
|
204 | 195 | {
|
|
208 | 199 | "source": [
|
209 | 200 | "In this example, we first prepare three pairs of scaling factors and amax\n",
|
210 | 201 | "histories, treating them as results computed from previous steps. Then, we apply\n",
|
211 |
| - "`fp8_ops.in_qdq` to the input operands of `jnp.dot`, followed by\n", |
212 |
| - "`fp8_ops.out_qdq` to the output of `jnp.dot`. Note the `fp8_ops.out_qdq` will\n", |
213 |
| - "apply fake quantization to the gradient of the output via custom_vjp functions.\n", |
214 |
| - "The new scaling factors and amax histories will be returned through their\n", |
215 |
| - "gradients, which will be covered in the next section.\n", |
| 202 | + "`fp8_ops.in_q` to the input operands of `jnp.dot`, followed by `fp8_ops.out_dq`\n", |
| 203 | + "to the output of `jnp.dot`.\n", |
216 | 204 | "\n",
|
217 | 205 | "\n",
|
218 | 206 | "## FLAX High Level API\n",
|
219 |
| - "With the FLAX library, incorporating FP8 operations into existing FLAX layers\n", |
220 |
| - "is a seamless process. Users don't need to manipulate the low-level APIs for\n", |
221 |
| - "quantization. Instead, they can integrate the provided custom FP8 dot\n", |
222 |
| - "(`fp8_ops.Fp8DotGeneralOp`) into FLAX layers using a straightforward\n", |
223 |
| - "\"code-injection\" approach. This custom operation encapsulates all FP8-related\n", |
224 |
| - "tasks, including the placement of quantization-dequantization ops, algorithms\n", |
225 |
| - "for updating scaling factors, and the selection of FP8 dtype combinations for\n", |
226 |
| - "forward and backward propagation.\n", |
| 207 | + "Flax provides high-level operations to seamlessly integrate FP8 quantization\n", |
| 208 | + "into existing layers. Instead of manually handling quantization of the delayed\n", |
| 209 | + "scaling (e.g., the maintanence of the amax history and scaling factors), users\n", |
| 210 | + "can simply use these drop-in replacements:\n", |
| 211 | + "\n", |
| 212 | + "* `fp8_ops.Fp8DotGeneral` for `lax.dot_general` operations\n", |
| 213 | + "* `fp8_ops.Fp8Einsum` for `jnp.einsum` operations \n", |
| 214 | + "\n", |
| 215 | + "These operations automatically handle all FP8-related functionality, including\n", |
| 216 | + "quantization/dequantization, scale factor updates, and FP8 dtype selection for\n", |
| 217 | + "both forward and backward passes.\n", |
227 | 218 | "\n",
|
228 | 219 | "Consider the following example:"
|
229 | 220 | ]
|
|
235 | 226 | "metadata": {},
|
236 | 227 | "outputs": [],
|
237 | 228 | "source": [
|
238 |
| - "model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneralOp)\n", |
239 |
| - "params = model.init(key, A)\n", |
| 229 | + "model = nn.Dense(features=64, dot_general_cls=fp8_ops.Fp8DotGeneral)\n", |
| 230 | + "params = model.init(k0, A)\n", |
240 | 231 | "\n",
|
241 | 232 | "@jax.jit\n",
|
242 | 233 | "def train_step(var, a): \n",
|
|
248 | 239 | },
|
249 | 240 | {
|
250 | 241 | "cell_type": "markdown",
|
251 |
| - "id": "a83b0851", |
| 242 | + "id": "ba280e79", |
252 | 243 | "metadata": {},
|
253 | 244 | "source": [
|
254 |
| - "In this example, we simply set `dot_general_cls=fp8_ops.Fp8DotGeneralOp` to\n", |
255 |
| - "enable the Dense layer to utilize the FP8 dot operation. The usage of the model\n", |
256 |
| - "remains almost the same as before. The main difference is the addition of a new\n", |
257 |
| - "category of parameters: the sets of scaling factors and amax history. In the\n", |
258 |
| - "next section, we will explore how to update these parameters.\n", |
| 245 | + "By setting `dot_general_cls=fp8_ops.Fp8DotGeneral`, we replace the\n", |
| 246 | + "default `lax.dot_general` operation in `nn.Dense` with an FP8-enabled version.\n", |
| 247 | + "The model usage remains similar, but now includes additional parameters for FP8\n", |
| 248 | + "quantization: scaling factors and amax history values. The next section explains\n", |
| 249 | + "how to update these FP8-specific parameters.\n", |
| 250 | + "\n", |
| 251 | + "For models that use `jnp.einsum` operations, such as Mixture of Experts (MoE)\n", |
| 252 | + "layers, users can replace them with `fp8_ops.Fp8Einsum` to enable FP8\n", |
| 253 | + "quantization. Here's an example:" |
| 254 | + ] |
| 255 | + }, |
| 256 | + { |
| 257 | + "cell_type": "code", |
| 258 | + "execution_count": null, |
| 259 | + "id": "961b4549", |
| 260 | + "metadata": {}, |
| 261 | + "outputs": [], |
| 262 | + "source": [ |
| 263 | + "from typing import Any\n", |
| 264 | + "class FooModule(nn.Module):\n", |
| 265 | + " einsum: Any = None\n", |
| 266 | + " @nn.compact\n", |
| 267 | + " def __call__(self, a, b):\n", |
| 268 | + " if self.einsum is not None:\n", |
| 269 | + " einsum_fn = self.einsum()\n", |
| 270 | + " elif self.einsum is None:\n", |
| 271 | + " einsum_fn = jnp.einsum\n", |
| 272 | + " c = einsum_fn(\"mk,kn->mn\", a, b)\n", |
| 273 | + " return c\n", |
| 274 | + "\n", |
| 275 | + "model = FooModule(einsum=fp8_ops.Fp8Einsum)\n", |
| 276 | + "params = model.init(k0, a, b)\n", |
259 | 277 | "\n",
|
| 278 | + "@jax.jit\n", |
| 279 | + "def train_step(var, a, b):\n", |
| 280 | + " c = model.apply(var, a, b)\n", |
| 281 | + " return jnp.sum(c)\n", |
| 282 | + "\n", |
| 283 | + "check_fp8_call(train_step.lower(params, a, b))" |
| 284 | + ] |
| 285 | + }, |
| 286 | + { |
| 287 | + "cell_type": "markdown", |
| 288 | + "id": "a83b0851", |
| 289 | + "metadata": {}, |
| 290 | + "source": [ |
260 | 291 | "## Manipulate FP8 params\n",
|
| 292 | + "\n", |
| 293 | + "The following sections explain the internal FP8 parameters managed by\n", |
| 294 | + "`fp8_ops.Fp8DotGeneral` and `fp8_ops.Fp8Einsum`. These parameters\n", |
| 295 | + "include scaling factors and amax history values that control the FP8\n", |
| 296 | + "quantization process. While most users don't need to interact with these\n", |
| 297 | + "directly, understanding them can be valuable for advanced optimization and\n", |
| 298 | + "debugging.\n", |
| 299 | + "\n", |
261 | 300 | "Let's first examine the data structure of `params`. In the code below, we redact\n",
|
262 | 301 | "the parameter values and then display the PyTree structure."
|
263 | 302 | ]
|
|
285 | 324 | "The output is as follows:\n",
|
286 | 325 | "\n",
|
287 | 326 | "```plaintext\n",
|
288 |
| - "{'_overwrite_with_gradient': {'Fp8DotGeneralOp_0': {'input_amax_history': '*',\n", |
289 |
| - " 'input_scale': '*',\n", |
290 |
| - " 'kernel_amax_history': '*',\n", |
291 |
| - " 'kernel_scale': '*',\n", |
292 |
| - " 'output_grad_amax_history': '*',\n", |
293 |
| - " 'output_grad_scale': '*'}},\n", |
294 |
| - " 'params': {'bias': '*', 'kernel': '*'}}\n", |
| 327 | + "{'_overwrite_with_gradient': {'Fp8Einsum_0': {'input_amax_history': '*',\n", |
| 328 | + " 'input_scale': '*',\n", |
| 329 | + " 'kernel_amax_history': '*',\n", |
| 330 | + " 'kernel_scale': '*',\n", |
| 331 | + " 'output_grad_amax_history': '*',\n", |
| 332 | + " 'output_grad_scale': '*'}}}\n", |
295 | 333 | "```\n",
|
296 | 334 | "\n",
|
297 | 335 | "In addition to the expected `params`, there is an additional category called\n",
|
|
400 | 438 | "2.0 [5. 0. 0. ... 0. 0. 0.]\n",
|
401 | 439 | "```\n",
|
402 | 440 | "\n",
|
403 |
| - "This casting is already included if users choose to use the high-level APIs." |
| 441 | + "This casting is already included if users choose to use the high-level APIs.\n", |
| 442 | + "\n", |
| 443 | + "## Deprecated APIs\n", |
| 444 | + "Previously, we provided APIs like `fp8_ops.quantize_dequantize` for current\n", |
| 445 | + "scaling and `fp8_ops.[in|out]_qdq` for delayed scaling. These were used with\n", |
| 446 | + "high precision dot operations, leveraging an XLA-FP8 feature that\n", |
| 447 | + "pattern-matched QDQ->dot sequences to Q->fp8_cublas_gemm. The corresponding\n", |
| 448 | + "high-level API was called `fp8_ops.Fp8DotGeneralOp`. However, this pattern\n", |
| 449 | + "matching-based solution proved brittle, as the patterns could be easily broken\n", |
| 450 | + "by other XLA optimizations. We recommend users migrate from these deprecated\n", |
| 451 | + "APIs to the newer ones described above.\n", |
| 452 | + "\n", |
| 453 | + "For migration, users should replace:\n", |
| 454 | + "* `fp8_ops.quantize_dequantize -> jnp.dot` with `fp8_ops.quantize -> jnp.dot ->\n", |
| 455 | + " fp8_ops.dequantize`\n", |
| 456 | + "* `fp8_ops.in_qdq -> jnp.dot -> fp8_ops.out_qdq` with `fp8_ops.in_q -> jnp.dot\n", |
| 457 | + " -> fp8_ops.out_dq`\n", |
| 458 | + "* `fp8_ops.Fp8DotGeneralOp` with `fp8_ops.Fp8DotGeneral`\n", |
| 459 | + "\n", |
| 460 | + "Additionally, we provide an einsum variant through `fp8_ops.Fp8Einsum`." |
404 | 461 | ]
|
405 | 462 | }
|
406 | 463 | ],
|
|
0 commit comments