diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 0aaf2bd..9e8285e 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -300,7 +300,7 @@ def cast_x_to_float8( if torch.is_autocast_enabled(): # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it - autocast_dtype = torch.get_autocast_gpu_dtype() + autocast_dtype = torch.get_autocast_dtype("cuda") x = x.to(autocast_dtype) if self.scaling_type_x is TensorScalingType.DELAYED: diff --git a/float8_experimental/float8_linear_utils.py b/float8_experimental/float8_linear_utils.py index 818fef0..1156d18 100644 --- a/float8_experimental/float8_linear_utils.py +++ b/float8_experimental/float8_linear_utils.py @@ -227,7 +227,7 @@ def sync_float8_amax_and_scale_history(model: torch.nn.Module, fp8_layers=None) fp8_layers = get_float8_layers(model) if len(fp8_layers) == 0: - log.warn( + log.warning( "Calling sync_float8_amax_and_scale_history on a module with no Float8Linear layers" ) return diff --git a/test/test_base.py b/test/test_base.py index 1470bdd..100dcef 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -10,6 +10,8 @@ import re import unittest import warnings +from itertools import product +from typing import Any, Callable, Dict, List, Optional, Tuple import pytest @@ -50,6 +52,42 @@ is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) +def filtered_parametrize( + param_list: List[Tuple[str, List[Any]]], + filter_func: Optional[Callable[[Dict[str, Any]], bool]] = None, +): + """ + A decorator that works like pytest.mark.parametrize but filters out + unwanted parameter combinations. + + Args: + param_list: A list of tuples, each containing (arg_name, [arg_values]) + filter_func: A function that takes a dictionary of parameter names and values, + and returns True for valid combinations, False otherwise + + """ + + def decorator(func): + arg_names = [param[0] for param in param_list] + arg_values = [param[1] for param in param_list] + + all_combinations = product(*arg_values) + if filter_func: + valid_combinations = [ + combo + for combo in all_combinations + if filter_func(dict(zip(arg_names, combo))) + ] + else: + valid_combinations = list(all_combinations) + + return pytest.mark.parametrize( + argnames=arg_names, argvalues=valid_combinations + )(func) + + return decorator + + def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: assert torch.all(a._data == b._data).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" @@ -243,32 +281,31 @@ def test_linear( scaling_type_x: TensorScalingType, scaling_type_w: TensorScalingType, scaling_type_dL_dY: TensorScalingType, - linear_dtype: torch.dtype, - linear_bias: bool, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() - x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype) - m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype) + x = torch.randn(*x_shape, device="cuda") + m_ref = nn.Linear(16, 32, bias=False, device="cuda") self._test_linear_impl( x, m_ref, + linear_type, emulate, scaling_type_x, scaling_type_w, scaling_type_dL_dY, ) - - @pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True]) - @pytest.mark.parametrize( - "linear_dtype", [torch.float16, torch.bfloat16, torch.float32] + + @filtered_parametrize( + [ + ("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]), + ("emulate", [True, False] if is_H100 else [True]), + ("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]), + ( + "scaling_type_dL_dY", + [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC], + ), + ("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]), + ], ) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_autocast_outputs( @@ -276,15 +313,6 @@ def test_autocast_outputs( emulate: bool, linear_dtype: torch.dtype, ): - if not emulate: - if not torch.cuda.is_available(): - warnings.warn("CUDA not available") - pytest.skip() - elif torch.cuda.get_device_capability() < (9, 0): - warnings.warn( - f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)" - ) - pytest.skip() m_ref = nn.Linear(32, 16, device="cuda", dtype=linear_dtype) kwargs = {