Skip to content

Commit

Permalink
Support writing optimizer checkpoint only on rank0 and make UT pass o…
Browse files Browse the repository at this point in the history
…n A100 (#142)

**Description**
Support writing optimizer checkpoint only on rank0 and make UT pass on
A100.
- support write checkpoint on rank0. 
With this PR, we don't need to change
[checkpointing](https://github.com/NVIDIA/Megatron-LM/blob/0609f27fe8376f17ab65c001d3d8f35cd8175950/megatron/checkpointing.py)
in
[MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples/tree/main/gpt3).
- Fix some bugs of TransformerEngine integration and make UT pass on
A100
- Improve document
  • Loading branch information
tocean authored Mar 1, 2024
1 parent 51f34ac commit 0a2cd72
Show file tree
Hide file tree
Showing 9 changed files with 356 additions and 60 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/build-image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
uses: actions/checkout@v2
with:
submodules: true
path: buildimage
- name: Free disk space
run: |
mkdir -p /tmp/emptydir
Expand All @@ -54,7 +55,7 @@ jobs:
if [[ "${{ github.event_name }}" == "release" ]]; then
TAGS=$(sed "s/main/${GITHUB_REF##*/}/g" <<< ${TAGS})
fi
DOCKERFILE=dockerfile/${{ matrix.name }}.dockerfile
DOCKERFILE=buildimage/dockerfile/${{ matrix.name }}.dockerfile
CACHE_FROM="type=registry,ref=$(cut -d, -f1 <<< ${TAGS})"
CACHE_TO=""
Expand Down Expand Up @@ -87,7 +88,7 @@ jobs:
uses: docker/build-push-action@v2
with:
platforms: linux/amd64
context: .
context: ./buildimage
file: ${{ steps.metadata.outputs.dockerfile }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}
Expand Down
4 changes: 0 additions & 4 deletions .github/workflows/unit-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,6 @@ jobs:
export LD_PRELOAD="/usr/local/lib/libmsamp_dist.so:/usr/local/lib/libnccl.so:${LD_PRELOAD}"
cd ${{ matrix.dir }}/
python3 setup.py test
- name: Clean repository
if: always()
run: |
rm -rf ${{ matrix.dir }}/
# - name: Report coverage results
# run: |
# bash <(curl -s https://codecov.io/bash)
4 changes: 4 additions & 0 deletions docs/getting-started/run-msamp.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,8 @@ deepspeed cifar10_deepspeed.py --deepspeed --deepspeed_config ds_config_zero_msa
deepspeed cifar10_deepspeed_te.py --deepspeed --deepspeed_config ds_config_zero_te_msamp.json
```

:::note Note
If you get "ModuleNotFoundError: No module named 'timm'" error when running this example, you need to install timm using `pip install timm`.
:::

For more comprehensive examples, please go to [MS-AMP-Examples](https://github.com/Azure/MS-AMP-Examples).
2 changes: 1 addition & 1 deletion docs/introduction.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Here are the results for GPT-3, Swin-T, DeiT-S and RoBERTa-B.

### System performance

MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 42% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 17% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP.
MS-AMP preserves high-precision's accuracy while using only a fraction of the memory footprint on a range of tasks, including GPT-3, DeiT and Swin Transformer. For example, when training GPT-175B on NVIDIA H100 platform, MS-AMP achieves a notable 39% reduction in real memory usage compared with BF16 mixed-precision approach and reduces training time by 37% compared with Transformer Engine. For small models, MS-AMP with O2 mode can achieve 44% memory saving for Swin-1.0B and 26% memory saving for ViT-1.2B, comparing with FP16 AMP.

Here are the resuls for GPT-3:

Expand Down
327 changes: 297 additions & 30 deletions msamp/megatron/optimizer/distrib_optimizer.py

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions msamp/optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,12 @@ def adamw_fn( # noqa: C901

for i, param in enumerate(params):
grad = grads[i].float() if not maximize else -grads[i].float()
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else 1.0
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else 1.0
exp_avgs[i].meta.scale = _new_exp_avg_factors[i] if self.tensor_scale else torch.ones((), device='cuda')
exp_avgs[i].meta.scale_inv.fill_(1.0 / exp_avgs[i].meta.scale)
exp_avg_sqs[i].meta.scale = _new_exp_avg_sq_factors[i] if self.tensor_scale else torch.ones(
(), device='cuda'
)
exp_avg_sqs[i].meta.scale_inv.fill_(1.0 / exp_avg_sqs[i].meta.scale)
# update state
msamp_adamw.adamw_fp8_stage2_compute(
grad, exp_avgs[i].value, _exp_avg_inv_factors[i], exp_avgs[i].meta.scale, beta1,
Expand Down
24 changes: 24 additions & 0 deletions msamp/te/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from msamp.common.dtype import Dtypes
from msamp.common.tensor import ScalingTensor
from msamp.nn import ScalingParameter


class TeExtensionOverrider:
Expand All @@ -24,6 +25,7 @@ class TeExtensionOverrider:
original_fused_cast_transpose = tex.fused_cast_transpose
original_cast_to_fp8 = te.cpp_extensions.cast_to_fp8
original_fp8_cast_transpose_fused = te.cpp_extensions.fp8_cast_transpose_fused
original_cast_if_needed = te.utils.cast_if_needed

@staticmethod
@torch.no_grad()
Expand Down Expand Up @@ -119,6 +121,24 @@ def cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out=None):
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype)
return TeExtensionOverrider.original_cast_to_fp8(inp, fp8_meta_tensor, fp8_tensor, otype, out)

@staticmethod
def cast_if_needed(tensor, dtype):
"""Cast tensor to dtype.
Args:
tensor (torch.Tensor or ScalingParameter): Input tensor.
dtype (torch.dtype): Output dtype.
Returns:
torch.Tensor: Output tensor.
"""
with torch.enable_grad():
if isinstance(tensor, ScalingParameter):
new_tensor = tensor.to(dtype)
new_tensor.requires_grad = tensor.requires_grad
return new_tensor
return TeExtensionOverrider.original_cast_if_needed(tensor, dtype)

@staticmethod
def override():
"""Override transformer engine extension functions."""
Expand All @@ -127,5 +147,9 @@ def override():
te.module.linear.cast_to_fp8 = TeExtensionOverrider.cast_to_fp8
te.cpp_extensions.fp8_cast_transpose_fused = TeExtensionOverrider.fp8_cast_transpose_fused

te.module.layernorm_linear.cast_if_needed = TeExtensionOverrider.cast_if_needed
te.module.linear.cast_if_needed = TeExtensionOverrider.cast_if_needed
te.module.layernorm_mlp.cast_if_needed = TeExtensionOverrider.cast_if_needed


TeExtensionOverrider.override()
2 changes: 0 additions & 2 deletions msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ def _override_classes(cls):
te.attention.Linear = MSAMPLinear
te.attention.LayerNormLinear = MSAMPLayerNormLinear

te.transformer.Linear = MSAMPLinear
te.transformer.LayerNormLinear = MSAMPLayerNormLinear
te.transformer.LayerNormMLP = MSAMPLayerNormMLP

@staticmethod
Expand Down
40 changes: 21 additions & 19 deletions tests/te/test_te_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import os
import unittest
from contextlib import nullcontext

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -65,17 +66,16 @@ def _check_model(model):

scaling_params = [p for p in model.parameters() if isinstance(p, ScalingParameter)]
assert len(scaling_params) == 4
is_fp8_available = te.fp8.check_fp8_support()
if is_fp8_available:
# Do a forward pass to make sure the model is working.
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
y = model(x, attention_mask=None)
assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size)
y.sum().backward()
is_fp8_available, _ = te.fp8.check_fp8_support()
# Do a forward pass to make sure the model is working.
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
x = torch.rand(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)

with te.fp8_autocast(enabled=is_fp8_available, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
y = model(x, attention_mask=None)
assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size)
y.sum().backward()

@decorator.cuda_test
def test_te_with_deepspeed(self):
Expand All @@ -100,12 +100,13 @@ def test_te_with_deepspeed(self):

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
is_fp8_available, _ = te.fp8.check_fp8_support()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
input = torch.randn(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)
output = model(input, attention_mask=None)
loss = output.sum()
model.backward(loss)
model.step()
loss = output.sum()
model.backward(loss)
model.step()


class TeReplacerDistributedTestCast(MultiProcessTestCase):
Expand Down Expand Up @@ -163,9 +164,10 @@ def test_fp8_ddp_with_te(self):
x = torch.randn(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
is_fp8_available, _ = te.fp8.check_fp8_support()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
output = model(x, attention_mask=None, is_first_microbatch=True)
output.sum().backward()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output.sum().backward()
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) if is_fp8_available else nullcontext():
output = model(x, attention_mask=None, is_first_microbatch=False)
output.sum().backward()
output.sum().backward()

0 comments on commit 0a2cd72

Please sign in to comment.