Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
add per-gemm config to Float8LinearConfig (#334)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #334

Previously the per-gemm configuration had to be hardcoded in library
code. This PR exposes it to the top-level UX by adding a
`Float8GemmConfig` field to `Float8LinearConfig`.

Note that today the only supported configuration option is
`use_fast_accum`.  In the future, configuring output_dtype
and whether to keep a gemm in higher precision would go here.

Reviewed By: weifengpy

Differential Revision: D60252069

fbshipit-source-id: bca34eb49e1bf046f937e32b11b2369b535d56e6
  • Loading branch information
vkuzo authored and facebook-github-bot committed Jul 25, 2024
1 parent ed1693e commit b9b606e
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 7 deletions.
2 changes: 2 additions & 0 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Lets define a few top level things here
from float8_experimental.config import (
DelayedScalingConfig,
Float8GemmConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
Expand Down Expand Up @@ -33,6 +34,7 @@
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"Float8GemmConfig",
"Float8LinearConfig",
"Float8TensorCastConfig",
# top level UX
Expand Down
19 changes: 19 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Configuration for a float8 gemm.
"""

# If True, fast accumulation in lower precision is used.
# Note: this flag is currently a no-op if emulation is turned on.
use_fast_accum: bool = False


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Expand All @@ -67,6 +78,14 @@ class Float8LinearConfig:
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()

#
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
# `grad_weight`
#
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig()

#
# Per-linear configuration
#
Expand Down
18 changes: 11 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs):

self.create_buffers()

# TODO(future): user level configuration of gemms
self.linear_mm_config = LinearMMConfig(
# input
# output
ScaledMMConfig(
emulate,
True if not emulate else False,
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# weight
# grad_input
ScaledMMConfig(
emulate,
True if not emulate else False,
self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_output
ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim),
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
Expand Down

0 comments on commit b9b606e

Please sign in to comment.