Skip to content

Commit

Permalink
cut off WA for tunned gemms
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleksandr Malyshev committed Oct 4, 2024
1 parent 34d2658 commit 9ed31a8
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
5 changes: 0 additions & 5 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
VLLM_SYNC_SERVER_ENGINE_STEPS_BETWEEN_POLLS: int = 1
VLLM_MOE_PADDING: bool = False
VLLM_FP8_PADDING: bool = True
VLLM_NO_TUNED_GEMM: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -487,10 +486,6 @@ def get_default_config_root():
# Pad the weight for moe kernel or not
"VLLM_FP8_PADDING":
lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))),

# Not supported by mllama3.2
"VLLM_NO_TUNED_GEMM":
lambda: bool(int(os.getenv("VLLM_NO_TUNED_GEMM", "0"))),
}

# end-env-vars-definition
Expand Down
7 changes: 1 addition & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter

import vllm.envs as envs
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
Expand Down Expand Up @@ -133,10 +131,7 @@ def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if envs.VLLM_NO_TUNED_GEMM:
return F.linear(x, layer.weight, bias)
else:
return tgemm.mm(x, layer.weight, bias)
return tgemm.mm(x, layer.weight, bias)


class LinearBase(torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def mm(self, inp, weights, bias=None):
# uses this for linear units. However, sampler
# will use torch.matmul with 2 dimensions only
if inp.dim() == 3:
inp_view = inp.view(-1, inp.size(-1))
inp_view = inp.reshape(-1, inp.size(-1))
batched = True
else:
inp_view = inp
Expand Down

0 comments on commit 9ed31a8

Please sign in to comment.