From 06d4546832d11923cb74881e56e524e7b85f0111 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 17 Dec 2024 06:25:49 +0900 Subject: [PATCH] fix column major check & add dtype check of data mat Signed-off-by: Masaki Kozuki --- thunder/executors/torchex.py | 6 +++++- thunder/torch/__init__.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 5670a0067e..d90fb22c10 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1417,8 +1417,12 @@ def _scaled_mm_transform( out_dtype: dtypeLike | None = None, use_fast_accum: bool = False, ): + + def is_column_major(mat: TensorLike) -> bool: + return mat.stride()[0] == 1 and mat.stride()[0] > 1 + result_dtype: torch.dtype = to_torch_dtype(a.dtype if out_dtype is None else out_dtype) - if b.stride()[0] != 1 and b.stride()[1] > 1: + if not is_column_major(b): b = b.t().contiguous().t() return _scaled_mm(a, b, scale_a, scale_b, bias, scale_result, result_dtype, use_fast_accum) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index d5d5efecf8..66ef8e5727 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -3608,6 +3608,7 @@ def _scaled_mm( and (a.shape[1] == b.shape[0]) and (a.shape[1] % 16 == 0 and b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) and (to_dtype(a.dtype) in fp8_dtypes and to_dtype(b.dtype) in fp8_dtypes) + and not (a.dtype == dtypes.float8_e5m2 and b.dtype == dtypes.float8_e5m2) ), lambda: f"data matrices of {a=} and {b=} do not satisfy the condition.", )