diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 889da6b479d32..12b9d97091274 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import torch import torch.library @@ -491,6 +491,22 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def scaled_mm_torch(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a * out + out = scale_b.T * out + out = out.to(out_dtype) + if bias is not None: + out = out + bias + + return out + + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -504,9 +520,12 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + if is_hip(): + return scaled_mm_torch(a, b, scale_a, scale_b, out_dtype, bias) + else: + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out