Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Commit

Permalink
Add utility for filtering out skpped tests in large paremtrization gr…
Browse files Browse the repository at this point in the history
…oups

ghstack-source-id: d99192cee644bc03f310fc113c0c48251de5a88c
Pull Request resolved: #303
  • Loading branch information
drisspg committed Jul 17, 2024
1 parent 7e7fbec commit 52e5d0a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 28 deletions.
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def cast_x_to_float8(
if torch.is_autocast_enabled():
# For now, hardcode to GPU's autocast dtype
# if we need CPU support in the future, we can add it
autocast_dtype = torch.get_autocast_gpu_dtype()
autocast_dtype = torch.get_autocast_dtype("cuda")
x = x.to(autocast_dtype)

if self.scaling_type_x is TensorScalingType.DELAYED:
Expand Down
2 changes: 1 addition & 1 deletion float8_experimental/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None)
fp8_layers = get_float8_layers(model)

if len(fp8_layers) == 0:
log.warn(
log.warning(
"Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers"
)
return
Expand Down
80 changes: 54 additions & 26 deletions test/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import re
import unittest
import warnings
from itertools import product
from typing import Any, Callable, Dict, List, Optional, Tuple

import pytest

Expand Down Expand Up @@ -50,6 +52,42 @@
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)


def filtered_parametrize(
param_list: List[Tuple[str, List[Any]]],
filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None,
):
"""
A decorator that works like pytest.mark.parametrize but filters out
unwanted parameter combinations.
Args:
param_list: A list of tuples, each containing (arg_name, [arg_values])
filter_func: A function that takes a dictionary of parameter names and values,
and returns True for valid combinations, False otherwise
"""

def decorator(func):
arg_names = [param[0] for param in param_list]
arg_values = [param[1] for param in param_list]

all_combinations = product(*arg_values)
if filter_func:
valid_combinations = [
combo
for combo in all_combinations
if filter_func(dict(zip(arg_names, combo)))
]
else:
valid_combinations = list(all_combinations)

return pytest.mark.parametrize(
argnames=arg_names, argvalues=valid_combinations
)(func)

return decorator


def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
Expand Down Expand Up @@ -243,48 +281,38 @@ def test_linear(
scaling_type_x: TensorScalingType,
scaling_type_w: TensorScalingType,
scaling_type_dL_dY: TensorScalingType,
linear_dtype: torch.dtype,
linear_bias: bool,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
x = torch.randn(*x_shape, device="cuda")
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
self._test_linear_impl(
x,
m_ref,
linear_type,
emulate,
scaling_type_x,
scaling_type_w,
scaling_type_dL_dY,
)

@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
@pytest.mark.parametrize(
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]

@filtered_parametrize(
[
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
("emulate", [True, False] if is_H100 else [True]),
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
(
"scaling_type_dL_dY",
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
),
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_autocast_outputs(
self,
emulate: bool,
linear_dtype: torch.dtype,
):
if not emulate:
if not torch.cuda.is_available():
warnings.warn("CUDA not available")
pytest.skip()
elif torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
)
pytest.skip()

m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype)
kwargs = {
Expand Down

0 comments on commit 52e5d0a

Please sign in to comment.