Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Skipping SAM for now since it hangs
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed May 28, 2024
1 parent cdb7867 commit e29cc35
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 19 deletions.
87 changes: 69 additions & 18 deletions float8_experimental/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple
from typing import Literal, Tuple

import torch
import torch.distributed as dist
Expand All @@ -14,22 +14,40 @@

# define the e4m3/e5m2 constants
E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
E4M3_FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
E5M2_FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max

FP16_MAX_POS = torch.finfo(torch.float16).max

# avoid division by zero when calculating scale
# TODO: align this value with NVIDIA's assumptions (current value is a guess)
EPS = 1e-12

IS_AMD = torch.cuda.is_available() and torch.version.hip is not None


@torch.no_grad()
def amax_to_scale(amax, float8_dtype, orig_dtype):
def amax_to_scale(
amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
):
"""Converts the amax value of a tensor to the fp8 scale.
Args:
amax: The amax value of the tensor.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
scale = torch.empty_like(amax, dtype=torch.float32)
if float8_dtype == torch.float8_e4m3fn:
res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
else: # e5m2
elif float8_dtype == torch.float8_e4m3fnuz:
res = E4M3_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2:
res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
elif float8_dtype == torch.float8_e5m2fnuz:
res = E5M2_FNUZ_MAX_POS / torch.clamp(amax, min=EPS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")

# Ensure that the scale is representable in float16,
# this helps when amax is small. We are assuming that we don't need
Expand All @@ -42,11 +60,18 @@ def amax_to_scale(amax, float8_dtype, orig_dtype):

@torch.no_grad()
def amax_history_to_scale(
amax_history,
float8_dtype,
orig_dtype,
history_to_scale_fn_type,
amax_history: torch.Tensor,
float8_dtype: torch.Tensor,
orig_dtype: torch.dtype,
history_to_scale_fn_type: Literal["max"],
):
"""Takes in a history of amax values and returns a scale tensor.
Args:
amax_history: A tensor containing the history of amax values.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax = torch.max(amax_history)
return amax_to_scale(amax, float8_dtype, orig_dtype)
Expand All @@ -58,9 +83,15 @@ def amax_history_to_scale_stack(
amax_history: torch.Tensor,
float8_dtype: torch.dtype,
orig_dtype: torch.dtype,
history_to_scale_fn_type: str,
history_to_scale_fn_type: Literal["max"],
) -> torch.Tensor:
"""Takes in a stack of amax_history tensors and returns a scale tensor."""
"""Takes in a stack of amax_history tensors and returns a scale tensor.
Args:
amax_history: A 2D tensor containing a stack of amax histories.
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
history_to_scale_fn_type: The type of function to use to convert the history to a scale.
"""
if history_to_scale_fn_type == "max":
amax_stack = torch.max(amax_history, dim=1).values
return amax_to_scale(amax_stack, float8_dtype, orig_dtype)
Expand Down Expand Up @@ -90,21 +121,41 @@ def tensor_to_scale(
return amax_to_scale(amax, float8_dtype, x.dtype)


def to_fp8_saturated(x, float8_dtype: torch.dtype):
# The default behavior in PyTorch for casting to `float8_e4m3fn`
# and `e5m2` is to not saturate. In this context, we should saturate.
# A common case where we want to saturate is when the history of a
# tensor has a maximum value of `amax1`, and the current amax value
# is `amax2`, where `amax1 < amax2`. This is common when using delayed
# scaling.
def to_fp8_saturated(x: torch.Tensor, float8_dtype: torch.dtype):
"""Converts a tensor to a saturated fp8 tensor.
Note:
The default behavior in PyTorch for casting to `float8_e4m3fn`
and `e5m2` is to not saturate. In this context, we should saturate.
A common case where we want to saturate is when the history of a
tensor has a maximum value of `amax1`, and the current amax value
is `amax2`, where `amax1 < amax2`. This is common when using delayed
scaling.
"""

if float8_dtype == torch.float8_e4m3fn:
x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
else:
elif float8_dtype == torch.float8_e4m3fnuz:
x = x.clamp(min=-1 * E4M3_FNUZ_MAX_POS, max=E4M3_FNUZ_MAX_POS)
elif float8_dtype == torch.float8_e5m2:
x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
elif float8_dtype == torch.float8_e5m2fnuz:
x = x.clamp(min=-1 * E5M2_FNUZ_MAX_POS, max=E5M2_FNUZ_MAX_POS)
else:
raise ValueError(f"Unsupported float8_dtype: {float8_dtype}")
return x.to(float8_dtype)


def compute_error(x, y):
def compute_error(x: torch.Tensor, y: torch.Tensor):
"""Computes the error between two tensors in dB.
For more details see:
https://en.wikipedia.org/wiki/Signal-to-noise_ratio
Args:
x: The original tensor.
y: The tensor to compare to the original tensor.
"""
Ps = torch.norm(x)
Pn = torch.norm(x - y)
return 20 * torch.log10(Ps / Pn)
Expand Down
2 changes: 1 addition & 1 deletion test/test_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
set -e

pytest test/test_base.py
pytest test/test_sam.py
# pytest test/test_sam.py
pytest test/test_compile.py
./test/test_fsdp.sh
./test/test_fsdp_compile.sh
Expand Down

0 comments on commit e29cc35

Please sign in to comment.