36
36
ScalingMode ,
37
37
QuantizerFactory ,
38
38
QuantizeLayout ,
39
+ QuantizerParams ,
39
40
)
40
41
from transformer_engine .jax .quantize import helper
41
42
from transformer_engine .jax .activation import activation
@@ -188,9 +189,11 @@ def test_act_grad_with_tensor_scaling_fp8(
188
189
)
189
190
190
191
quantizer = QuantizerFactory .create (
191
- scaling_mode = scaling_mode ,
192
- q_dtype = output_type ,
193
- q_layout = QuantizeLayout .ROWWISE ,
192
+ QuantizerParams (
193
+ scaling_mode = scaling_mode ,
194
+ q_dtype = output_type ,
195
+ q_layout = QuantizeLayout .ROWWISE ,
196
+ )
194
197
)
195
198
196
199
prim_out , (prim_grad ,) = value_n_grad_primitive_func (x , activation_type , quantizer )
@@ -219,9 +222,11 @@ def test_act_forward_with_tensor_scaling_fp8(
219
222
220
223
te_quantizer , jax_quantizer = QuantizerFactory .create (
221
224
n_quantizers = 2 ,
222
- scaling_mode = scaling_mode ,
223
- q_dtype = output_type ,
224
- q_layout = q_layout ,
225
+ q_params = QuantizerParams (
226
+ scaling_mode = scaling_mode ,
227
+ q_dtype = output_type ,
228
+ q_layout = q_layout ,
229
+ ),
225
230
)
226
231
227
232
te_output = tex .act_lu (x , activation_type , te_quantizer )
@@ -244,7 +249,9 @@ def test_act_forward_with_block_scaling_fp8(
244
249
self .activation_type = activation_type
245
250
246
251
quantizer = QuantizerFactory .create (
247
- scaling_mode = ScalingMode .MXFP8_1D_SCALING , q_dtype = output_type , q_layout = q_layout
252
+ QuantizerParams (
253
+ scaling_mode = ScalingMode .MXFP8_1D_SCALING , q_dtype = output_type , q_layout = q_layout
254
+ )
248
255
)
249
256
250
257
output = tex .act_lu (x , activation_type , quantizer )
@@ -378,7 +385,7 @@ def test_norm_grad_with_tensor_scaling_fp8(
378
385
pytest .skip ("RMSNorm and zero_centered_gamma is not supported!" )
379
386
380
387
quantizer = QuantizerFactory .create (
381
- scaling_mode = scaling_mode , q_dtype = out_dtype , q_layout = q_layout
388
+ QuantizerParams ( scaling_mode = scaling_mode , q_dtype = out_dtype , q_layout = q_layout )
382
389
)
383
390
self ._test_norm_grad (
384
391
n , hidden , norm_type , zero_centered_gamma , epsilon , inp_dtype , quantizer
@@ -406,7 +413,12 @@ def _test_norm_forward(
406
413
gamma = jnp .asarray (gamma , inp_dtype )
407
414
408
415
quantizer , ref_quantizer = QuantizerFactory .create (
409
- n_quantizers = 2 , scaling_mode = scaling_mode , q_dtype = out_dtype , q_layout = q_layout
416
+ n_quantizers = 2 ,
417
+ q_params = QuantizerParams (
418
+ scaling_mode = scaling_mode ,
419
+ q_dtype = out_dtype ,
420
+ q_layout = q_layout ,
421
+ ),
410
422
)
411
423
if norm_type == "layernorm" :
412
424
beta = jax .random .uniform (subkeys [2 ], (hidden ,), jnp .float32 , - 1 , 1 )
@@ -562,9 +574,11 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt
562
574
563
575
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
564
576
quantizer = QuantizerFactory .create (
565
- scaling_mode = scaling_mode ,
566
- q_dtype = q_dtype ,
567
- q_layout = q_layout ,
577
+ QuantizerParams (
578
+ scaling_mode = scaling_mode ,
579
+ q_dtype = q_dtype ,
580
+ q_layout = q_layout ,
581
+ )
568
582
)
569
583
# Adding dimension to test if padding is done correctly when flatten 3D to 2D
570
584
if flatten_axis == - 2 :
@@ -587,7 +601,8 @@ def test_quantize_bitwise(
587
601
input = jax .random .uniform (key , input_shape , in_dtype )
588
602
589
603
te_quantizer , jax_quantizer = QuantizerFactory .create (
590
- n_quantizers = 2 , q_dtype = q_dtype , scaling_mode = scaling_mode , q_layout = q_layout
604
+ n_quantizers = 2 ,
605
+ q_params = QuantizerParams (q_dtype = q_dtype , scaling_mode = scaling_mode , q_layout = q_layout ),
591
606
)
592
607
593
608
jax_output = _jax_quantize (input , quantizer = jax_quantizer , flatten_axis = flatten_axis )
@@ -619,7 +634,10 @@ def test_quantize_dbias(
619
634
input = jax .random .uniform (key , input_shape , in_dtype )
620
635
621
636
jax_quantizer , te_quantizer = QuantizerFactory .create (
622
- n_quantizers = 2 , q_dtype = out_dtype , scaling_mode = scaling_mode , q_layout = q_layout
637
+ n_quantizers = 2 ,
638
+ q_params = QuantizerParams (
639
+ q_dtype = out_dtype , scaling_mode = scaling_mode , q_layout = q_layout
640
+ ),
623
641
)
624
642
625
643
te_output , te_dbias = jit (
@@ -649,7 +667,10 @@ def _test_quantize_dact_dbias(
649
667
dz = jax .random .uniform (subkeys [1 ], input_shape , in_dtype , - 1 , 1 )
650
668
651
669
jax_quantizer , te_quantizer = QuantizerFactory .create (
652
- n_quantizers = 2 , q_dtype = out_dtype , scaling_mode = scaling_mode , q_layout = q_layout
670
+ n_quantizers = 2 ,
671
+ q_params = QuantizerParams (
672
+ q_dtype = out_dtype , scaling_mode = scaling_mode , q_layout = q_layout
673
+ ),
653
674
)
654
675
is_casted_output = te_quantizer is not None
655
676
0 commit comments