diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 870e869795..fe36b33384 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1 -NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 +NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1 exit $FAIL diff --git a/setup.py b/setup.py index 856c518f79..996027bd9e 100644 --- a/setup.py +++ b/setup.py @@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: # Framework-specific requirements if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: - install_reqs.extend(["torch"]) + install_reqs.extend(["torch>=2.1"]) # Blackwell is not supported as of Triton 3.2.0, need custom internal build # install_reqs.append("triton") test_reqs.extend(["numpy", "torchvision", "prettytable"]) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index bad09bf32a..f5c186a3bc 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -6,8 +6,8 @@ import pytest import subprocess from pathlib import Path +from transformer_engine.pytorch import torch_version from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.utils import torch_version import torch diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 92250cd322..966115c29e 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -7,16 +7,25 @@ # pylint: disable=wrong-import-position,wrong-import-order import logging +import functools +import sys import importlib import importlib.util -import sys -import torch from importlib.metadata import version +from packaging.version import Version as PkgVersion + +import torch from transformer_engine.common import get_te_path, is_package_installed from transformer_engine.common import _get_sys_extension +@functools.lru_cache(maxsize=None) +def torch_version() -> tuple[int, ...]: + """Get PyTorch version""" + return PkgVersion(str(torch.__version__)).release + + def _load_library(): """Load shared library with Transformer Engine C extensions""" module_name = "transformer_engine_torch" @@ -60,6 +69,9 @@ def _load_library(): spec.loader.exec_module(solib) +assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." + + _load_library() from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 7666d3f32b..cc92c1377d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens( return _cu_seqlens_cache[(batch_size, max_seqlen)] -@torch.compile +@jit_fuser def pack_tensor( indices: torch.Tensor, tensor: torch.Tensor, @@ -1409,7 +1409,7 @@ def pack_tensor( return packed -@torch.compile +@jit_fuser def pack_2_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1423,7 +1423,7 @@ def pack_2_tensors( return t1_packed, t2_packed -@torch.compile +@jit_fuser def pack_3_tensors( indices: torch.Tensor, t1: torch.Tensor, @@ -1439,7 +1439,7 @@ def pack_3_tensors( return t1_packed, t2_packed, t3_packed -@torch.compile +@jit_fuser def unpack_tensor( indices: torch.Tensor, dim0: int, @@ -1462,7 +1462,7 @@ def unpack_tensor( return unpacked -@torch.compile +@jit_fuser def unpack_2_tensors( indices: torch.Tensor, dim0: int, @@ -1477,7 +1477,7 @@ def unpack_2_tensors( return t1_unpacked, t2_unpacked -@torch.compile +@jit_fuser def unpack_3_tensors( indices: torch.Tensor, dim0: int, @@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank( return cu_seqlens_on_cp_rank -@torch.compile +@jit_fuser def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): """ Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing. @@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous): return chunk_ids -@torch.compile +@jit_fuser def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn): """Reorder sequence chunk for A2A communication.""" if before_attn: diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index cda3939d6f..aae35ded68 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -10,28 +10,20 @@ # pylint: disable=unnecessary-lambda-assignment -jit_fuser = torch.jit.script +jit_fuser = lambda func: func if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): jit_fuser = torch.compile + # See: https://github.com/NVIDIA/TransformerEngine/issues/597 dropout_fuser = torch.jit.script if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))): dropout_fuser = torch.compile + # Decorator to disable Torch Dynamo # See: https://github.com/NVIDIA/TransformerEngine/issues/308 -no_torch_dynamo = lambda recursive=True: lambda func: func -if torch.__version__ >= "2": - import torch._dynamo - - if torch.__version__ >= "2.1": - no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable( - f, recursive=recursive - ) - else: - # no "recursive" option in pyTorch 2.0 - it acts as if recursive was True - no_torch_dynamo = lambda recursive=True: torch._dynamo.disable +no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive) def set_jit_fusion_options() -> None: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index b4631eb9a7..20e63e0e63 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -10,13 +10,13 @@ import torch from transformer_engine_torch import FP8TensorMeta +from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor from ..utils import ( canonicalize_device, canonicalize_dtype, devices_match, - torch_version, ) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 4678097dc4..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -8,7 +8,6 @@ import math import os from typing import Any, Callable, List, Optional, Tuple -from packaging.version import Version as PkgVersion import torch import transformer_engine.pytorch.cpp_extensions as ext @@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None: # Pop NVTX range torch.cuda.nvtx.range_pop() - - -@functools.lru_cache(maxsize=None) -def torch_version() -> tuple[int, ...]: - """Get PyTorch version""" - return PkgVersion(str(torch.__version__)).release