Skip to content

Commit fc4d718

Browse files
authored
Merge branch 'NVIDIA:main' into misc_2.4
2 parents 0d8fb83 + 2761205 commit fc4d718

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1035
-938
lines changed

3rdparty/cudnn-frontend

Submodule cudnn-frontend updated 69 files

build_tools/build_ext.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,24 @@ def run(self) -> None:
130130
super().run()
131131
self.extensions = all_extensions
132132

133-
# Ensure that binaries are not in global package space.
133+
# Ensure that shared objects files for source and PyPI installations live
134+
# in separate directories to avoid conflicts during install and runtime.
134135
lib_dir = (
135136
"wheel_lib"
136137
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
137138
else ""
138139
)
139-
target_dir = install_dir / "transformer_engine" / lib_dir
140-
target_dir.mkdir(exist_ok=True, parents=True)
141140

142-
for ext in Path(self.build_lib).glob("*.so"):
143-
self.copy_file(ext, target_dir)
144-
os.remove(ext)
141+
# Ensure that binaries are not in global package space.
142+
# For editable/inplace builds this is not a concern as
143+
# the SOs will be in a local directory anyway.
144+
if not self.inplace:
145+
target_dir = install_dir / "transformer_engine" / lib_dir
146+
target_dir.mkdir(exist_ok=True, parents=True)
147+
148+
for ext in Path(self.build_lib).glob("*.so"):
149+
self.copy_file(ext, target_dir)
150+
os.remove(ext)
145151

146152
def build_extensions(self):
147153
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers

build_tools/jax.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import setuptools
1111

12-
from .utils import get_cuda_include_dirs, all_files_in_dir
12+
from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
1313
from typing import List
1414

1515

