From e7c04637dd5318a5738bb76b8adce2187b084b02 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 26 Jul 2024 08:51:36 -0700 Subject: [PATCH] Update on "[2/x] clean up casting functions: delayed scaling" Summary: Removes delayed scaling from `float8_tensor.py`. After this PR, the invariant is that everything in `float8_tensor.py` requires the scale to be calculated elsewhere. This moves the codebase towards separation of concerns for calculating the scale (via various scaling strategies), separated from creating an instance of `Float8Tensor`. Note that stateful delayed scaling is the reason we need this separation. Test Plan: ``` ./test/test_everything.sh ``` Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D60291447](https://our.internmc.facebook.com/intern/diff/D60291447) [ghstack-poisoned] --- test/test_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_compile.py b/test/test_compile.py index 05bb57a..db73471 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -20,9 +20,9 @@ get_float8_layers, sync_float8_amax_and_scale_history, ) +from float8_experimental.float8_scaling_utils import cast_to_float8_delayed from float8_experimental.float8_tensor import LinearMMConfig from float8_experimental.float8_utils import e4m3_dtype -from float8_experimental.float8_scaling_utils import cast_to_float8_delayed from torch._dynamo.test_case import TestCase as DynamoTestCase from torch._dynamo.testing import CompileCounterWithBackend