Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[bc-breaking] rename TensorScalingType->ScalingType, Float8TensorCastConfig->CastConfig #337

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions benchmarks/bench_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@

import torch
import torch.utils.benchmark as benchmark
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
linear_requires_sync,
Expand Down Expand Up @@ -107,15 +103,13 @@ def main(
device = "cuda"
print(f"Compile is set to | {compile}")

scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)

# LLaMa 2 70B single-node weight shapes
Expand Down
14 changes: 4 additions & 10 deletions benchmarks/bench_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
import torch.multiprocessing as mp
import torch.nn as nn
import torch.utils.benchmark as benchmark
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
sync_float8_amax_and_scale_history,
Expand All @@ -33,11 +29,9 @@
lr = 0.01

config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_weight=Float8TensorCastConfig(scaling_type=TensorScalingType.DELAYED),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=TensorScalingType.DELAYED
),
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
)


Expand Down
20 changes: 7 additions & 13 deletions benchmarks/profile_linear_float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from float8_experimental.config import (
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
)
from float8_experimental.config import CastConfig, Float8LinearConfig, ScalingType
from float8_experimental.float8_linear_utils import (
convert_to_float8_training,
linear_requires_sync,
Expand Down Expand Up @@ -217,15 +213,13 @@ def main(
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
assert dtype_filter in ("both", "float8", "bfloat16")

scaling_type_input = TensorScalingType(scaling_type_input)
scaling_type_weight = TensorScalingType(scaling_type_weight)
scaling_type_grad_output = TensorScalingType(scaling_type_grad_output)
scaling_type_input = ScalingType(scaling_type_input)
scaling_type_weight = ScalingType(scaling_type_weight)
scaling_type_grad_output = ScalingType(scaling_type_grad_output)
config = Float8LinearConfig(
cast_config_input=Float8TensorCastConfig(scaling_type=scaling_type_input),
cast_config_weight=Float8TensorCastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=Float8TensorCastConfig(
scaling_type=scaling_type_grad_output
),
cast_config_input=CastConfig(scaling_type=scaling_type_input),
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
)
scaling_repr = "_".join(
[
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.config import (
CastConfig,
DelayedScalingConfig,
Float8GemmConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
ScalingType,
)
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
Expand All @@ -33,10 +33,10 @@
__all__ = [
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"ScalingType",
"Float8GemmConfig",
"Float8LinearConfig",
"Float8TensorCastConfig",
"CastConfig",
# top level UX
"convert_to_float8_training",
"linear_requires_sync",
Expand Down
17 changes: 9 additions & 8 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,26 @@
from dataclasses import dataclass


class TensorScalingType(enum.Enum):
# TODO(future): consider renaming to ScalingType
class ScalingType(enum.Enum):
DELAYED = "delayed"
DYNAMIC = "dynamic"

def short_str(self):
if self is TensorScalingType.DELAYED:
if self is ScalingType.DELAYED:
return "del"
else:
assert self is TensorScalingType.DYNAMIC
assert self is ScalingType.DYNAMIC
return "dyn"


@dataclass(frozen=True)
class Float8TensorCastConfig:
class CastConfig:
"""
Configuration for casting a single tensor to float8
"""

scaling_type: TensorScalingType = TensorScalingType.DYNAMIC
scaling_type: ScalingType = ScalingType.DYNAMIC


@dataclass(frozen=True)
Expand Down Expand Up @@ -74,9 +75,9 @@ class Float8LinearConfig:
#
# Per-tensor configuration for `input`, `weight`, `grad_output`
#
cast_config_input: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_input: CastConfig = CastConfig()
cast_config_weight: CastConfig = CastConfig()
cast_config_grad_output: CastConfig = CastConfig()

#
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
Expand Down
26 changes: 12 additions & 14 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import torch

from float8_experimental.config import Float8LinearConfig, TensorScalingType
from float8_experimental.config import Float8LinearConfig, ScalingType

from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
Expand Down Expand Up @@ -159,9 +159,9 @@ def __init__(self, *args, **kwargs):
self.scaling_type_grad_output = config.cast_config_grad_output.scaling_type
# Convenience flag to skip code related to delayed scaling
self.has_any_delayed_scaling = (
self.scaling_type_input is TensorScalingType.DELAYED
or self.scaling_type_weight is TensorScalingType.DELAYED
or self.scaling_type_grad_output is TensorScalingType.DELAYED
self.scaling_type_input is ScalingType.DELAYED
or self.scaling_type_weight is ScalingType.DELAYED
or self.scaling_type_grad_output is ScalingType.DELAYED
)

self.config = config
Expand Down Expand Up @@ -284,7 +284,7 @@ def cast_input_to_float8(
autocast_dtype = torch.get_autocast_gpu_dtype()
input = input.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
if self.scaling_type_input is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
input,
Expand All @@ -305,14 +305,14 @@ def cast_input_to_float8(
gemm_input_role=GemmInputRole.INPUT,
)
else:
assert self.scaling_type_input is TensorScalingType.DYNAMIC
assert self.scaling_type_input is ScalingType.DYNAMIC
input_fp8 = cast_to_float8_e4m3_dynamic(input, self.linear_mm_config)
return input_fp8

def cast_weight_to_float8(
self, weight: torch.Tensor, is_amax_initialized: bool
) -> torch.Tensor:
if self.scaling_type_weight is TensorScalingType.DELAYED:
if self.scaling_type_weight is ScalingType.DELAYED:
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
Expand All @@ -337,7 +337,7 @@ def cast_weight_to_float8(
gemm_input_role=GemmInputRole.WEIGHT,
)
else:
assert self.scaling_type_weight is TensorScalingType.DYNAMIC
assert self.scaling_type_weight is ScalingType.DYNAMIC
if isinstance(self.weight, Float8Tensor): # cast by FSDP
weight_fp8 = self.weight
else:
Expand All @@ -349,7 +349,7 @@ def cast_weight_to_float8(
return weight_fp8

def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
if self.scaling_type_grad_output is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
output = NoopFwToFloat8E5M2Bw.apply(
output,
Expand All @@ -361,7 +361,7 @@ def cast_output_to_float8_in_bw(self, output: torch.Tensor) -> torch.Tensor:
self.linear_mm_config,
)
else:
assert self.scaling_type_grad_output is TensorScalingType.DYNAMIC
assert self.scaling_type_grad_output is ScalingType.DYNAMIC
output = cast_to_float8_e5m2_dynamic_bw(output, self.linear_mm_config)
return output

Expand Down Expand Up @@ -448,17 +448,15 @@ def from_float(
# 2. buffers need to be already created for the delayed scaling version
# of the weight wrapper to be initialized
if config.enable_fsdp_float8_all_gather:
if config.cast_config_weight.scaling_type is TensorScalingType.DYNAMIC:
if config.cast_config_weight.scaling_type is ScalingType.DYNAMIC:
new_mod.weight = torch.nn.Parameter(
WeightWithDynamicFloat8CastTensor(
new_mod.weight,
new_mod.linear_mm_config,
)
)
else:
assert (
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED
)
assert config.cast_config_weight.scaling_type is ScalingType.DELAYED
new_mod.weight = torch.nn.Parameter(
WeightWithDelayedFloat8CastTensor(
new_mod.weight,
Expand Down
8 changes: 4 additions & 4 deletions float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from float8_experimental.config import Float8LinearConfig, TensorScalingType
from float8_experimental.config import Float8LinearConfig, ScalingType
from float8_experimental.float8_linear import Float8Linear

from float8_experimental.float8_utils import (
Expand All @@ -27,9 +27,9 @@ def linear_requires_sync(config: Float8LinearConfig):
"""Returns whether the given linear_type requires sync before forward."""
return any(
[
config.cast_config_input.scaling_type is TensorScalingType.DELAYED,
config.cast_config_weight.scaling_type is TensorScalingType.DELAYED,
config.cast_config_grad_output.scaling_type is TensorScalingType.DELAYED,
config.cast_config_input.scaling_type is ScalingType.DELAYED,
config.cast_config_weight.scaling_type is ScalingType.DELAYED,
config.cast_config_grad_output.scaling_type is ScalingType.DELAYED,
]
)

Expand Down
6 changes: 3 additions & 3 deletions float8_experimental/float8_tensor_parallel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn
from float8_experimental.config import TensorScalingType
from float8_experimental.config import ScalingType
from float8_experimental.float8_dynamic_utils import (
cast_to_float8_e4m3_dynamic,
cast_to_float8_e5m2_dynamic_bw,
Expand Down Expand Up @@ -28,8 +28,8 @@ def _float8_linear_supports_float8_allgather(m):
# TODO(future): add support for delayed scaling for activations
# and gradients
return (
m.scaling_type_input == TensorScalingType.DYNAMIC
and m.scaling_type_grad_output == TensorScalingType.DYNAMIC
m.scaling_type_input == ScalingType.DYNAMIC
and m.scaling_type_grad_output == ScalingType.DYNAMIC
)


Expand Down
5 changes: 2 additions & 3 deletions float8_experimental/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)
"""
from float8_experimental.config import TensorScalingType
from float8_experimental.config import ScalingType
from float8_experimental.float8_linear import Float8Linear
from torch.distributed._tensor import DTensor

if any(
isinstance(m, Float8Linear)
and m.scaling_type_weight is TensorScalingType.DELAYED
isinstance(m, Float8Linear) and m.scaling_type_weight is ScalingType.DELAYED
for m in module.modules()
):
raise NotImplementedError("Only supports delayed scaling")
Expand Down
Loading
Loading