Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from NVIDIA:main #81

Merged
merged 2 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || FAIL=1
NVTE_TORCH_COMPILE=0 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || FAIL=1

exit $FAIL
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch"])
install_reqs.extend(["torch>=2.1"])
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable"])
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import torch_version

import torch

Expand Down
30 changes: 30 additions & 0 deletions tests/pytorch/test_float8tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,36 @@ def test_basic_ops(
with pytest.raises(AssertionError):
torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols)

@pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]])
def test_chunk_op(
self,
dims: DimsType,
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 3.5,
dtype: torch.dtype = torch.float32,
) -> None:
"""Test for ops for which shape of inputs and outputs differ."""

# Initialize random data
dims = _to_list(dims)
x_ref = torch.randn(dims, dtype=dtype, device="cpu")
x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0)

# Get chunks.
chunk1, chunk2 = x_fp8.chunk(2, dim=0)

# Test chunks.
torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0)
torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0)

# Check shapes.
assert (
chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk1"
assert (
chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:]
), "Wrong shape for chunk2"

def test_inplace_ops(
self,
dims: DimsType = 23,
Expand Down
16 changes: 14 additions & 2 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@
# pylint: disable=wrong-import-position,wrong-import-order

import logging
import functools
import sys
import importlib
import importlib.util
import sys
import torch
from importlib.metadata import version
from packaging.version import Version as PkgVersion

import torch

from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension


@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release


def _load_library():
"""Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
Expand Down Expand Up @@ -60,6 +69,9 @@ def _load_library():
spec.loader.exec_module(solib)


assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}."


_load_library()
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
Expand Down
16 changes: 8 additions & 8 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ def _get_full_cu_seqlens(
return _cu_seqlens_cache[(batch_size, max_seqlen)]


@torch.compile
@jit_fuser
def pack_tensor(
indices: torch.Tensor,
tensor: torch.Tensor,
Expand All @@ -1409,7 +1409,7 @@ def pack_tensor(
return packed


@torch.compile
@jit_fuser
def pack_2_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
Expand All @@ -1423,7 +1423,7 @@ def pack_2_tensors(
return t1_packed, t2_packed


@torch.compile
@jit_fuser
def pack_3_tensors(
indices: torch.Tensor,
t1: torch.Tensor,
Expand All @@ -1439,7 +1439,7 @@ def pack_3_tensors(
return t1_packed, t2_packed, t3_packed


@torch.compile
@jit_fuser
def unpack_tensor(
indices: torch.Tensor,
dim0: int,
Expand All @@ -1462,7 +1462,7 @@ def unpack_tensor(
return unpacked


@torch.compile
@jit_fuser
def unpack_2_tensors(
indices: torch.Tensor,
dim0: int,
Expand All @@ -1477,7 +1477,7 @@ def unpack_2_tensors(
return t1_unpacked, t2_unpacked


@torch.compile
@jit_fuser
def unpack_3_tensors(
indices: torch.Tensor,
dim0: int,
Expand Down Expand Up @@ -1645,7 +1645,7 @@ def get_cu_seqlens_on_cp_rank(
return cu_seqlens_on_cp_rank


@torch.compile
@jit_fuser
def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
"""
Context parallelism assigns two discontiguous sequence chunks to each GPU for load balancing.
Expand All @@ -1665,7 +1665,7 @@ def get_seq_chunk_ids_for_reordering(cp_size, device, to_contiguous):
return chunk_ids


@torch.compile
@jit_fuser
def reorder_seq_chunks_for_a2a(x, chunk_ids_for_a2a, seq_dim, cp_size, before_attn):
"""Reorder sequence chunk for A2A communication."""
if before_attn:
Expand Down
16 changes: 4 additions & 12 deletions transformer_engine/pytorch/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,20 @@

# pylint: disable=unnecessary-lambda-assignment

jit_fuser = torch.jit.script
jit_fuser = lambda func: func
if torch.__version__ >= "2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
jit_fuser = torch.compile


# See: https://github.com/NVIDIA/TransformerEngine/issues/597
dropout_fuser = torch.jit.script
if torch.__version__ >= "2.2" and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))):
dropout_fuser = torch.compile


# Decorator to disable Torch Dynamo
# See: https://github.com/NVIDIA/TransformerEngine/issues/308
no_torch_dynamo = lambda recursive=True: lambda func: func
if torch.__version__ >= "2":
import torch._dynamo

if torch.__version__ >= "2.1":
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(
f, recursive=recursive
)
else:
# no "recursive" option in pyTorch 2.0 - it acts as if recursive was True
no_torch_dynamo = lambda recursive=True: torch._dynamo.disable
no_torch_dynamo = lambda recursive=True: lambda f: torch._dynamo.disable(f, recursive=recursive)


def set_jit_fusion_options() -> None:
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import torch

from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
torch_version,
)


Expand Down
9 changes: 6 additions & 3 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
[data] + list(args[1:]),
kwargs,
)
return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
return [
Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape)
for split_tensor in func_out
]
if func == aten.new_zeros.default:
tensor = args[0]
data = tensor._data
Expand All @@ -412,7 +415,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
Expand All @@ -422,7 +425,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
[data] + list(args[1:]),
kwargs,
)
return Float8Tensor.make_like(tensor, data=func_out)
return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/tensor/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def make_like(
data.

"""
shape = shape if shape is not None else tensor.shape
if shape is None:
shape = data.shape if data is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
if data is not None:
Expand Down
7 changes: 0 additions & 7 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from packaging.version import Version as PkgVersion

import torch
import transformer_engine.pytorch.cpp_extensions as ext
Expand Down Expand Up @@ -387,9 +386,3 @@ def nvtx_range_pop(msg: Optional[str] = None) -> None:

# Pop NVTX range
torch.cuda.nvtx.range_pop()


@functools.lru_cache(maxsize=None)
def torch_version() -> tuple[int, ...]:
"""Get PyTorch version"""
return PkgVersion(str(torch.__version__)).release
Loading