Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
rename DelayedScalingRecipe to DelayedScalingConfig (#333)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #333

1. rename `DelayedScalingRecipe` to `DelayedScalingConfig`
2. move this to `config.py` and make user facing

Reviewed By: weifengpy

Differential Revision: D60252067

fbshipit-source-id: ec233df1e0d03fdc649a19de1722ee45d5029aa6
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 25, 2024
1 parent eff4ba6 commit ed1693e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 31 deletions.
2 changes: 2 additions & 0 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
# Lets define a few top level things here
from float8_experimental.config import (
DelayedScalingConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
Expand All @@ -30,6 +31,7 @@

__all__ = [
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"Float8LinearConfig",
"Float8TensorCastConfig",
Expand Down
31 changes: 31 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,30 @@ class Float8TensorCastConfig:
scaling_type: TensorScalingType = TensorScalingType.DYNAMIC


@dataclass(frozen=True)
class DelayedScalingConfig:
"""
Configuration for delayed scaling.
Note: for now, `history_len` values must be the same for all layers in the
model using delayed scaling.
TODO(future): serialization for recipes
"""

# Controls the history length of amax buffers
history_len: int = 16

# Controls the way to calculate current scale from amax history
# TODO(future): add other functions as needed, hardcoded or user defined
scale_fn_name: str = "max"

def __post_init__(self):
assert (
self.scale_fn_name == "max"
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Expand Down Expand Up @@ -71,6 +95,13 @@ class Float8LinearConfig:
# If True, emulation is used instead of hardware accelerated gemm
emulate: bool = False

# Configuration for delayed scaling
# Note: this is actually applied per-tensor, but only using the same
# configuration for all tensors and layers in the model is currently
# supported. If in the future we add support for a more fine grained
# configuration, this field may move to per-tensor configs.
delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig()


# If True, use 'fnuz' float8 types for calculations.
# Currently, ROCm only supports fnuz variants.
Expand Down
34 changes: 4 additions & 30 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,23 +131,6 @@ def backward(ctx, go):
return res, *empty_grads


@dataclasses.dataclass
class DelayedScalingRecipe:
# Controls the history length of amax buffers
history_len: int

# Controls the way to calculate current scale from amax history
# TODO(future): add other functions as needed, hardcoded or user defined
scale_fn_name: str

def __init__(self, history_len: int = 16, scale_fn_name: str = "max"):
self.history_len = history_len
self.scale_fn_name = scale_fn_name
assert (
self.scale_fn_name == "max"
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


class Float8Linear(torch.nn.Linear):
"""
Note: this is **not** a public API and is only intended to be used
Expand All @@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear):
def __init__(self, *args, **kwargs):
"""
Additional arguments on top of `torch.nn.Linear`'s arguments:
* `delayed_scaling_recipe`: configuration for delayed scaling
* `config`: Float8LinearConfig
"""

delayed_scaling_recipe = kwargs.pop(
"delayed_scaling_recipe", DelayedScalingRecipe()
)
# Amax scales should always be kept as float32.
self.always_float32_buffers = set()
config = kwargs.pop("config")
Expand All @@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs):

self.config = config

# TODO(future): have a unique recipe per buffer instead of one per
# module, saving implementing that until we need it.
# TODO(future): serialization for recipes
self.recipe = delayed_scaling_recipe

self.create_buffers()

# TODO(future): user level configuration of gemms
Expand Down Expand Up @@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs):

def create_buffers(self):
# Default values for history buffers, see above TODO
history_len = self.recipe.history_len
history_len = self.config.delayed_scaling_config.history_len
device = self.weight.device
# TODO(future PR): dtype values below don't have the other float8
# flavors, fix it
Expand Down Expand Up @@ -307,7 +281,7 @@ def cast_x_to_float8(
x = x.to(autocast_dtype)

if self.scaling_type_input is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
x,
self.fp8_amax_input,
Expand Down Expand Up @@ -338,7 +312,7 @@ def cast_w_to_float8(
if isinstance(self.weight, Float8Tensor): # cast by FSDP
w_fp8 = self.weight
else:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
_maybe_initialize_amaxes_scales_for_float8_cast(
w,
self.fp8_amax_weight,
Expand Down Expand Up @@ -370,7 +344,7 @@ def cast_w_to_float8(

def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor:
if self.scaling_type_grad_output is TensorScalingType.DELAYED:
scale_fn_name = self.recipe.scale_fn_name
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
y = NoopFwToFloat8E5M2Bw.apply(
y,
self.fp8_amax_grad_output,
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def inner_func():
fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output

x_dtypes.add(child.last_seen_input_dtype)
scale_fn_recipes.add(child.recipe.scale_fn_name)
scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name)

# TODO This way to get the activation dtype is not ideal
if len(x_dtypes) != 1:
Expand Down

0 comments on commit ed1693e

Please sign in to comment.