diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index c1cf316f..b427a3c6 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -265,8 +265,6 @@ def warmup(self, topk=20): self.bitblas_matmul.hardware_aware_finetune(topk=topk) def forward(self, A, output=None): - if A.dtype != torch.float16: - A = A.half() A = self.bitblas_matmul.transform_input(A) stream = torch.cuda.current_stream() @@ -277,7 +275,9 @@ def forward(self, A, output=None): args = [A_void, *self.q_params] if output is None: output = torch.empty( - A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + A.shape[:-1] + (self.out_features,), + dtype=getattr(torch, self.bitblas_matmul.out_dtype), + device=A.device) args.append(ctypes.c_void_p(output.data_ptr())) if self.bitblas_matmul.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1)