Skip to content

Commit

Permalink
update float8 integration after UX changes
Browse files Browse the repository at this point in the history
Summary:

float8_experimental landed various BC-breaking UX changes last week.
This PR updates torchtitan to work with the version of
float8_experimental after
pytorch-labs/float8_experimental#332

Test Plan:

```
with-proxy CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 NGPU=8 CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.enable_float8_linear --training.compile
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed Jul 25, 2024
1 parent 0f70507 commit cffbf1e
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions torchtitan/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

# Note: Performance
# Float8 experimental is intended to be ran under `torch.compile`` for competitive performance
import contextlib
import functools
from typing import Optional

Expand All @@ -24,20 +23,6 @@
from torchtitan.logging_utils import logger


@contextlib.contextmanager
def set_enable_fsdp_float8_all_gather(enable_fsdp_fp8_all_gather: bool):
import float8_experimental.config as config

prev = config.enable_fsdp_fp8_all_gather
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = enable_fsdp_fp8_all_gather
try:
yield
finally:
torch.distributed.barrier()
config.enable_fsdp_fp8_all_gather = prev


@functools.lru_cache(None)
def is_sm90_or_later():
# Float8 is only supported on H100+ GPUs
Expand All @@ -63,21 +48,26 @@ def maybe_build_fp8_linear(
)
return
try:
from float8_experimental.float8_linear import TensorScalingType
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear,
from float8_experimental import (
CastConfig,
convert_to_float8_training,
Float8LinearConfig,
ScalingType,
)

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
job_config.training.enable_fsdp_float8_all_gather and dp_enabled
)
with set_enable_fsdp_float8_all_gather(enable_fsdp_float8_all_gather):
swap_linear_with_float8_linear(
model,
scaling_type_w=TensorScalingType.DYNAMIC,
skip_fqn_list=["output"],
)
float8_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
)
convert_to_float8_training(
model,
config=float8_config,
module_filter_fn=lambda mod, fqn: fqn != "output",
)
logger.info(
f"Swapped to Float8Linear layers with {enable_fsdp_float8_all_gather=}"
)
Expand All @@ -102,6 +92,6 @@ def maybe_precompute_fp8_dynamic_scale_for_fsdp(
"Skipped precomputing fp8 scales because SM90 or later is not available",
)
return
from float8_experimental.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
from float8_experimental import precompute_float8_dynamic_scale_for_fsdp

precompute_float8_dynamic_scale_for_fsdp(model)

0 comments on commit cffbf1e

Please sign in to comment.