diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 3ecacafd31977..2ab9b89eef7ed 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -91,7 +91,7 @@ def mm(self, inp, weights): self.tuned_df.to_csv(self.untune_path, index=False) if ((n == 4 or n == 3 or n == 2 or n == 1) and k % 8 == 0 - and inp_view.dtype == torch.float16): + and m > 8 and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype,