Skip to content

Commit

Permalink
Moved bias addition inside tgemm
Browse files Browse the repository at this point in the history
  • Loading branch information
gshtras committed Jun 20, 2024
1 parent aba49c6 commit fb920fa
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 11 deletions.
10 changes: 4 additions & 6 deletions vllm/model_executor/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
multiplication.
"""

def __init__(self, separate_bias_add: bool = True):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add

def create_weights(self, layer: torch.nn.Module,
Expand All @@ -90,11 +90,9 @@ def apply(self,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
if bias is not None:
if self.separate_bias_add:
return tgemm.mm(x, weight) + bias
return F.linear(x, weight, bias)
return tgemm.mm(x, weight)
if self.separate_bias_add and bias is not None:
return tgemm.mm(x, weight) + bias
return F.linear(x, weight, bias)


class LinearBase(torch.nn.Module):
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/layers/tuned_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def apply_skinny(self, m, n, k, inp_view, weights):
else:
return None

def mm(self, inp, weights):
def mm(self, inp, weights, bias = None):
# F.Linear can take a 3 dimensional input. vllm
# uses this for linear units. However, sampler
# will use torch.matmul with 2 dimensions only
Expand Down Expand Up @@ -107,11 +107,12 @@ def mm(self, inp, weights):
})
]).drop_duplicates()
self.tuned_df.to_csv(self.untune_path, index=False)
out = F.linear(inp_view, weights)
return F.linear(inp, weights, bias)
if batched:
return out.view(inp.shape[0], inp.shape[1], weights.shape[0])
else:
return out
out = out.view(inp.shape[0], inp.shape[1], weights.shape[0])
if bias is not None:
return out + bias
return out


tgemm = TunedGemm()

0 comments on commit fb920fa

Please sign in to comment.