@@ -41,7 +41,7 @@ def setup_jax_extension(
4141
# Source files
4242
csrc_source_files = Path(csrc_source_files)
4343
extensions_dir = csrc_source_files / "extensions"
44-
sources = all_files_in_dir(extensions_dir, ".cpp")
44+
sources = all_files_in_dir(extensions_dir, name_extension="cpp")
4545

4646
# Header files
4747
include_dirs = get_cuda_include_dirs()
@@ -57,6 +57,11 @@ def setup_jax_extension(
5757

5858
# Compile flags
5959
cxx_flags = ["-O3"]
60+
if debug_build_enabled():
61+
cxx_flags.append("-g")
62+
cxx_flags.append("-UNDEBUG")
63+
else:
64+
cxx_flags.append("-g0")
6065

6166
# Define TE/JAX as a Pybind11Extension
6267
from pybind11.setup_helpers import Pybind11Extension

build_tools/pytorch.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import setuptools
1010

11-
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs
11+
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled
1212

1313

1414
def setup_pytorch_extension(
@@ -19,11 +19,7 @@ def setup_pytorch_extension(
1919
"""Setup CUDA extension for PyTorch support"""
2020

2121
# Source files
22-
csrc_source_files = Path(csrc_source_files)
23-
extensions_dir = csrc_source_files / "extensions"
24-
sources = [
25-
csrc_source_files / "common.cpp",
26-
] + all_files_in_dir(extensions_dir)
22+
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
2723

2824
# Header files
2925
include_dirs = get_cuda_include_dirs()
@@ -37,10 +33,12 @@ def setup_pytorch_extension(
3733
)
3834

3935
# Compiler flags
40-
cxx_flags = [
41-
"-O3",
42-
"-fvisibility=hidden",
43-
]
36+
cxx_flags = ["-O3", "-fvisibility=hidden"]
37+
if debug_build_enabled():
38+
cxx_flags.append("-g")
39+
cxx_flags.append("-UNDEBUG")
40+
else:
41+
cxx_flags.append("-g0")
4442

4543
# Version-dependent CUDA options
4644
try:

build_tools/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def all_files_in_dir(path, name_extension=None):
5656
all_files = []
5757
for dirname, _, names in os.walk(path):
5858
for name in names:
59-
if name_extension is not None and name_extension not in name:
59+
if name_extension is not None and not name.endswith(f".{name_extension}"):
6060
continue
6161
all_files.append(Path(dirname, name))
6262
return all_files

examples/jax/encoder/test_multiprocessing_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def test_te_bf16(self):
609609
def test_te_delayed_scaling_fp8(self):
610610
"""Test Transformer Engine with DelayedScaling FP8"""
611611
result = self.exec(True, "DelayedScaling")
612-
assert result[0] < 0.505 and result[1] > 0.754
612+
assert result[0] < 0.505 and result[1] > 0.753
613613

614614
@unittest.skipIf(
615615
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"

qa/L0_pytorch_unittest/test.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
4242
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
4343
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
4444
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
45-
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
46-
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
45+
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
46+
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
4747

4848
if [ "$RET" -ne 0 ]; then
4949
echo "Error in the following test cases:$FAILED_CASES"

qa/L3_pytorch_FA_versions_test/test.sh

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,17 @@ mkdir -p "$XML_LOG_DIR"
1111
pip3 install pytest==8.2.1
1212

1313
# Limit parallel build jobs to avoid overwhelming system resources
14-
export MAX_JOBS=4
14+
export MAX_JOBS=32
1515

1616
# Iterate over Flash Attention versions
1717
sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"`
18+
export FLASH_ATTN_CUDA_ARCHS=$sm_arch
1819
if [ $sm_arch -gt 90 ]
1920
then
2021
FA_versions=(2.7.3)
21-
else
22-
FA_versions=(2.3.0 2.4.1 2.5.7 2.7.3 3.0.0b1)
22+
elif [ $sm_arch -eq 90 ]
23+
then
24+
FA_versions=(2.5.7 2.7.3 3.0.0b1)
2325
fi
2426

2527
for fa_version in "${FA_versions[@]}"

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
123123
)
124124
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
125125
# install_reqs.append("triton")
126-
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
126+
test_reqs.extend(["numpy", "torchvision"])
127127
if "jax" in frameworks:
128128
setup_reqs.extend(["jax[cuda12]", "flax>=0.7.1"])
129129
install_reqs.extend(["jax", "flax>=0.7.1"])
@@ -144,7 +144,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
144144
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
145145
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
146146
ext_modules = []
147-
cmdclass = {}
148147
package_data = {}
149148
include_package_data = False
150149
setup_requires = []
@@ -156,7 +155,6 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
156155
else:
157156
setup_requires, install_requires, test_requires = setup_requirements()
158157
ext_modules = [setup_common_extension()]
159-
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
160158
package_data = {"": ["VERSION.txt"]}
161159
include_package_data = True
162160
extras_require = {"test": test_requires}

tests/jax/utils.py

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import jax.numpy as jnp
1414
import numpy as np
1515
from flax import linen as nn
16-
from flax.linen import partitioning as nn_partitioning
1716
from flax.linen.attention import combine_masks
1817
from jax import lax, vmap
1918
from jax import nn as jax_nn
@@ -316,16 +315,22 @@ def __call__(self, inputs: Array) -> Array:
316315

317316
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
318317
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
319-
kernel = nn_partitioning.param_with_axes(
320-
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
318+
kernel = self.param(
319+
"kernel",
320+
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
321+
kernel_param_shape,
322+
self.dtype,
321323
)
322324

323325
kernel = jnp.asarray(kernel, input_dtype)
324326
kernel = jnp.reshape(kernel, kernel_shape)
325327

326328
if self.use_bias:
327-
bias = nn_partitioning.param_with_axes(
328-
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
329+
bias = self.param(
330+
"bias",
331+
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
332+
self.features,
333+
self.dtype,
329334
)
330335
bias = bias.astype(input_dtype)
331336
else:
@@ -422,9 +427,9 @@ def __call__(self, inputs, deterministic: bool = False):
422427
) # Broadcast along length.
423428

424429
if self.transpose_batch_sequence:
425-
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
430+
x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
426431
else:
427-
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
432+
x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
428433
output = DenseGeneral(
429434
inputs.shape[-1],
430435
dtype=self.dtype,
@@ -688,21 +693,13 @@ def qkv_init(key, shape, dtype):
688693
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
689694

690695
if self.transpose_batch_sequence:
691-
query = nn_partitioning.with_sharding_constraint(
692-
query, ("length", "batch", "heads", "kv")
693-
)
694-
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
695-
value = nn_partitioning.with_sharding_constraint(
696-
value, ("length", "batch", "heads", "kv")
697-
)
696+
query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
697+
key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
698+
value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
698699
else:
699-
query = nn_partitioning.with_sharding_constraint(
700-
query, ("batch", "length", "heads", "kv")
701-
)
702-
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
703-
value = nn_partitioning.with_sharding_constraint(
704-
value, ("batch", "length", "heads", "kv")
705-
)
700+
query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
701+
key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
702+
value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
706703

707704
if decode:
708705
# Detect if we're initializing by absence of existing cache data.
@@ -809,9 +806,9 @@ def qkv_init(key, shape, dtype):
809806
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
810807

811808
if self.transpose_batch_sequence:
812-
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
809+
x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
813810
else:
814-
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
811+
x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
815812

816813
# Back to the original inputs dimensions.
817814

@@ -857,17 +854,23 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
857854
input_dtype = x.dtype
858855
features = x.shape[-1]
859856

860-
scale = nn_partitioning.param_with_axes(
861-
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
857+
scale = self.param(
858+
"scale",
859+
nn.with_logical_partitioning(self.scale_init, ("embed",)),
860+
(features,),
861+
self.dtype,
862862
)
863863
x_ = x.astype(jnp.float32)
864864
if self.layernorm_type == "layernorm":
865865
mean = jnp.mean(x_, axis=-1, keepdims=True)
866866
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
867867
y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
868868

869-
bias = nn_partitioning.param_with_axes(
870-
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
869+
bias = self.param(
870+
"ln_bias",
871+
nn.with_logical_partitioning(self.bias_init, ("embed",)),
872+
(features,),
873+
self.dtype,
871874
)
872875
bias = jnp.asarray(bias, input_dtype)
873876

@@ -976,12 +979,11 @@ def __call__(self, qlen, klen, bidirectional=True):
976979
num_buckets=self.num_buckets,
977980
max_distance=self.max_distance,
978981
)
979-
relative_attention_bias = nn_partitioning.param_with_axes(
982+
relative_attention_bias = self.param(
980983
"rel_embedding",
981-
self.embedding_init,
984+
nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
982985
(self.num_heads, self.num_buckets),
983986
jnp.float32,
984-
axes=("heads", "relpos_buckets"),
985987
)
986988

987989
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
@@ -1559,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
15591561
"""
15601562
src_values = {}
15611563
for key, value in jax.tree_util.tree_leaves_with_path(src):
1562-
normalized_key = sep.join(x.key for x in key)
1564+
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
1565+
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
15631566
src_values[normalized_key] = value
15641567

15651568
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
15661569
synced_dst_values = []
15671570

15681571
for key, value in flatten_dst:
1569-
normalized_key = sep.join(x.key for x in key)
1572+
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
1573+
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
15701574
if normalized_key in transformations:
15711575
corresponding_src_key = transformations[normalized_key]
15721576
else:

tests/pytorch/test_numerics.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
4343
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
4444
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
45-
from transformer_engine.pytorch.utils import get_device_compute_capability
45+
from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version
4646
from transformer_engine.common import recipe
4747
import transformer_engine_torch as tex
4848

@@ -2293,6 +2293,12 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
22932293
pytest.skip("FusedAttention and FlashAttention do not support FP32")
22942294
if use_RoPE:
22952295
pytest.skip("KV cache does not support starting positions for RoPE")
2296+
if (
2297+
backend == "FusedAttention"
2298+
and get_device_compute_capability() == (8, 9)
2299+
and get_cudnn_version() < (9, 11, 0)
2300+
):
2301+
pytest.skip("Skip KV cache for sm89 and cuDNN < 9.11")
22962302

22972303
os.environ["NVTE_FLASH_ATTN"] = "0"
22982304
os.environ["NVTE_FUSED_ATTN"] = "0"

transformer_engine/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111

1212
try:
1313
from . import pytorch
14-
except (ImportError, StopIteration) as e:
14+
except ImportError as e:
1515
pass
1616

1717
try:
1818
from . import jax
19-
except (ImportError, StopIteration) as e:
19+
except ImportError as e:
2020
pass
2121

2222
__version__ = str(metadata.version("transformer_engine"))

0 commit comments

Comments
 (0)