Skip to content

Commit

Permalink
[JAX] Support non-deterministic algo for cuDNN FA (NVIDIA#1056)
Browse files Browse the repository at this point in the history
* Support non-deterministic algo

Signed-off-by: Reese Wang <[email protected]>

* Refine the helper function name

Signed-off-by: Reese Wang <[email protected]>

* Move fixture to conftest.py

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
Co-authored-by: Phuong Nguyen <[email protected]>
  • Loading branch information
zlsh80826 and phu0ngng authored Aug 8, 2024
1 parent 6717554 commit 86f27e1
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 33 deletions.
19 changes: 19 additions & 0 deletions tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
#
# See LICENSE for license information.
"""conftest for tests/jax"""
import os
import jax
import pytest

from transformer_engine.transformer_engine_jax import get_device_compute_capability


@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
Expand All @@ -14,3 +17,19 @@ def clear_live_arrays():
yield
for arr in jax.live_arrays():
arr.delete()


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "0"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]
if "NVTE_ALLOW_NONDETERMINISTIC_ALGO" in os.environ:
del os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"]
14 changes: 0 additions & 14 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from utils import assert_allclose

from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
Expand Down Expand Up @@ -43,19 +42,6 @@
FP8_FORMATS = [Format.E4M3, Format.HYBRID]


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
"""
if get_device_compute_capability(0) >= 90:
os.environ["NVTE_FUSED_ATTN"] = "1"
yield
if "NVTE_FUSED_ATTN" in os.environ:
del os.environ["NVTE_FUSED_ATTN"]


def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
assert key in test_fd, f"{key} not found in test dict {test_fd}"
Expand Down
14 changes: 13 additions & 1 deletion transformer_engine/jax/cpp_extensions/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
# See LICENSE for license information.
"""JAX/TE custom ops for attention"""
from dataclasses import dataclass
from functools import partial, reduce
from functools import partial, reduce, cache
import operator
import os
from typing import Optional, Tuple
import warnings

Expand Down Expand Up @@ -84,6 +85,12 @@ def get_fused_attn_backend(self):
self.head_dim,
)

@staticmethod
@cache
def is_non_deterministic_allowed():
"""Check if non-deterministic kernels are allowed"""
return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))

@staticmethod
def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout):
"""Parse qkv aval"""
Expand Down Expand Up @@ -365,6 +372,7 @@ def lowering(
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)

out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False)
Expand Down Expand Up @@ -642,6 +650,8 @@ def abstract(
*bias_batch_shape, bias_heads, _, _ = bias_aval.shape
bias_batch = reduce(operator.mul, bias_batch_shape)

deterministic = not FusedAttnHelper.is_non_deterministic_allowed()

input_batch = reduce(operator.mul, batch_shape)
wkspace_shape, wkspace_dtype = transformer_engine_jax.get_fused_attn_bwd_workspace_sizes(
input_batch,
Expand All @@ -659,6 +669,7 @@ def abstract(
qkv_layout,
jax_dtype_to_te_dtype(q_aval.dtype),
is_training,
deterministic,
max_segments_per_seq,
)

Expand Down Expand Up @@ -764,6 +775,7 @@ def lowering(
jax_dtype_to_te_dtype(q_aval.dtype),
jax_dtype_to_te_dtype(wkspace_aval.dtype),
is_training,
not FusedAttnHelper.is_non_deterministic_allowed(),
)

out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False)
Expand Down
6 changes: 4 additions & 2 deletions transformer_engine/jax/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,16 @@ struct CustomCallFusedAttnDescriptor {
DType dtype;
DType wkspace_dtype;
bool is_training;
bool deterministic;
};

pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen,
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training);
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic);

// Transpose

Expand Down Expand Up @@ -260,7 +262,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq);
bool deterministic, size_t max_segments_per_seq);

void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len);

Expand Down
30 changes: 16 additions & 14 deletions transformer_engine/jax/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type,
NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training,
size_t max_segments_per_seq) {
bool deterministic, size_t max_segments_per_seq) {
// For qkv_packed
auto qkv_shape = std::vector<size_t>{input_batch * q_max_seqlen, 3, attn_heads, head_dim};
auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype);
Expand Down Expand Up @@ -392,13 +392,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
auto dummy_ragged_offset_tensor =
TensorWrapper(nullptr, std::vector<size_t>{num_segments + 1}, DType::kInt32);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
nvte_fused_attn_bwd_qkvpacked(
qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability,
qkv_layout, bias_type, mask_type, -1, -1, true, query_workspace_tensor.data(), nullptr);
nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
s_tensor.data(), // not used for F16
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, deterministic,
query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
nvte_fused_attn_bwd_kvpacked(
q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(),
Expand All @@ -408,7 +409,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(),
dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen,
kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
-1, deterministic, query_workspace_tensor.data(), nullptr);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(),
doutput_tensor.data(),
Expand All @@ -419,7 +420,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(),
dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen,
scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, -1,
-1, true, query_workspace_tensor.data(), nullptr);
-1, deterministic, query_workspace_tensor.data(), nullptr);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
Expand Down Expand Up @@ -467,6 +468,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
auto bias_type = descriptor.bias_type;
auto mask_type = descriptor.mask_type;
auto dtype = descriptor.dtype;
auto deterministic = descriptor.deterministic;
auto max_segments_per_seq = descriptor.max_segments_per_seq;

/* Input tensors */
Expand Down Expand Up @@ -539,7 +541,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
s_tensor.data(), // not used for F16
&aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(),
q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout,
bias_type, mask_type, -1, -1, true, workspace_tensor.data(), stream);
bias_type, mask_type, -1, -1, deterministic, workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) {
auto q = buffers[0];
auto q_shape = std::vector<size_t>{input_batch * q_max_seqlen, attn_heads, head_dim};
Expand All @@ -566,7 +568,7 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
&aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(),
q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, deterministic,
workspace_tensor.data(), stream);
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) {
auto q = buffers[0];
Expand Down Expand Up @@ -602,8 +604,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque,
dbias_tensor.data(), q_cu_seqlens_tensor.data(),
kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(),
k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor,
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1, true,
workspace_tensor.data(), stream);
dropout_probability, qkv_layout, bias_type, mask_type, -1, -1,
deterministic, workspace_tensor.data(), stream);
} else {
NVTE_ERROR("Unsupported qkv_layout.");
}
Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/jax/csrc/extensions/packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,12 @@ pybind11::bytes PackCustomCallFusedAttnDescriptor(
size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim,
size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor,
float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type,
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training) {
NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training,
bool deterministic) {
return PackOpaque(CustomCallFusedAttnDescriptor{
input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads,
head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type,
mask_type, qkv_layout, dtype, wkspace_dtype, is_training});
mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic});
}

} // namespace jax
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/jax/flax/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,14 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
kernel is not available on the system, a warning will be issued, and the module will
automatically fall back to the unfused backend.
.. note::
The DotProductAttention default setting enables non-deterministic kernels for reduced
workspace requirements and faster computation. Users can disable the non-deterministic
kernels via the :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO` environment variable:
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=0` to allow only deterministic kernels.
* :attr:`NVTE_ALLOW_NONDETERMINISTIC_ALGO=1` to allow non-deterministic kernels (default).
Parameters
----------
head_dim: int
Expand Down

0 comments on commit 86f27e1

Please sign in to comment.