Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dev] Improve General Matmul With Splitk #50

Merged
merged 19 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/QuickStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Here is an example for a $W_{INT4}A_{FP16}$ mixed-precision matrix multiplicatio
import bitblas
import torch

# enabling debug output

bitblas.set_debug_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
Expand Down Expand Up @@ -125,6 +128,9 @@ Here is an example to define a ```bitblas.Linear``` of $W_{INT4}A_{FP16}$:
import bitblas
import torch

# enabling debug output
bitblas.set_debug_level("Debug")

model = bitblas.Linear(
in_features=1024,
out_features=1024,
Expand Down Expand Up @@ -178,6 +184,9 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
QuantLinear as CudaOldQuantLinear,
)

# enabling debug output
bitblas.set_debug_level("Debug")

in_features = 1024
out_features = 1024
group_size = 128
Expand Down
7 changes: 5 additions & 2 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,18 @@ def forward(self, A, output=None):
A = A.half()
# can be lifted to post init.
self.init_params()

if output is None:
output = torch.empty(
A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device)
m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1))
A = self.bitblas_matmul.transform_input(A)
stream = torch.cuda.current_stream()

A_void = ctypes.c_void_p(A.data_ptr())
stream_handle = ctypes.c_void_p(stream.cuda_stream)
# m is the product of the last n - 1 dimensions of A
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m)
self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, stream_handle)

return output

Expand Down
11 changes: 8 additions & 3 deletions python/bitblas/ops/general_matmul_splitk.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:

if output is None:
output = torch.empty(
(self.k_split,) + A.shape[:-1] + (self.N,),
A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
if scale is not None:
Expand All @@ -169,7 +169,12 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args.append(zeros)
if bias is not None:
args.append(bias)
args.append(output)

sk_output = torch.empty((self.k_split,) +
A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
args.append(sk_output)

if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
Expand All @@ -180,7 +185,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
if self.lib is None:
self._forward_from_torch_func(*args)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)
output = torch.sum(output, dim=0)
torch.sum(sk_output, dim=0, out=output)
return output

def __call__(self, *args: Any, **kwds: Any) -> Any:
Expand Down
4 changes: 3 additions & 1 deletion testing/python/operators/test_general_matmul_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,6 @@ def map_torch_type(intype):

# fmt: on
if __name__ == "__main__":
bitblas.testing.main()
# bitblas.testing.main()
test_matmul_torch_forward_weight_dequantize(1024, 1024, 1024, "float16", "e4m3_float8", "float16", "float16", "nt", None, None, None,
None, None)
70 changes: 50 additions & 20 deletions testing/python/operators/test_general_matmul_splitk_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,22 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)
assert get_codegen_result(matmul)


@pytest.mark.parametrize(
"M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
"SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
[
(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False,
None),
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
False, None),
],
)
def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias,
group_size, with_scaling, with_zeros, zeros_mode):

def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
layout, with_bias, group_size, with_scaling, with_zeros,
zeros_mode):
import torch
torch.random.manual_seed(0)
matmul_config = MatmulConfigWithSplitK(
k_split=SplitK,
M=M,
N=N,
K=K,
Expand All @@ -70,20 +72,27 @@ def test_matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo
zeros_mode=zeros_mode,
)
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)
matmul.hardware_aware_finetune(topk=10)
assert get_codegen_result(matmul)

input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5)

output_bitblas = matmul.forward(*inputs)
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1)

@pytest.mark.parametrize(
"SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode",
[
(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
(1, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False,
False, None),
(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False,
(4, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False,
False, None),
],
)
def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype,
layout, with_bias, group_size, with_scaling, with_zeros,
zeros_mode):
import torch
Expand All @@ -103,18 +112,39 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu
with_scaling=with_scaling,
with_zeros=with_zeros,
zeros_mode=zeros_mode,
propagate_a=False,
propagate_b=False,
)
matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False)

input_shape = (M, K)
weight_shape = (N, K) if layout == "nt" else (K, N)
inputs = []
inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5)
inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda() - 0.5)
def map_torch_type(intype):

output_bitblas = matmul.forward(*inputs)
output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1])
torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1)
typemap = {
'e4m3_float8': torch.float8_e4m3fn,
'e5m2_float8': torch.float8_e5m2,
}
if intype in typemap:
return typemap[intype]
else:
return getattr(torch, intype)

numpytype_a = map_torch_type(A_dtype)
numpytype_b = map_torch_type(W_dtype)

torch_a = torch.rand(M * K).uniform_(-1, 1).reshape(input_shape).type(numpytype_a).cuda()
torch_b = torch.rand(N * K).uniform_(-1, 1).reshape(weight_shape).type(numpytype_b).cuda()
ref_out = torch.matmul(torch_a.to(torch.float32),
torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(
torch_a.to(torch.float32), torch_b.to(torch.float32))
ref_out = ref_out.to(torch.float16)
bitblas_out = torch.empty_like(ref_out)
matmul.forward(torch_a, torch_b, output=bitblas_out)
print("torch_ref_out", ref_out)
print("bitblas_out", bitblas_out)

torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1)


# fmt: on
Expand Down
Loading