From 86f27e12a58b9a5e2b3be9531e695a3ac93a4f69 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Thu, 8 Aug 2024 21:28:26 +0800 Subject: [PATCH] [JAX] Support non-deterministic algo for cuDNN FA (#1056) * Support non-deterministic algo Signed-off-by: Reese Wang * Refine the helper function name Signed-off-by: Reese Wang * Move fixture to conftest.py Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang Co-authored-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --- tests/jax/conftest.py | 19 ++++++++++++ tests/jax/test_praxis_layers.py | 14 --------- .../jax/cpp_extensions/attention.py | 14 ++++++++- transformer_engine/jax/csrc/extensions.h | 6 ++-- .../jax/csrc/extensions/attention.cpp | 30 ++++++++++--------- .../jax/csrc/extensions/packing.cpp | 5 ++-- transformer_engine/jax/flax/transformer.py | 8 +++++ 7 files changed, 63 insertions(+), 33 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 55494c42d6..ccb6690a87 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -2,9 +2,12 @@ # # See LICENSE for license information. """conftest for tests/jax""" +import os import jax import pytest +from transformer_engine.transformer_engine_jax import get_device_compute_capability + @pytest.fixture(autouse=True, scope="function") def clear_live_arrays(): @@ -14,3 +17,19 @@ def clear_live_arrays(): yield for arr in jax.live_arrays(): arr.delete() + + +@pytest.fixture(autouse=True, scope="module") +def enable_fused_attn(): + """ + Enable fused attn for hopper+ arch. + Fused attn kernels on pre-hopper arch are not deterministic. + """ + if get_device_compute_capability(0) >= 90: + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0" + yield + if "NVTE_FUSED_ATTN" in os.environ: + del os.environ["NVTE_FUSED_ATTN"] + if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ: + del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index 92a6c80028..ccab73088a 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -15,7 +15,6 @@ from utils import assert_allclose -from transformer_engine.transformer_engine_jax import get_device_compute_capability from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.jax import fp8_autocast, update_collections from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral @@ -43,19 +42,6 @@ FP8_FORMATS = [Format.E4M3, Format.HYBRID] -@pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): - """ - Enable fused attn for hopper+ arch. - Fused attn kernels on pre-hopper arch are not deterministic. - """ - if get_device_compute_capability(0) >= 90: - os.environ["NVTE_FUSED_ATTN"] = "1" - yield - if "NVTE_FUSED_ATTN" in os.environ: - del os.environ["NVTE_FUSED_ATTN"] - - def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08): for key in ref_fd: assert key in test_fd, f"{key} not found in test dict {test_fd}" diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6fa43b7961..76ccec363b 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,8 +3,9 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce +from functools import partial, reduce, cache import operator +import os from typing import Optional, Tuple import warnings @@ -84,6 +85,12 @@ def get_fused_attn_backend(self): self.head_dim, ) + @staticmethod + @cache + def is_non_deterministic_allowed(): + """Check if non-deterministic kernels are allowed""" + return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1"))) + @staticmethod def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): """Parse qkv aval""" @@ -365,6 +372,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) @@ -642,6 +650,8 @@ def abstract( *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) + deterministic = not FusedAttnHelper.is_non_deterministic_allowed() + input_batch = reduce(operator.mul, batch_shape) wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( input_batch, @@ -659,6 +669,7 @@ def abstract( qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training, + deterministic, max_segments_per_seq, ) @@ -764,6 +775,7 @@ def lowering( jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training, + not FusedAttnHelper.is_non_deterministic_allowed(), ) out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index c541fb8afa..c084ab09e9 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -147,6 +147,7 @@ struct CustomCallFusedAttnDescriptor { DType dtype; DType wkspace_dtype; bool is_training; + bool deterministic; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( @@ -154,7 +155,8 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training); + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic); // Transpose @@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq); + bool deterministic, size_t max_segments_per_seq); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 866147b336..1d367f5cc1 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, - size_t max_segments_per_seq) { + bool deterministic, size_t max_segments_per_seq) { // For qkv_packed auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); @@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, -1, -1, deterministic, + query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1, - -1, true, query_workspace_tensor.data(), nullptr); + -1, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; auto dtype = descriptor.dtype; + auto deterministic = descriptor.deterministic; auto max_segments_per_seq = descriptor.max_segments_per_seq; /* Input tensors */ @@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream); + bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q = buffers[0]; auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; @@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q = buffers[0]; @@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, dbias_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true, - workspace_tensor.data(), stream); + dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, + deterministic, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/jax/csrc/extensions/packing.cpp b/transformer_engine/jax/csrc/extensions/packing.cpp index 8c948d0a8f..128564db64 100644 --- a/transformer_engine/jax/csrc/extensions/packing.cpp +++ b/transformer_engine/jax/csrc/extensions/packing.cpp @@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) { + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic) { return PackOpaque(CustomCallFusedAttnDescriptor{ input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, - mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic}); } } // namespace jax diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index d53a4e5202..c62c2bb77d 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods kernel is not available on the system, a warning will be issued, and the module will automatically fall back to the unfused backend. + .. note:: + The DotProductAttention default setting enables non-deterministic kernels for reduced + workspace requirements and faster computation. Users can disable the non-deterministic + kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable: + + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels. + * :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default). + Parameters ---------- head_dim: int