diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 870e869795..fe36b33384 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/setup.py b/setup.py index 856c518f79..996027bd9e 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch"]) + install_reqs.extend(["torch>=2.1"]) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable"]) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index bad09bf32a..f5c186a3bc 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -6,8 +6,8 @@ import pytest import subprocess from pathlib import Path +from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.utils import torch_version import torch diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 56b01f1dbc..9d01527ac5 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -161,6 +161,36 @@ def test_basic_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]]) + def test_chunk_op( + self, + dims: DimsType, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test for ops for which shape of inputs and outputs differ.""" + + # Initialize random data + dims = _to_list(dims) + x_ref = torch.randn(dims, dtype=dtype, device="cpu") + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0) + + # Get chunks. + chunk1, chunk2 = x_fp8.chunk(2, dim=0) + + # Test chunks. + torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0) + torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0) + + # Check shapes. + assert ( + chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk1" + assert ( + chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk2" + def test_inplace_ops( self, dims: DimsType = 23, diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 92250cd322..966115c29e 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -7,16 +7,25 @@ # pylint: disable=wrong-import-position,wrong-import-order import logging +import functools +import sys import importlib import importlib.util -import sys -import torch from importlib.metadata import version +from packaging.version import Version as PkgVersion + +import torch from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release + + def _load_library(): """Load shared library with Transformer Engine C extensions""" module_name = "transformer_engine_torch" @@ -60,6 +69,9 @@ def _load_library(): spec.loader.exec_module(solib) +assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." + + _load_library() from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7666d3f32b..cc92c1377d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens( return _cu_seqlens_cache[(batch_size, max_seqlen)] -@torch.compile +@jit_fuser def pack_tensor( indices: torch.Tensor, tensor: torch.Tensor, @@ -1409,7 +1409,7 @@ def pack_tensor( return packed -@torch.compile +@jit_fuser def pack_2_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1423,7 +1423,7 @@ def pack_2_tensors( return t1_packed, t2_packed -@torch.compile +@jit_fuser def pack_3_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1439,7 +1439,7 @@ def pack_3_tensors( return t1_packed, t2_packed, t3_packed -@torch.compile +@jit_fuser def unpack_tensor( indices: torch.Tensor, dim0: int, @@ -1462,7 +1462,7 @@ def unpack_tensor( return unpacked -@torch.compile +@jit_fuser def unpack_2_tensors( indices: torch.Tensor, dim0: int, @@ -1477,7 +1477,7 @@ def unpack_2_tensors( return t1_unpacked, t2_unpacked -@torch.compile +@jit_fuser def unpack_3_tensors( indices: torch.Tensor, dim0: int, @@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank -@torch.compile +@jit_fuser def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): """ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. @@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): return chunk_ids -@torch.compile +@jit_fuser def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): """Reorder sequence chunk for A2A communication.""" if before_attn: diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index cda3939d6f..aae35ded68 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -10,28 +10,20 @@ # pylint: disable=unnecessary-lambda-assignment -jit_fuser = torch.jit.script +jit_fuser = lambda func: func if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile + # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile + # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": - import torch._dynamo - - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( - f, recursive=recursive - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index b4631eb9a7..20e63e0e63 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -10,13 +10,13 @@ import torch from transformer_engine_torch import FP8TensorMeta +from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor from ..utils import ( canonicalize_device, canonicalize_dtype, devices_match, - torch_version, ) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 49bf4facfa..c9e65bd93a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -402,7 +402,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + return [ + Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) + for split_tensor in func_out + ] if func == aten.new_zeros.default: tensor = args[0] data = tensor._data @@ -412,7 +415,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data @@ -422,7 +425,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index ef21412ca7..b540cd91a1 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -433,7 +433,8 @@ def make_like( data. """ - shape = shape if shape is not None else tensor.shape + if shape is None: + shape = data.shape if data is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() if data is not None: diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 4678097dc4..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,7 +8,6 @@ import math import os from typing import Any, Callable, List, Optional, Tuple -from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext @@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: # Pop NVTX range torch.cuda.nvtx.range_pop() - - -@functools.lru_cache(maxsize=None) -def torch_version() -> tuple[int, ...]: - """Get PyTorch version""" - return PkgVersion(str(torch.__version__)).release