From e19b8281453ec1835319448bba93697fc8b0f537 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 14 Feb 2025 08:11:32 -0800 Subject: [PATCH] [JAX] Fixes for CI failures with the latest JAX (#1469) * fixes L1 test * fix test_multigpu_encoder * fixes for other multi-encoder tests * jax.extend.ffi to jax.ffi * initialization with float32 * add init_dtype as an optional arg to all modules * update use_scan query from xla flags * relax threshold for test_encoder fp8 * relax the tols --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 6 +++--- examples/jax/encoder/test_multigpu_encoder.py | 2 +- .../encoder/test_multiprocessing_encoder.py | 4 ++-- .../jax/encoder/test_single_gpu_encoder.py | 2 +- qa/L1_jax_distributed_unittest/test.sh | 8 +------- tests/jax/test_distributed_fused_attn.py | 20 ++++++++++++++----- .../jax/cpp_extensions/activation.py | 2 +- .../jax/cpp_extensions/attention.py | 6 ++---- .../jax/cpp_extensions/custom_call.py | 7 +++---- .../jax/cpp_extensions/normalization.py | 2 +- .../jax/cpp_extensions/quantization.py | 2 +- .../jax/cpp_extensions/softmax.py | 2 +- .../jax/cpp_extensions/transpose.py | 2 +- transformer_engine/jax/flax/transformer.py | 7 ++++--- 14 files changed, 37 insertions(+), 35 deletions(-) diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 918dfd8238..f02cc562b5 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -239,7 +239,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -447,7 +447,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_sp(self): @@ -462,7 +462,7 @@ def test_te_fp8_sp(self): self.args.enable_sp = True self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.785 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index c0325d3e28..eb4a1d0afb 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -218,7 +218,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 7d2df77b7d..91186a15c4 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -320,7 +320,7 @@ def to_device_axis(logical_axis): ) params_axes_sharding = flax.core.unfreeze(params_axes_sharding) params_sharding = jax.tree_util.tree_map( - lambda x: NamedSharding(mesh, ()), abs_var_collect[PARAMS_KEY] + lambda x: NamedSharding(mesh, PartitionSpec(None)), abs_var_collect[PARAMS_KEY] ) params_sharding = {**params_sharding, **params_axes_sharding} return params_sharding @@ -587,7 +587,7 @@ def test_te_bf16(self): def test_te_fp8(self): """Test Transformer Engine with FP8""" result = self.exec(True) - assert result[0] < 0.45 and result[1] > 0.79 + assert result[0] < 0.455 and result[1] > 0.79 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index b2439278ea..dd1997fe6f 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -334,7 +334,7 @@ def test_te_fp8(self): """Test Transformer Engine with FP8""" self.args.use_fp8 = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.455 and actual[1] > 0.79 if __name__ == "__main__": diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index deb0f93cec..e47aa15fbd 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -6,10 +6,4 @@ set -xe : ${TE_PATH:=/opt/transformerengine} -# Skip ring attention tests since they need fixed environment vars -pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* -k 'not test_context_parallel_ring_attn' - -# Test ring attention with and without scan loop -NVTE_FUSED_RING_ATTENTION_USE_SCAN=0 pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn -NVTE_FUSED_RING_ATTENTION_USE_SCAN=1 XLA_FLAGS="--xla_experimental_ignore_channel_id" \ - pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_fused_attn.py -k test_context_parallel_ring_attn +pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d7e015dbf7..898993f5d1 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -2,6 +2,8 @@ # # See LICENSE for license information. +import os +import pytest import jax import jax.numpy as jnp import numpy as np @@ -11,7 +13,7 @@ generate_context_parallel_configs, generate_collectives_count, ) -from transformer_engine.jax import fp8_autocast +from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -22,10 +24,7 @@ inverse_reorder_causal_load_balancing, CPStrategy, ) -from transformer_engine.jax.sharding import MeshResource -import pytest -from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat DTYPES = [jnp.bfloat16] @@ -355,6 +354,10 @@ def test_context_parallel_allgather_attn( CPStrategy.ALL_GATHER, ) + @pytest.mark.parametrize( + "use_scan", + [pytest.param(False, id="NO_SCAN"), pytest.param(True, id="USE_SCAN")], + ) def test_context_parallel_ring_attn( self, device_count, @@ -367,8 +370,14 @@ def test_context_parallel_ring_attn( dtype, qkv_layout, load_balanced, + use_scan, ): - return self.impl_test_context_parallel_attn( + if use_scan: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" + else: + os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" + + self.impl_test_context_parallel_attn( device_count, mesh_shape, mesh_axes, @@ -381,6 +390,7 @@ def test_context_parallel_ring_attn( load_balanced, CPStrategy.RING, ) + del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] class TestReorderCausalLoadBalancing: diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 4a29fce2c4..076ec98aba 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -11,7 +11,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import NVTE_Activation_Type diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 5ec556ab34..1c32ef4cba 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -15,7 +15,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor @@ -1602,9 +1602,7 @@ def use_scanloop(): def truthy(val): return val.lower() in ["1", "true"] - x = use_scan and get_xla_flag( - "--xla_experimental_ignore_channel_id", default=False, cast=truthy - ) + x = use_scan and get_xla_flag("--xla_ignore_channel_id", default=True, cast=truthy) return x def check_supported(self): diff --git a/transformer_engine/jax/cpp_extensions/custom_call.py b/transformer_engine/jax/cpp_extensions/custom_call.py index 6739ac8bda..6f6c9962cf 100644 --- a/transformer_engine/jax/cpp_extensions/custom_call.py +++ b/transformer_engine/jax/cpp_extensions/custom_call.py @@ -5,9 +5,8 @@ from dataclasses import dataclass from enum import IntEnum +import jax from jax.interpreters import mlir -import jax.extend as jex - from transformer_engine import transformer_engine_jax from .misc import is_ffi_enabled @@ -30,11 +29,11 @@ class CustomCallAPIVersion(IntEnum): for _name, _value in transformer_engine_jax.registrations().items(): if _name.endswith("_ffi"): if is_ffi_enabled(): - jex.ffi.register_ffi_target( + jax.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.FFI.value ) else: - jex.ffi.register_ffi_target( + jax.ffi.register_ffi_target( _name, _value, platform="CUDA", api_version=CustomCallAPIVersion.OPAQUE.value ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d7512b0e70..1107dd3a0f 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -13,7 +13,7 @@ from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c3ea8cb7aa..2f29a64f18 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -9,7 +9,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 5c55dd3672..dba1f504da 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -12,7 +12,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax diff --git a/transformer_engine/jax/cpp_extensions/transpose.py b/transformer_engine/jax/cpp_extensions/transpose.py index d07b6944fb..bb9b104e7e 100644 --- a/transformer_engine/jax/cpp_extensions/transpose.py +++ b/transformer_engine/jax/cpp_extensions/transpose.py @@ -11,7 +11,7 @@ from jax import dtypes from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec, NamedSharding -from jax.extend import ffi +from jax import ffi from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import DType as TEDType diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 6c96e7ba1a..fbae73f131 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -150,8 +150,8 @@ def __call__( del self.scale_factor if self.float32_logits: - query = query.astype(jnp.float32) - key = key.astype(jnp.float32) + query = query.astype(self.dtype) + key = key.astype(self.dtype) h_q, h_kv = query.shape[-2], key.shape[-2] # The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv. # Therefore, we have to maintain two code paths. @@ -989,6 +989,7 @@ def __post_init__(self): self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "normal", self.weight_dtype ) + self.kernel_init = _kernel_init.astype(self.dtype) if self.num_gqa_groups is None: self.num_gqa_groups = self.num_attention_heads super().__post_init__() @@ -1281,7 +1282,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): f"expected query shape {expected_shape} instead got {query.shape}." ) - cur_index = cache_index.value + cur_index = cache_index.value.astype(jnp.int32) one_hot_indices = jax_nn.one_hot(cur_index, length, dtype=key.dtype) one_hot_indices = jnp.reshape(one_hot_indices, one_hot_indices_shape) key = cached_key.value + key * one_hot_indices