From 994057c0d2e2805fb590d19d326bec14cb0b6e4c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Jul 2024 09:40:26 -0700 Subject: [PATCH 1/2] add per-gemm config to `Float8LinearConfig` Summary: Previously the per-gemm configuration had to be hardcoded in library code. This PR exposes it to the top-level UX by adding a `Float8GemmConfig` field to `Float8LinearConfig`. Note that today the only supported configuration option is `use_fast_accum`. In the future, configuring output_dtype and whether to keep a gemm in higher precision would go here. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/__init__.py | 2 ++ float8_experimental/config.py | 19 +++++++++++++++++++ float8_experimental/float8_linear.py | 18 +++++++++++------- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/float8_experimental/__init__.py b/float8_experimental/__init__.py index 4c1f255..08c0ac4 100644 --- a/float8_experimental/__init__.py +++ b/float8_experimental/__init__.py @@ -6,6 +6,7 @@ # Lets define a few top level things here from float8_experimental.config import ( DelayedScalingConfig, + Float8GemmConfig, Float8LinearConfig, Float8TensorCastConfig, TensorScalingType, @@ -33,6 +34,7 @@ # configuration "DelayedScalingConfig", "TensorScalingType", + "Float8GemmConfig", "Float8LinearConfig", "Float8TensorCastConfig", # top level UX diff --git a/float8_experimental/config.py b/float8_experimental/config.py index ea088e3..7cb7230 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -53,6 +53,17 @@ def __post_init__(self): ), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now." +@dataclass(frozen=True) +class Float8GemmConfig: + """ + Configuration for a float8 gemm. + """ + + # If True, fast accumulation in lower precision is used. + # Note: this flag is currently a no-op if emulation is turned on. + use_fast_accum: bool = False + + @dataclass(frozen=True) class Float8LinearConfig: """ @@ -67,6 +78,14 @@ class Float8LinearConfig: cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig() cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig() + # + # Per-gemm configuration for gemms calculating `output`, `grad_input` and + # `grad_weight` + # + gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + # # Per-linear configuration # diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 581f9f3..c598a93 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs): self.create_buffers() - # TODO(future): user level configuration of gemms self.linear_mm_config = LinearMMConfig( - # input + # output ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_output.use_fast_accum, False, self.config.pad_inner_dim, ), - # weight + # grad_input ScaledMMConfig( emulate, - True if not emulate else False, + self.config.gemm_config_grad_input.use_fast_accum, + False, + self.config.pad_inner_dim, + ), + # grad_weight + ScaledMMConfig( + emulate, + self.config.gemm_config_grad_weight.use_fast_accum, False, self.config.pad_inner_dim, ), - # grad_output - ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim), ) # Note: is_amax_initialized is not a buffer to avoid data dependent From 962b65202d91c9f706ec197cda8098312f37588c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Thu, 25 Jul 2024 09:43:02 -0700 Subject: [PATCH 2/2] Update on "add per-gemm config to `Float8LinearConfig`" Summary: Previously the per-gemm configuration had to be hardcoded in library code. This PR exposes it to the top-level UX by adding a `Float8GemmConfig` field to `Float8LinearConfig`. Note that today the only supported configuration option is `use_fast_accum`. In the future, configuring output_dtype and whether to keep a gemm in higher precision would go here. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] --- float8_experimental/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 7cb7230..6408ac7 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -84,7 +84,7 @@ class Float8LinearConfig: # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() - gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) + gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() # # Per-linear configuration