diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 4c1f255..08c0ac4 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -6,6 +6,7 @@ # Lets define a few top level things here from float8_experimental.config import ( DelayedScalingConfig, + Float8GemmConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -33,6 +34,7 @@ # configuration "DelayedScalingConfig", "TensorScalingType", + "Float8GemmConfig", "Float8LinearConfig", "Float8TensorCastConfig", # top level UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py index ea088e3..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -53,6 +53,17 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." +@dataclass(frozen=True) +class Float8GemmConfig: + """ + Configuration for a float8 gemm. + """ + + # If True, fast accumulation in lower precision is used. + # Note: this flag is currently a no-op if emulation is turned on. + use_fast_accum: bool = False + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -67,6 +78,14 @@ class Float8LinearConfig: cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + # + # Per-gemm configuration for gemms calculating `output`, `grad_input` and + # `grad_weight` + # + gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() + # # Per-linear configuration # diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 581f9f3..c598a93 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # TODO(future): user level configuration of gemms self.linear_mm_config = LinearMMConfig( - # input + # output ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_output.use_fast_accum, False, self.config.pad_inner_dim, ), - # weight + # grad_input ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, False, self.config.pad_inner_dim, ), - # grad_output - ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent