Skip to content

Commit 4ceb3d4

Browse files
[JAX] Distributed Current Scaling (#1699)
* Update test_helper.py and add QuantizeConfig class for CurrentScaling Signed-off-by: Jeremy Berchtold <[email protected]> * WIP distributed current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Distributed Current Scaling (debugging). Distributed implementation with replicated scale_inv works for layernorm_mlp but feels like a hack Has different per-device scale_inv values, but jax.debug.print only shows one of them. Since we're telling JAX/XLA that this scale is replicated, I think it assumes all the values are equal. However, it doesn't actually check this, so it seems we are able to get away with per-device scales for current scaling but I am not sure how stable this will be and may randomly fail if us or the user changes partitioning at all or if XLA decides to actually act on the assumption that all these scale_invs are the same. Signed-off-by: Jeremy Berchtold <[email protected]> * Implement distributed current scaling by computing a global amax and scale before quantization Signed-off-by: Jeremy Berchtold <[email protected]> * Add encoder and mnist tests for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Add primitive prefix to shardy unique_vars to prevent factor conflicts when performing unfused primitives for current scaling Signed-off-by: Jeremy Berchtold <[email protected]> * Remove scale_shape primitive arg that is no longer used Signed-off-by: Jeremy Berchtold <[email protected]> * Format Signed-off-by: Jeremy Berchtold <[email protected]> * Fix expected result on multiprocessing encoder test Signed-off-by: Jeremy Berchtold <[email protected]> * Lint fix Signed-off-by: Jeremy Berchtold <[email protected]> * Update multiprocessing current scaling tolerances Signed-off-by: Jeremy Berchtold <[email protected]> * Uncomment test case that was disabled for testing Signed-off-by: Jeremy Berchtold <[email protected]> * Remove commented out debug line Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]>
1 parent 643fb0a commit 4ceb3d4

16 files changed

+230
-202
lines changed

examples/jax/encoder/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str):
3737
return recipe.DelayedScaling()
3838
case "MXFP8BlockScaling":
3939
return recipe.MXFP8BlockScaling()
40+
case "Float8CurrentScaling":
41+
return recipe.Float8CurrentScaling()
4042
case _:
4143
raise ValueError(f"Invalid fp8_recipe, got {name}")

examples/jax/encoder/run_test_multiprocessing_encoder.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
88
TEST_CASES=(
99
"test_te_bf16"
1010
"test_te_delayed_scaling_fp8"
11+
"test_te_current_scaling_fp8"
1112
"test_te_mxfp8"
1213
"test_te_bf16_shardy"
1314
"test_te_delayed_scaling_fp8_shardy"
15+
"test_te_current_scaling_fp8_shardy"
1416
)
1517

1618
echo

examples/jax/encoder/test_multigpu_encoder.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,14 @@ def test_te_delayed_scaling_fp8(self):
441441
actual = train_and_evaluate(self.args)
442442
assert actual[0] < 0.535 and actual[1] > 0.73
443443

444+
@unittest.skipIf(not is_fp8_supported, fp8_reason)
445+
def test_te_current_scaling_fp8(self):
446+
"""Test Transformer Engine with CurrentScaling FP8"""
447+
self.args.use_fp8 = True
448+
self.args.fp8_recipe = "Float8CurrentScaling"
449+
actual = train_and_evaluate(self.args)
450+
assert actual[0] < 0.535 and actual[1] > 0.73
451+
444452
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
445453
def test_te_mxfp8(self):
446454
"""Test Transformer Engine with MXFP8"""
@@ -467,6 +475,15 @@ def test_te_delayed_scaling_fp8_shardy(self):
467475

468476
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
469477

478+
@unittest.skipIf(not is_fp8_supported, fp8_reason)
479+
def test_te_current_scaling_fp8_shardy(self):
480+
"""Test Transformer Engine with CurrentScaling FP8"""
481+
self.args.enable_shardy = True
482+
self.args.use_fp8 = True
483+
self.args.fp8_recipe = "Float8CurrentScaling"
484+
actual = train_and_evaluate(self.args)
485+
assert actual[0] < 0.535 and actual[1] > 0.73
486+
470487

471488
if __name__ == "__main__":
472489
train_and_evaluate(encoder_parser(None))

examples/jax/encoder/test_multiprocessing_encoder.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,14 @@ def test_te_delayed_scaling_fp8(self):
611611
result = self.exec(True, "DelayedScaling")
612612
assert result[0] < 0.505 and result[1] > 0.754
613613

614+
@unittest.skipIf(
615+
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
616+
)
617+
def test_te_current_scaling_fp8(self):
618+
"""Test Transformer Engine with CurrentScaling FP8"""
619+
result = self.exec(True, "Float8CurrentScaling")
620+
assert result[0] < 0.507 and result[1] > 0.753
621+
614622
@unittest.skipIf(
615623
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
616624
)
@@ -631,10 +639,18 @@ def test_te_bf16_shardy(self):
631639
def test_te_delayed_scaling_fp8_shardy(self):
632640
"""Test Transformer Engine with DelayedScaling FP8"""
633641
result = self.exec(True, "DelayedScaling", enable_shardy=True)
634-
assert result[0] < 0.505 and result[1] > 0.754
642+
assert result[0] < 0.505 and result[1] > 0.753
635643

636644
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
637645

646+
@unittest.skipIf(
647+
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
648+
)
649+
def test_te_current_scaling_fp8_shardy(self):
650+
"""Test Transformer Engine with CurrentScaling FP8"""
651+
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
652+
assert result[0] < 0.507 and result[1] > 0.753
653+
638654

639655
if __name__ == "__main__":
640656
train_and_evaluate(encoder_parser(None))

examples/jax/encoder/test_single_gpu_encoder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,14 @@ def test_te_delayed_scaling_fp8(self):
348348
actual = train_and_evaluate(self.args)
349349
assert actual[0] < 0.455 and actual[1] > 0.79
350350

351+
@unittest.skipIf(not is_fp8_supported, fp8_reason)
352+
def test_te_current_scaling_fp8(self):
353+
"""Test Transformer Engine with CurrentScaling FP8"""
354+
self.args.use_fp8 = True
355+
self.args.fp8_recipe = "Float8CurrentScaling"
356+
actual = train_and_evaluate(self.args)
357+
assert actual[0] < 0.455 and actual[1] > 0.79
358+
351359
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
352360
def test_te_mxfp8(self):
353361
"""Test Transformer Engine with MXFP8"""

examples/jax/mnist/test_single_gpu_mnist.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,14 @@ def test_te_mxfp8(self):
350350
actual = train_and_evaluate(self.args)
351351
self.verify(actual)
352352

353+
@unittest.skipIf(not is_fp8_supported, fp8_reason)
354+
def test_te_current_scaling_fp8(self):
355+
"""Test Transformer Engine with CurrentScaling FP8"""
356+
self.args.use_fp8 = True
357+
self.args.fp8_recipe = "Float8CurrentScaling"
358+
actual = train_and_evaluate(self.args)
359+
self.verify(actual)
360+
353361

354362
if __name__ == "__main__":
355363
train_and_evaluate(mnist_parser(None))

tests/jax/test_distributed_layernorm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
SUPPORTED_RECIPES = []
3535
if is_fp8_supported:
3636
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
37+
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
3738
if is_mxfp8_supported:
3839
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
3940

@@ -76,6 +77,8 @@ def generate_collectives_count_ref(
7677
other_bytes = 0
7778
if fp8_recipe == recipe.MXFP8BlockScaling() and "dp" in mesh_axes:
7879
other_bytes = 384 # required for small scale shapes that require padding
80+
if fp8_recipe == recipe.Float8CurrentScaling():
81+
allreduce_total_bytes += 4 # 1 * FP32 for the amax reduction
7982
return generate_collectives_count(
8083
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
8184
)

tests/jax/test_distributed_layernorm_mlp.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
SUPPORTED_RECIPES = []
4242
if is_fp8_supported:
4343
SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling"))
44+
SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling"))
4445
if is_mxfp8_supported:
4546
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
4647

@@ -217,37 +218,10 @@ def _test_layernorm_mlp_grad(
217218
m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
218219
)
219220
else:
220-
is_gated = len(activation_type) > 1
221-
rtol = None
222-
atol = None
223-
if is_gated:
224-
if dtype == jnp.bfloat16:
225-
if i == 2:
226-
rtol = 800
227-
atol = 9e-2
228-
if i == 4:
229-
atol = 300
230-
rtol = 1e-1
231-
if dtype == jnp.float16:
232-
if i == 1: # gamma
233-
rtol = 200
234-
atol = 1e-2
235-
if i == 2:
236-
rtol = 2000
237-
atol = 7e-2
238-
if i == 4 and fp8_recipe == recipe.MXFP8BlockScaling(): # bias_1
239-
# Accumulating dbias across a large tensor introduces a larger difference
240-
rtol = 200
241-
atol = 4e-2
242-
if i == 4 and fp8_recipe == recipe.DelayedScaling():
243-
rtol = 2200
244-
atol = 9e-2
245221
assert_allclose(
246222
multi_grads[i],
247223
single_grads[i],
248224
dtype=dtype,
249-
rtol=rtol,
250-
atol=atol,
251225
err_msg=f"multi_grads[{i}] is not close",
252226
)
253227

tests/jax/test_helper.py

Lines changed: 72 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -10,47 +10,22 @@
1010
import numpy as np
1111

1212
from utils import assert_allclose
13-
from transformer_engine.common.recipe import DelayedScaling
13+
from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling
1414
from transformer_engine.common.recipe import Format as FP8Format
1515
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
16-
from transformer_engine.jax.quantize import QuantizeConfig, is_fp8_available, AmaxComputeAlgo
16+
from transformer_engine.jax.quantize import (
17+
QuantizeConfig,
18+
is_fp8_available,
19+
ScalingMode,
20+
update_collections,
21+
)
1722
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
1823

1924
is_fp8_supported, reason = is_fp8_available()
25+
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
2026

2127

22-
class TestQuantizeConfig(unittest.TestCase):
23-
24-
@unittest.skipIf(not is_fp8_supported, reason=reason)
25-
def test_initialize(self):
26-
margin = 5.0
27-
fp8_format = FP8Format.E4M3
28-
amax_history_len = 10
29-
30-
QuantizeConfig.initialize(
31-
margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
32-
)
33-
34-
self.assertEqual(
35-
QuantizeConfig.MARGIN,
36-
margin,
37-
f"QuantizeConfig.MARGIN initialization failed, should be {margin}"
38-
f" but got {QuantizeConfig.MARGIN}.",
39-
)
40-
self.assertEqual(
41-
QuantizeConfig.FP8_FORMAT,
42-
fp8_format,
43-
f"QuantizeConfig.FP8_FORMAT initialization failed, should be {fp8_format}"
44-
f" but got {QuantizeConfig.FP8_FORMAT}.",
45-
)
46-
self.assertEqual(
47-
QuantizeConfig.AMAX_HISTORY_LEN,
48-
amax_history_len,
49-
f"QuantizeConfig.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
50-
f" but got {QuantizeConfig.AMAX_HISTORY_LEN}.",
51-
)
52-
53-
QuantizeConfig.finalize()
28+
class TestHelper(unittest.TestCase):
5429

5530
@unittest.skipIf(not is_fp8_supported, reason=reason)
5631
def test_update_collections(self):
@@ -61,12 +36,12 @@ def test_update_collections(self):
6136
"test1": original_val,
6237
"test2": original_val,
6338
}
64-
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
39+
updated_state = update_collections({"test1": updated_val}, original_state)
6540
self.assertEqual(updated_state["test1"], updated_val)
6641
self.assertEqual(updated_state["test2"], original_val)
6742

6843
original_state = flax.core.frozen_dict.FrozenDict(original_state)
69-
updated_state = QuantizeConfig.update_collections({"test1": updated_val}, original_state)
44+
updated_state = update_collections({"test1": updated_val}, original_state)
7045
self.assertEqual(updated_state["test1"], updated_val)
7146
self.assertEqual(updated_state["test2"], original_val)
7247

@@ -82,8 +57,18 @@ def _compare_delay_scaling(self, ref, test):
8257
self.assertTrue(ref.amax_history_len == test.amax_history_len)
8358
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
8459

60+
def _compare_current_scaling(self, test):
61+
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
62+
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
63+
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
64+
65+
def _compare_mxfp8_scaling(self, test):
66+
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
67+
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
68+
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
69+
8570
@unittest.skipIf(not is_fp8_supported, reason=reason)
86-
def test_fp8_autocast(self):
71+
def test_fp8_autocast_delayed_scaling(self):
8772
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
8873
self._check_defult_state()
8974

@@ -107,6 +92,56 @@ def test_fp8_autocast(self):
10792

10893
self._check_defult_state()
10994

95+
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
96+
def test_fp8_autocast_mxfp8_scaling(self):
97+
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
98+
self._check_defult_state()
99+
100+
with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()):
101+
self.assertFalse(QuantizeConfig.is_fp8_enabled())
102+
self._compare_current_scaling(Float8CurrentScaling())
103+
104+
self._check_defult_state()
105+
106+
cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3)
107+
with fp8_autocast(enabled=True, fp8_recipe=cs):
108+
self.assertTrue(QuantizeConfig.is_fp8_enabled())
109+
self._compare_current_scaling(cs)
110+
111+
self._check_defult_state()
112+
113+
cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
114+
with fp8_autocast(enabled=True, fp8_recipe=cs):
115+
self.assertTrue(QuantizeConfig.is_fp8_enabled())
116+
self._compare_current_scaling(cs)
117+
118+
self._check_defult_state()
119+
120+
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
121+
def test_fp8_autocast_mxfp8_scaling(self):
122+
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
123+
self._check_defult_state()
124+
125+
with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()):
126+
self.assertFalse(QuantizeConfig.is_fp8_enabled())
127+
self._compare_mxfp8_scaling(MXFP8BlockScaling())
128+
129+
self._check_defult_state()
130+
131+
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
132+
with fp8_autocast(enabled=True, fp8_recipe=bs):
133+
self.assertTrue(QuantizeConfig.is_fp8_enabled())
134+
self._compare_mxfp8_scaling(bs)
135+
136+
self._check_defult_state()
137+
138+
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
139+
with fp8_autocast(enabled=True, fp8_recipe=bs):
140+
self.assertTrue(QuantizeConfig.is_fp8_enabled())
141+
self._compare_mxfp8_scaling(bs)
142+
143+
self._check_defult_state()
144+
110145
@unittest.skipIf(not is_fp8_supported, reason=reason)
111146
def test_fp8_autocast_with_sharding_resource(self):
112147
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.

0 commit comments

Comments
 (0)