Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add per-gemm config to Float8LinearConfig #334

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions float8_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# Lets define a few top level things here
from float8_experimental.config import (
DelayedScalingConfig,
Float8GemmConfig,
Float8LinearConfig,
Float8TensorCastConfig,
TensorScalingType,
Expand Down Expand Up @@ -33,6 +34,7 @@
# configuration
"DelayedScalingConfig",
"TensorScalingType",
"Float8GemmConfig",
"Float8LinearConfig",
"Float8TensorCastConfig",
# top level UX
Expand Down
19 changes: 19 additions & 0 deletions float8_experimental/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,17 @@ def __post_init__(self):
), f"{self.scale_fn_name} is not implemented yet. Only max is supported for now."


@dataclass(frozen=True)
class Float8GemmConfig:
"""
Configuration for a float8 gemm.
"""

# If True, fast accumulation in lower precision is used.
# Note: this flag is currently a no-op if emulation is turned on.
use_fast_accum: bool = False


@dataclass(frozen=True)
class Float8LinearConfig:
"""
Expand All @@ -67,6 +78,14 @@ class Float8LinearConfig:
cast_config_weight: Float8TensorCastConfig = Float8TensorCastConfig()
cast_config_grad_output: Float8TensorCastConfig = Float8TensorCastConfig()

#
# Per-gemm configuration for gemms calculating `output`, `grad_input` and
# `grad_weight`
#
gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True)
gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig()
gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig()

#
# Per-linear configuration
#
Expand Down
18 changes: 11 additions & 7 deletions float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,24 +168,28 @@ def __init__(self, *args, **kwargs):

self.create_buffers()

# TODO(future): user level configuration of gemms
self.linear_mm_config = LinearMMConfig(
# input
# output
ScaledMMConfig(
emulate,
True if not emulate else False,
self.config.gemm_config_output.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# weight
# grad_input
ScaledMMConfig(
emulate,
True if not emulate else False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was mistakenly set to default True in #315, fixing

self.config.gemm_config_grad_input.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_weight
ScaledMMConfig(
emulate,
self.config.gemm_config_grad_weight.use_fast_accum,
False,
self.config.pad_inner_dim,
),
# grad_output
ScaledMMConfig(emulate, False, False, self.config.pad_inner_dim),
)

# Note: is_amax_initialized is not a buffer to avoid data dependent
Expand Down
Loading