Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update te to latest stable version #157

Merged
merged 22 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions .github/workflows/build-image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ on:
jobs:
docker:
name: Docker build ${{ matrix.name }}
runs-on: ubuntu-20.04
runs-on: [self-hosted, linux, x64, gpu]
timeout-minutes: 600
permissions:
contents: read
packages: write
Expand All @@ -23,21 +24,28 @@ jobs:
include:
- name: torch1.14-cuda11.8
tags: ghcr.io/azure/msamp:main-cuda11.8
- name: torch2.1-cuda12.1
tags: ghcr.io/azure/msamp:main-cuda12.1,ghcr.io/azure/msamp:latest
- name: torch2.1-cuda12.2
tags: ghcr.io/azure/msamp:main-cuda12.2,ghcr.io/azure/msamp:latest
steps:
- name: Checkout
uses: actions/checkout@v2
with:
submodules: true
- name: Free disk space
run: |
mkdir /tmp/emptydir
mkdir -p /tmp/emptydir
for dir in /usr/share/swift /usr/share/dotnet /usr/local/share/powershell /usr/local/share/chromium /usr/local/lib/android /opt/ghc; do
sudo rsync -a --delete /tmp/emptydir/ ${dir}
done
sudo apt-get clean
sudo docker rmi $(sudo docker images --format "{{.Repository}}:{{.Tag}}" --filter=reference="node" --filter=reference="buildpack-deps")

# Check if Docker images exist before trying to remove them
if sudo docker images -q --filter=reference="node" --filter=reference="buildpack-deps" | grep -q .; then
sudo docker rmi $(sudo docker images --format "{{.Repository}}:{{.Tag}}" --filter=reference="node" --filter=reference="buildpack-deps")
else
echo "No Docker images found with the specified references."
fi

df -h
- name: Prepare metadata
id: metadata
Expand All @@ -48,8 +56,6 @@ jobs:
fi
DOCKERFILE=dockerfile/${{ matrix.name }}.dockerfile

BUILD_ARGS="NUM_MAKE_JOBS=8"

CACHE_FROM="type=registry,ref=$(cut -d, -f1 <<< ${TAGS})"
CACHE_TO=""
if [[ "${{ github.event_name }}" != "pull_request" ]]; then
Expand Down
12 changes: 10 additions & 2 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@ jobs:
# 1.14.0a0+410ce96
- torch: "1.14"
nvcr: 22.12-py3
dir: torch1
# 2.1.0a0+fe05266f
- torch: "2.1"
nvcr: 23.04-py3
nvcr: 23.10-py3
dir: torch2
container:
image: nvcr.io/nvidia/pytorch:${{ matrix.nvcr }}
options: --privileged --ipc=host --gpus=all
Expand All @@ -27,9 +29,10 @@ jobs:
uses: actions/checkout@v2
with:
submodules: true
path: ${{ matrix.dir }}
- name: Install MSCCL
run: |
cd third_party/msccl
cd ${{ matrix.dir }}/third_party/msccl
make -j src.build NVCC_GENCODE="\
-gencode=arch=compute_70,code=sm_70 \
-gencode=arch=compute_80,code=sm_80 \
Expand All @@ -38,16 +41,21 @@ jobs:
- name: Install dependencies
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
export DEBIAN_FRONTEND=noninteractive
python3 -m pip install --upgrade pip
apt-get update && apt-get install -y python3-mpi4py
cd ${{ matrix.dir }}/
python3 -m pip install .[test]
make postinstall
- name: Run code lint
run: |
cd ${{ matrix.dir }}/
python3 setup.py lint
- name: Run unit tests
run: |
export LD_LIBRARY_PATH="/usr/local/lib:$LD_LIBRARY_PATH"
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
cd ${{ matrix.dir }}/
python3 setup.py test
# - name: Report coverage results
# run: |
Expand Down
4 changes: 2 additions & 2 deletions dockerfile/torch1.14-cuda11.8.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ RUN apt-get update && \
util-linux \
vim \
wget \
python3-mpi4py \
&& \
apt-get autoremove && \
apt-get clean && \
Expand All @@ -48,9 +49,8 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

ADD . .
RUN python3 -m pip install . && \
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
FROM nvcr.io/nvidia/pytorch:23.04-py3
FROM nvcr.io/nvidia/pytorch:23.10-py3

# Ubuntu: 20.04
# Ubuntu: 22.04
# Python: 3.8
# CUDA: 12.1.0
# cuDNN: 8.9.0
# CUDA: 12.2.0
# cuDNN: 8.9.5
# NCCL: v2.16.2-1 + FP8 Support
# PyTorch: 2.1.0a0+fe05266f

Expand All @@ -29,6 +29,7 @@ RUN apt-get update && \
util-linux \
vim \
wget \
python3-mpi4py \
&& \
apt-get autoremove && \
apt-get clean && \
Expand All @@ -48,9 +49,8 @@ RUN cd third_party/msccl && \
-gencode=arch=compute_90,code=sm_90" && \
make install
# cache TE build to save time in CI
ENV MAX_JOBS=1
RUN python3 -m pip install --upgrade pip && \
python3 -m pip install flash-attn==1.0.9 git+https://github.com/NVIDIA/TransformerEngine.git@v0.11
python3 -m pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

