diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index eb3811ad5dbfc..427a0336f4645 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -90,7 +90,8 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if ((n == 4 or n == 3 or n== 2 or n == 1) + if ((n == 4 or n == 3 or n == 2 or n == 1) + and k % 8 == 0 and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0],