Skip to content

Commit

Permalink
Enforce PyTorch version 2.1 and run attention tests with torch.compile (
Browse files Browse the repository at this point in the history
NVIDIA#1516)

* Enforce torch 2.0 and run attn tests with torch.compile

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* replace torch.compile with jit_fuser

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Feb 28, 2025
1 parent 9588109 commit 303c6d1
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 33 deletions.
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1

exit $FAIL
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"])
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import torch_version

import torch

Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@
# pylint: disable=wrong-import-position,wrong-import-order

import logging
import functools
import sys
import importlib
import importlib.util
import sys
import torch
from importlib.metadata import version
from packaging.version import Version as PkgVersion

import torch

from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension


@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release


def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
Expand Down Expand Up @@ -60,6 +69,9 @@ def _load_library():
spec.loader.exec_module(solib)


assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."


_load_library()
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
Expand Down
16 changes: 8 additions & 8 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)]


@torch.compile
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
Expand All @@ -1409,7 +1409,7 @@ def pack_tensor(
return packed


@torch.compile
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
Expand All @@ -1423,7 +1423,7 @@ def pack_2_tensors(
return t1_packed, t2_packed


@torch.compile
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
Expand All @@ -1439,7 +1439,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed


@torch.compile
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
Expand All @@ -1462,7 +1462,7 @@ def unpack_tensor(
return unpacked


@torch.compile
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
Expand All @@ -1477,7 +1477,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked


@torch.compile
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
Expand Down Expand Up @@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank


@torch.compile
@jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
Expand All @@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
return chunk_ids


@torch.compile
@jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
Expand Down
16 changes: 4 additions & 12 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,20 @@

# pylint: disable=unnecessary-lambda-assignment

jit_fuser = torch.jit.script
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile


# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile


# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo

if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)


def set_jit_fusion_options() -> None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch

from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
torch_version,
)


Expand Down
7 changes: 0 additions & 7 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion

import torch
import transformer_engine.pytorch.cpp_extensions as ext
Expand Down Expand Up @@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:

# Pop NVTX range
torch.cuda.nvtx.range_pop()


@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release

0 comments on commit 303c6d1

Please sign in to comment.