ADD . .
RUN python3 -m pip install . && \
Expand Down
4 changes: 2 additions & 2 deletions examples/cifar10_deepspeed_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(
num_heads,
hidden_dropout=drop,
attention_dropout=attn_drop,
self_attn_mask_type='padding',
self_attn_mask_type='no_mask',
layer_type='encoder',
init_method=init_method,
output_layer_init_method=init_method,
Expand All @@ -152,7 +152,7 @@ def forward(self, x):
padding = batch_size % 16 > 0
if padding:
x = F.pad(x, (0, 0, 0, 16 - batch_size % 16))
out = self.m(x, attention_mask=None)
out = self.m(x)
if padding:
out = out[:, :batch_size]
return out
Expand Down
8 changes: 5 additions & 3 deletions msamp/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ def _fp8_post_backward_hook(state, handle, *unused):
if accumulate_grad and torch.count_nonzero(state._flat_param._saved_grad_shard).item() > 0:
raise NotImplementedError('accumulate_grad is not supported yet for fp8')

old_communication_hook = state._communication_hook
state._communication_hook = state._get_fp8_comm_hook()
comm_hook_attr = '_communication_hook' if hasattr(state, '_communication_hook') else '_comm_hook'

old_communication_hook = getattr(state, comm_hook_attr)
setattr(state, comm_hook_attr, state._get_fp8_comm_hook())
old_post_backward_hook(state, handle, *unused)
state._communication_hook = old_communication_hook
setattr(state, comm_hook_attr, old_communication_hook)
27 changes: 11 additions & 16 deletions msamp/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,19 @@

"""MS-AMP fsdp.flat_param module."""

from typing import Optional, Sequence

import torch
import torch.nn as nn
from torch.distributed.fsdp.flat_param import FlatParamHandle


class FP8FlatParamHandle(FlatParamHandle):
"""A handle for a flat parameter which may have fp32 and fp8."""
def _init_flat_param(
self,
params: Sequence[Optional[nn.Parameter]],
module: nn.Module,
use_orig_params: bool,
) -> None:
"""Initialize the flat parameter and save fp8 related metadata."""
super()._init_flat_param(params, module, use_orig_params)
def __init__(self, *args, **kwargs):
"""Constructor."""
super().__init__(*args, **kwargs)
self._init_fp8_meta()

def _init_fp8_meta(self):
"""Save fp8 related metadata."""
metas = []
paddeds = []
original_shapes = []
Expand Down Expand Up @@ -52,11 +47,11 @@ def _use_unsharded_views(self, as_params: bool) -> None:
for i, param_info in enumerate(self.flat_param._param_infos):
if hasattr(param_info.module, param_info.param_name):
param = getattr(param_info.module, param_info.param_name)

param._scaling_metas = self.flat_param._scaling_metas[i]
param._meta = self.flat_param._metas[i]
param._padded = self.flat_param._paddeds[i]
param._original_shape = self.flat_param._original_shapes[i]
if hasattr(self.flat_param, '_scaling_metas'):
param._scaling_metas = self.flat_param._scaling_metas[i]
param._meta = self.flat_param._metas[i]
param._padded = self.flat_param._paddeds[i]
param._original_shape = self.flat_param._original_shapes[i]

@torch.no_grad()
def _use_sharded_views(self) -> None:
Expand Down
11 changes: 5 additions & 6 deletions msamp/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import torch
from torch.distributed.fsdp import FullyShardedDataParallel
from torch.distributed.algorithms._comm_hooks import default_hooks
from torch.distributed.fsdp._init_utils import _get_default_comm_hook

from msamp.fsdp.flat_param import FP8FlatParamHandle
Expand All @@ -31,13 +30,13 @@ def _fp8_allreduce_hook(state, grad, output):
from msamp.operators.dist_op import DistOp
dtype = Dtypes.get_dtype_from_qtype(meta.qtype)
DistOp.enable_fp8(meta.qtype)
torch.distributed.all_reduce(grad[start:end].view(dtype), group=state.process_group)
torch.distributed.all_reduce(
grad[start:end].view(dtype), group=state.process_group if state else None
)
DistOp.disable_fp8()
else:
default_hooks.allreduce_hook(
state=state,
grad=grad[start:end],
)
torch.distributed.all_reduce(grad[start:end], group=state.process_group if state else None)

start = self.rank * output.numel()
end = (self.rank + 1) * output.numel()
output.copy_(grad[start:end])
Expand Down
37 changes: 23 additions & 14 deletions msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import torch
import transformer_engine.pytorch as te
import transformer_engine_extensions as tex
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule

from msamp.common.tensor import ScalingTensor
Expand Down Expand Up @@ -64,22 +66,29 @@ def set_fp8_weights(self):
return

setattr(
self,
weight_cast_attr,
torch.empty(
(0, 0),
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
self, weight_cast_attr,
Float8Tensor(
data=torch.empty(
(0, 0),
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)

setattr(
self,
weight_transpose_attr,
torch.empty(
(0, 0),
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
self, weight_transpose_attr,
Float8Tensor(
data=torch.empty(
(0, 0),
device=torch.cuda.current_device(),
dtype=torch.uint8,
),
fp8_dtype=tex.DType.kFloat8E4M3,
fp8_scale_inv=1,
)
)

@property
Expand Down
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ classifiers=[
]
dependencies = [
"torch",
"transformer-engine@git+https://github.com/NVIDIA/[email protected]#egg=transformer-engine",
"flash-attn==1.0.9",
"transformer-engine@git+https://github.com/NVIDIA/TransformerEngine.git@stable",
"colorlog>=6.7.0",
"deepspeed==0.13.1",
"mpi4py",
Expand Down
2 changes: 1 addition & 1 deletion tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _check_model(model):

scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)]
assert len(scaling_params) == 4
is_fp8_available, _ = te.fp8.is_fp8_available()
is_fp8_available = te.fp8.check_fp8_support()
if is_fp8_available:
# Do a forward pass to make sure the model is working.
fp8_format = Format.HYBRID
Expand Down
Loading