Skip to content

Commit

Permalink
Support latest TransformerEngine (#98)
Browse files Browse the repository at this point in the history
**Description**
Support latest Transformer-Engine

**Major Revision**
- Upgrade TE to latest stable v0.11
- Integrate MS-AMP with TE
  • Loading branch information
tocean authored Oct 18, 2023
1 parent 5b31b70 commit b0ab69c
Show file tree
Hide file tree
Showing 13 changed files with 715 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ on:
jobs:
docker:
name: Docker build ${{ matrix.name }}
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
permissions:
contents: read
packages: write
Expand Down
3 changes: 0 additions & 3 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ jobs:
strategy:
matrix:
include:
# 1.13.0a0+d0d6b1f
- torch: "1.13"
nvcr: 22.09-py3
# 1.14.0a0+410ce96
- torch: "1.14"
nvcr: 22.12-py3
Expand Down
3 changes: 2 additions & 1 deletion dockerfile/torch1.14-cuda11.8.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11

ADD . .
RUN python3 -m pip install . && \
Expand Down
3 changes: 2 additions & 1 deletion dockerfile/torch2.1-cuda12.1.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,9 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@v0.7
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11

ADD . .
RUN python3 -m pip install . && \
Expand Down
2 changes: 2 additions & 0 deletions msamp/operators/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def fp8_gemm(
workspace.shape[0],
accumulate,
use_split_accumulator,
0,
)
else:
# do gemm on device that doesn't supported fp8.
Expand All @@ -165,6 +166,7 @@ def fp8_gemm(
workspace.shape[0],
accumulate,
False,
0,
)

if pN > 0 or pM > 0:
Expand Down
13 changes: 13 additions & 0 deletions msamp/te/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Expose the interface of MS-AMP te package."""

from msamp.te import extension
from msamp.te import modules
from msamp.te.replacer import TeReplacer

del extension
del modules

__all__ = ['TeReplacer']
130 changes: 130 additions & 0 deletions msamp/te/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""MS-AMP te.extension module."""

import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor


class TeExtensionOverrider:
"""An Overrider to override some extension functions in transformer engine."""
dtype_map = {
tex.DType.kFloat8E4M3: Dtypes.kfloat8_e4m3,
tex.DType.kFloat8E5M2: Dtypes.kfloat8_e5m2,
tex.DType.kBFloat16: Dtypes.kbfloat16,
tex.DType.kFloat16: Dtypes.kfloat16,
tex.DType.kFloat32: Dtypes.kfloat32,
}

original_fused_cast_transpose = tex.fused_cast_transpose
original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8
original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused

@staticmethod
@torch.no_grad()
def fused_cast_transpose(input, scale, amax, scale_inv, input_cast, input_transpose, otype):
"""Fused cast and transpose, support ScalingTensor.
Args:
input (torch.Tensor or ScalingTensor): Input tensor.
scale (torch.Tensor): Scale tensor.
amax (torch.Tensor): Amax tensor.
scale_inv (torch.Tensor): Scale inverse tensor.
input_cast (torch.Tensor): Casted input tensor.
input_transpose (torch.Tensor): Transposed input tensor.
otype (tex.DType): Output type.
"""
if isinstance(input, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[otype]
if input_transpose is not None:
sv = input.cast(qtype)
# data should be contiguous, and TE does not check it.
st = sv.t().contiguous()
v, t = sv.value, st.value
input_transpose.data = t
else:
sv = input.cast(qtype)
v = sv.value

if input_cast is not None:
input_cast.data = v
scale_inv.copy_(sv.meta.scale_inv)
else:
TeExtensionOverrider.original_fused_cast_transpose(
input, scale, amax, scale_inv, input_cast, input_transpose, otype
)

@staticmethod
@torch.no_grad()
def fp8_cast_transpose_fused(inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out=None, transpose_out=None):
"""Cast + Transpose with FP8 output, support ScalingTensor.
Args:
inp (torch.Tensor or ScalingTensor): Input tensor.
fp8_meta_tensor: tex.FP8TensorMeta
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors]
dtype: tex.DType
cast_out (torch.Tensor, optional): Output tensor.
transpose_out (torch.Tensor, optional): Output tensor.
Returns:
Union[Tuple[torch.Tensor, torch.Tensor], None]: Output tensor.
"""
if isinstance(inp, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[dtype]
sv = inp.cast(qtype)
v = sv.value
t = sv.t().contiguous().value
if transpose_out is not None:
transpose_out.data = t
if cast_out is not None:
cast_out.data = v
fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv)
return v, t

return TeExtensionOverrider.original_fp8_cast_transpose_fused(
inp, fp8_meta_tensor, fp8_tensor, dtype, cast_out, transpose_out
)

@staticmethod
@torch.no_grad()
def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
"""Cast to fp8, support ScalingTensor.
Args:
inp (torch.Tensor or ScalingTensor): Input tensor.
fp8_meta_tensor (tex.FP8TensorMeta): Fp8 meta tensor.
fp8_tensor (Union[tex.FP8FwdTensors, tex.FP8BwdTensors): Fp8 tensor.
otype (tex.DType): Output type.
out (torch.Tensor, optional): Output tensor.
Returns:
torch.Tensor: Output tensor.
"""
if isinstance(inp, ScalingTensor):
qtype = TeExtensionOverrider.dtype_map[otype]
sv = inp.cast(qtype)
v = sv.value
if out is not None:
out.data = v
fp8_meta_tensor.scale_inv[fp8_tensor].copy_(sv.meta.scale_inv)
return v

if out is None:
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype)
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)

@staticmethod
def override():
"""Override transformer engine extension functions."""
tex.fused_cast_transpose = TeExtensionOverrider.fused_cast_transpose
te.cpp_extensions.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8
te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused


TeExtensionOverrider.override()
Loading

0 comments on commit b0ab69c

Please sign in to comment.