Skip to content

Commit

Permalink
[Bugfix][Kernel][Misc] Basic support for SmoothQuant, symmetric case (#…
Browse files Browse the repository at this point in the history
…237)

* basic support for symmetric smoothquant model

* yap

* yapf

* Put implementation into torch function

* yapf
  • Loading branch information
rasmith authored Oct 24, 2024
1 parent 842ea55 commit c9fc160
Showing 1 changed file with 22 additions and 3 deletions.
25 changes: 22 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import functools
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union

import torch
import torch.library
Expand Down Expand Up @@ -491,6 +491,22 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


def scaled_mm_torch(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: Type[torch.dtype],
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
out = out.to(out_dtype)
if bias is not None:
out = out + bias

return out


def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
Expand All @@ -504,9 +520,12 @@ def cutlass_scaled_mm(a: torch.Tensor,

m = a.shape[0]
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)

torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
if is_hip():
return scaled_mm_torch(a, b, scale_a, scale_b, out_dtype, bias)
else:
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

return out

Expand Down

0 comments on commit c9fc160

Please sign in to comment.