Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Aug 31, 2024
2 parents 6a04749 + b1f5e79 commit 9d90c40
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,6 @@ def warmup(self, topk=20):
self.bitblas_matmul.hardware_aware_finetune(topk=topk)

def forward(self, A, output=None):
if A.dtype != torch.float16:
A = A.half()
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

Expand All @@ -277,7 +275,9 @@ def forward(self, A, output=None):
args = [A_void, *self.q_params]
if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
A.shape[:-1] + (self.out_features,),
dtype=getattr(torch, self.bitblas_matmul.out_dtype),
device=A.device)
args.append(ctypes.c_void_p(output.data_ptr()))
if self.bitblas_matmul.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
Expand Down

0 comments on commit 9d90c40

Please sign in to comment.