Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

[bc-breaking] rename DelayedScalingRecipe to DelayedScalingConfig #333

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.config import (
DelayedScalingConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
Expand All @@ -30,6 +31,7 @@

__all__ = [
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"Float8LinearConfig",
"Float8TensorCastConfig",
Expand Down
31 changes: 31 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ class Float8TensorCastConfig:
scaling_type: TensorScalingType = TensorScalingType.DYNAMIC


@dataclass(frozen=True)
class DelayedScalingConfig:
"""
Configuration for delayed scaling.

Note: for now, `history_len` values must be the same for all layers in the
model using delayed scaling.

TODO(future): serialization for recipes
"""

# Controls the history length of amax buffers
history_len: int = 16

# Controls the way to calculate current scale from amax history
# TODO(future): add other functions as needed, hardcoded or user defined
scale_fn_name: str = "max"

def __post_init__(self):
assert (
self.scale_fn_name == "max"
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Expand Down Expand Up @@ -71,6 +95,13 @@ class Float8LinearConfig:
# If True, emulation is used instead of hardware accelerated gemm
emulate: bool = False

# Configuration for delayed scaling
# Note: this is actually applied per-tensor, but only using the same
# configuration for all tensors and layers in the model is currently
# supported. If in the future we add support for a more fine grained
# configuration, this field may move to per-tensor configs.
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
34 changes: 4 additions & 30 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,6 @@ def backward(ctx, go):
return res, *empty_grads


@dataclasses.dataclass
class DelayedScalingRecipe:
# Controls the history length of amax buffers
history_len: int

# Controls the way to calculate current scale from amax history
# TODO(future): add other functions as needed, hardcoded or user defined
scale_fn_name: str

def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
self.history_len = history_len
self.scale_fn_name = scale_fn_name
assert (
self.scale_fn_name == "max"
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
Expand All @@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
"""
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `delayed_scaling_recipe`: configuration for delayed scaling
* `config`: Float8LinearConfig
"""

delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
config = kwargs.pop("config")
Expand All @@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs):

self.config = config

# TODO(future): have a unique recipe per buffer instead of one per
# module, saving implementing that until we need it.
# TODO(future): serialization for recipes
self.recipe = delayed_scaling_recipe

self.create_buffers()

# TODO(future): user level configuration of gemms
Expand Down Expand Up @@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs):

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.recipe.history_len
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
Expand Down Expand Up @@ -307,7 +281,7 @@ def cast_x_to_float8(
x = x.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
self.fp8_amax_input,
Expand Down Expand Up @@ -338,7 +312,7 @@ def cast_w_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_weight,
Expand Down Expand Up @@ -370,7 +344,7 @@ def cast_w_to_float8(

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
self.fp8_amax_grad_output,
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def inner_func():
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
Expand Down
Loading