Skip to content

Commit 575f4c4

Browse files
Decouple recipe and scaling mode
Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 4ceb3d4 commit 575f4c4

File tree

5 files changed

+309
-82
lines changed

5 files changed

+309
-82
lines changed

tests/jax/test_custom_call_compute.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ScalingMode,
3737
QuantizerFactory,
3838
QuantizeLayout,
39+
QuantizerParams,
3940
)
4041
from transformer_engine.jax.quantize import helper
4142
from transformer_engine.jax.activation import activation
@@ -188,9 +189,11 @@ def test_act_grad_with_tensor_scaling_fp8(
188189
)
189190

190191
quantizer = QuantizerFactory.create(
191-
scaling_mode=scaling_mode,
192-
q_dtype=output_type,
193-
q_layout=QuantizeLayout.ROWWISE,
192+
QuantizerParams(
193+
scaling_mode=scaling_mode,
194+
q_dtype=output_type,
195+
q_layout=QuantizeLayout.ROWWISE,
196+
)
194197
)
195198

196199
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer)
@@ -219,9 +222,11 @@ def test_act_forward_with_tensor_scaling_fp8(
219222

220223
te_quantizer, jax_quantizer = QuantizerFactory.create(
221224
n_quantizers=2,
222-
scaling_mode=scaling_mode,
223-
q_dtype=output_type,
224-
q_layout=q_layout,
225+
q_params=QuantizerParams(
226+
scaling_mode=scaling_mode,
227+
q_dtype=output_type,
228+
q_layout=q_layout,
229+
),
225230
)
226231

227232
te_output = tex.act_lu(x, activation_type, te_quantizer)
@@ -244,7 +249,9 @@ def test_act_forward_with_block_scaling_fp8(
244249
self.activation_type = activation_type
245250

246251
quantizer = QuantizerFactory.create(
247-
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
252+
QuantizerParams(
253+
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
254+
)
248255
)
249256

250257
output = tex.act_lu(x, activation_type, quantizer)
@@ -378,7 +385,7 @@ def test_norm_grad_with_tensor_scaling_fp8(
378385
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
379386

380387
quantizer = QuantizerFactory.create(
381-
scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
388+
QuantizerParams(scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout)
382389
)
383390
self._test_norm_grad(
384391
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
@@ -406,7 +413,12 @@ def _test_norm_forward(
406413
gamma = jnp.asarray(gamma, inp_dtype)
407414

408415
quantizer, ref_quantizer = QuantizerFactory.create(
409-
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
416+
n_quantizers=2,
417+
q_params=QuantizerParams(
418+
scaling_mode=scaling_mode,
419+
q_dtype=out_dtype,
420+
q_layout=q_layout,
421+
),
410422
)
411423
if norm_type == "layernorm":
412424
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
@@ -562,9 +574,11 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt
562574

563575
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
564576
quantizer = QuantizerFactory.create(
565-
scaling_mode=scaling_mode,
566-
q_dtype=q_dtype,
567-
q_layout=q_layout,
577+
QuantizerParams(
578+
scaling_mode=scaling_mode,
579+
q_dtype=q_dtype,
580+
q_layout=q_layout,
581+
)
568582
)
569583
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
570584
if flatten_axis == -2:
@@ -587,7 +601,8 @@ def test_quantize_bitwise(
587601
input = jax.random.uniform(key, input_shape, in_dtype)
588602

589603
te_quantizer, jax_quantizer = QuantizerFactory.create(
590-
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
604+
n_quantizers=2,
605+
q_params=QuantizerParams(q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout),
591606
)
592607

593608
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
@@ -619,7 +634,10 @@ def test_quantize_dbias(
619634
input = jax.random.uniform(key, input_shape, in_dtype)
620635

621636
jax_quantizer, te_quantizer = QuantizerFactory.create(
622-
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
637+
n_quantizers=2,
638+
q_params=QuantizerParams(
639+
q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
640+
),
623641
)
624642

625643
te_output, te_dbias = jit(
@@ -649,7 +667,10 @@ def _test_quantize_dact_dbias(
649667
dz = jax.random.uniform(subkeys[1], input_shape, in_dtype, -1, 1)
650668

651669
jax_quantizer, te_quantizer = QuantizerFactory.create(
652-
n_quantizers=2, q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
670+
n_quantizers=2,
671+
q_params=QuantizerParams(
672+
q_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout
673+
),
653674
)
654675
is_casted_output = te_quantizer is not None
655676

tests/jax/test_layer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
ScalingMode,
2828
is_fp8_available,
2929
update_collections,
30+
UsageContext,
31+
UsageType,
3032
)
3133

3234

@@ -354,7 +356,13 @@ def test_backward(
354356
test_others,
355357
test_layer,
356358
)
357-
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
359+
if (
360+
QuantizeConfig.RECIPE_MANAGER is not None
361+
and QuantizeConfig.RECIPE_MANAGER.get_quantizer_params(
362+
UsageContext(UsageType.X)
363+
).scaling_mode
364+
== ScalingMode.DELAYED_TENSOR_SCALING
365+
):
358366
_, updated_quantize_meta = flax.core.pop(
359367
updated_state[0], QuantizeConfig.COLLECTION_NAME
360368
)

transformer_engine/jax/flax/module.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,15 @@
3131
jax_scaled_masked_softmax,
3232
jax_scaled_upper_triang_masked_softmax,
3333
)
34-
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
34+
from ..quantize import (
35+
QuantizerFactory,
36+
QuantizeConfig,
37+
QuantizeMeta,
38+
QuantizeMetaSet,
39+
ScalingMode,
40+
UsageContext,
41+
UsageType,
42+
)
3543
from ..sharding import get_non_contracting_logical_axes
3644

3745
PRNGKey = Any
@@ -356,7 +364,13 @@ def generate_quantize_meta(quantizer_name: str):
356364
).value
357365
return QuantizeMeta(scale=scale, amax_history=amax_history)
358366

359-
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
367+
if (
368+
QuantizeConfig.RECIPE_MANAGER is not None
369+
and QuantizeConfig.RECIPE_MANAGER.get_quantizer_params(
370+
UsageContext(UsageType.X)
371+
).scaling_mode
372+
== ScalingMode.DELAYED_TENSOR_SCALING
373+
):
360374
x_meta = generate_quantize_meta("x")
361375
kernel_meta = generate_quantize_meta("kernel")
362376
grad_meta = generate_quantize_meta("grad")

0 commit comments

Comments
 (0)