diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 95491f3..4c1f255 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. # Lets define a few top level things here from float8_experimental.config import ( + DelayedScalingConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -30,6 +31,7 @@ __all__ = [ # configuration + "DelayedScalingConfig", "TensorScalingType", "Float8LinearConfig", "Float8TensorCastConfig", diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 2e9eacf..ea088e3 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -29,6 +29,30 @@ class Float8TensorCastConfig: scaling_type: TensorScalingType = TensorScalingType.DYNAMIC +@dataclass(frozen=True) +class DelayedScalingConfig: + """ + Configuration for delayed scaling. + + Note: for now, `history_len` values must be the same for all layers in the + model using delayed scaling. + + TODO(future): serialization for recipes + """ + + # Controls the history length of amax buffers + history_len: int = 16 + + # Controls the way to calculate current scale from amax history + # TODO(future): add other functions as needed, hardcoded or user defined + scale_fn_name: str = "max" + + def __post_init__(self): + assert ( + self.scale_fn_name == "max" + ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -71,6 +95,13 @@ class Float8LinearConfig: # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False + # Configuration for delayed scaling + # Note: this is actually applied per-tensor, but only using the same + # configuration for all tensors and layers in the model is currently + # supported. If in the future we add support for a more fine grained + # configuration, this field may move to per-tensor configs. + delayed_scaling_config: DelayedScalingConfig = DelayedScalingConfig() + # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 1d8519e..581f9f3 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -131,23 +131,6 @@ def backward(ctx, go): return res, *empty_grads -@dataclasses.dataclass -class DelayedScalingRecipe: - # Controls the history length of amax buffers - history_len: int - - # Controls the way to calculate current scale from amax history - # TODO(future): add other functions as needed, hardcoded or user defined - scale_fn_name: str - - def __init__(self, history_len: int = 16, scale_fn_name: str = "max"): - self.history_len = history_len - self.scale_fn_name = scale_fn_name - assert ( - self.scale_fn_name == "max" - ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." - - class Float8Linear(torch.nn.Linear): """ Note: this is **not** a public API and is only intended to be used @@ -161,13 +144,9 @@ class Float8Linear(torch.nn.Linear): def __init__(self, *args, **kwargs): """ Additional arguments on top of `torch.nn.Linear`'s arguments: - * `delayed_scaling_recipe`: configuration for delayed scaling * `config`: Float8LinearConfig """ - delayed_scaling_recipe = kwargs.pop( - "delayed_scaling_recipe", DelayedScalingRecipe() - ) # Amax scales should always be kept as float32. self.always_float32_buffers = set() config = kwargs.pop("config") @@ -187,11 +166,6 @@ def __init__(self, *args, **kwargs): self.config = config - # TODO(future): have a unique recipe per buffer instead of one per - # module, saving implementing that until we need it. - # TODO(future): serialization for recipes - self.recipe = delayed_scaling_recipe - self.create_buffers() # TODO(future): user level configuration of gemms @@ -237,7 +211,7 @@ def __init__(self, *args, **kwargs): def create_buffers(self): # Default values for history buffers, see above TODO - history_len = self.recipe.history_len + history_len = self.config.delayed_scaling_config.history_len device = self.weight.device # TODO(future PR): dtype values below don't have the other float8 # flavors, fix it @@ -307,7 +281,7 @@ def cast_x_to_float8( x = x.to(autocast_dtype) if self.scaling_type_input is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( x, self.fp8_amax_input, @@ -338,7 +312,7 @@ def cast_w_to_float8( if isinstance(self.weight, Float8Tensor): # cast by FSDP w_fp8 = self.weight else: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( w, self.fp8_amax_weight, @@ -370,7 +344,7 @@ def cast_w_to_float8( def cast_y_to_float8_in_bw(self, y: torch.Tensor) -> torch.Tensor: if self.scaling_type_grad_output is TensorScalingType.DELAYED: - scale_fn_name = self.recipe.scale_fn_name + scale_fn_name = self.config.delayed_scaling_config.scale_fn_name y = NoopFwToFloat8E5M2Bw.apply( y, self.fp8_amax_grad_output, diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index e3a758a..c72b620 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -237,7 +237,7 @@ def inner_func(): fp8_grad_output_amax_history_stack[idx] = child.fp8_amax_history_grad_output x_dtypes.add(child.last_seen_input_dtype) - scale_fn_recipes.add(child.recipe.scale_fn_name) + scale_fn_recipes.add(child.config.delayed_scaling_config.scale_fn_name) # TODO This way to get the activation dtype is not ideal if len(x_dtypes) != 1: