From 13f2c26eb1be6aae9a95f679adf5682c7e006dc5 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 9 Jul 2024 11:02:51 -0700 Subject: [PATCH] fix nits from deletion of Float8DynamicLinear (#308) Summary: Pull Request resolved: https://github.com/pytorch-labs/float8_experimental/pull/308 Addressing a couple of nits that slipped in https://github.com/pytorch-labs/float8_experimental/pull/304 * more defaults to dynamic * undo repr change * fix comment Reviewed By: drisspg Differential Revision: D59521233 fbshipit-source-id: 5f69855cc2d19c6057a230b0963185c4396dcd99 --- float8_experimental/float8_linear.py | 16 ++++++++-------- test/test_dtensor.py | 3 +-- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index 664a03a..a7dd2d2 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -165,9 +165,9 @@ def __init__(self, *args, **kwargs): # Amax scales should always be kept as float32. self.always_float32_buffers = set() emulate = kwargs.pop("emulate", False) - scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DELAYED) - scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DELAYED) - scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DELAYED) + scaling_type_x = kwargs.pop("scaling_type_x", TensorScalingType.DYNAMIC) + scaling_type_w = kwargs.pop("scaling_type_w", TensorScalingType.DYNAMIC) + scaling_type_dL_dY = kwargs.pop("scaling_type_dL_dY", TensorScalingType.DYNAMIC) super().__init__(*args, **kwargs) # Defines the scaling behavior of x, w, dL_dY @@ -402,8 +402,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def scaling_repr(self): # add scaling settings without using too many characters - # example: "x_del_w_del_dldy_dyn" - return f"x_{self.scaling_type_x.short_str()}_w_{self.scaling_type_w.short_str()}_dldy_{self.scaling_type_dL_dY.short_str()}" + # example: "x:del,w:del,dldy:dyn" + return f"x:{self.scaling_type_x.short_str()},w:{self.scaling_type_w.short_str()},dldy:{self.scaling_type_dL_dY.short_str()}" def extra_repr(self): s = f'{super().extra_repr()}, scaling="{self.scaling_repr()}"' @@ -414,9 +414,9 @@ def from_float( cls, mod, emulate: bool = False, - scaling_type_x=TensorScalingType.DELAYED, - scaling_type_w=TensorScalingType.DELAYED, - scaling_type_dL_dY=TensorScalingType.DELAYED, + scaling_type_x=TensorScalingType.DYNAMIC, + scaling_type_w=TensorScalingType.DYNAMIC, + scaling_type_dL_dY=TensorScalingType.DYNAMIC, ): """ Create an nn.Linear with fp8 compute from a regular nn.Linear diff --git a/test/test_dtensor.py b/test/test_dtensor.py index 6506ee7..2088b78 100644 --- a/test/test_dtensor.py +++ b/test/test_dtensor.py @@ -171,8 +171,7 @@ def _test_fp8_mlp_tensor_parallelism_base( mesh: DeviceMesh, size=16, compile: bool = False ): device = mesh.device_type - # For now, just use Float8Linear with dynamic scaling, which is the - # same behavior as Float8Linear. + # For now, only supports dynamic scaling of `x` and `dL_dY`. # TODO(future): add support for float8 all-gather with delayed scaling # for activations and gradients. extra_kwargs = {