diff --git a/bitblas/ops/impl/ladder_permutate_impl.py b/bitblas/ops/impl/ladder_permutate_impl.py index 8086bf584..76b5a01fb 100644 --- a/bitblas/ops/impl/ladder_permutate_impl.py +++ b/bitblas/ops/impl/ladder_permutate_impl.py @@ -49,6 +49,7 @@ def select_implementation( inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) args = [inp] + assert transform_kind != 0, "Permute only apply when transform_kind >= 1" if transform_kind >= 1: arg = args[-1] diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index d1555a674..f3d391778 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -106,8 +106,7 @@ def tvm_callback_cuda_postproc(code, _): **self.pass_context }): rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) - except Exception as e: - rt_build_error = e # noqa + except Exception: # noqa: F841 logger.debug( "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" ) diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index df3ae39d9..11be4059f 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -54,7 +54,7 @@ if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa: F401 logger = logging.get_logger(__name__) diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index 307308bf1..fcdf90239 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -11,16 +11,7 @@ def get_codegen_result(ops): # fmt: off -@pytest.mark.parametrize( - "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - (16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, - None), - ], -) -def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, +def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): matmul_config = MatmulConfigWithSplitK( @@ -37,21 +28,21 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) assert get_codegen_result(matmul) -@pytest.mark.parametrize( - "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, - False, None), - (4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, - False, None), - ], -) -def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, +def test_matmul_codegen_default(): + matmul_codegen_default(1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None) + matmul_codegen_default(16, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, + None) + + +def matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch @@ -71,6 +62,8 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, ) matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) @@ -84,17 +77,13 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) +def test_matmul_torch_forward_consistent(): + matmul_torch_forward_consistent(1, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None) + matmul_torch_forward_consistent(4, 1, 4096, 12800, "float16", "float16", "float16", "float16", "nt", False, -1, False, + False, None) -@pytest.mark.parametrize( - "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", - [ - (1, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, - False, None), - (4, 16, 4096, 12800, "float16", "e4m3_float8", "float32", "float16", "nt", False, -1, False, - False, None), - ], -) -def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, +def matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, group_size, with_scaling, with_zeros, zeros_mode): import torch @@ -157,6 +146,12 @@ def map_torch_type(intype): torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1) +@bitblas.testing.requires_cuda_compute_version(8, 9) +def test_matmul_torch_forward_fp8e4m3(): + matmul_torch_forward_fp8e4m3(1, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", "float16", "nt", False, -1, False, + False, None) + matmul_torch_forward_fp8e4m3(4, 16, 4096, 12800, "e4m3_float8", "e4m3_float8", "float32", "float16", "nt", False, -1, False, + False, None) # fmt: on if __name__ == "__main__": diff --git a/testing/python/operators/test_ladder_permutate_ops.py b/testing/python/operators/test_ladder_permutate_ops.py index b302edcc7..8fa54a4ca 100644 --- a/testing/python/operators/test_ladder_permutate_ops.py +++ b/testing/python/operators/test_ladder_permutate_ops.py @@ -9,16 +9,7 @@ # fmt: off -@pytest.mark.parametrize( - "M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", - [ - (1024, 1024, "float16", -1, "float16", "B", True, 0, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma"), - # dequantize propagation - (1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma"), - ]) -def test_ladder_permutate_profile_latency( +def ladder_permutate_profile_latency( M, N, datatype, @@ -49,16 +40,13 @@ def test_ladder_permutate_profile_latency( assert latency -@pytest.mark.parametrize( - "M,N,datatype,dequantize_bits,storage_dtype,propagate_kind,transpose_matrix,transform_kind,target_instruction", - [ - (1024, 1024, "float16", -1, "float16", "A", True, 0, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "A", True, 1, "nvidia-mma"), - (1024, 1024, "float16", -1, "float16", "A", True, 2, "nvidia-mma"), - # dequantize propagation - (1024, 1024, "float16", 4, "uint32", "A", True, 2, "nvidia-mma"), - ]) -def test_ladder_permutate_profile_latency_cuda( +def test_ladder_permutate_profile_latency(): + ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 1, "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", -1, "float16", "B", True, 2, "nvidia-mma") + ladder_permutate_profile_latency(1024, 1024, "float16", 4, "uint32", "B", True, 2, "nvidia-mma") + + +def ladder_permutate_profile_latency_cuda( M, N, datatype, @@ -91,6 +79,10 @@ def test_ladder_permutate_profile_latency_cuda( assert latency +def test_ladder_permutate_profile_latency_cuda(): + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 1, "nvidia-mma") + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", -1, "float16", "A", True, 2, "nvidia-mma") + ladder_permutate_profile_latency_cuda(1024, 1024, "float16", 4, "uint32", "A", True, 2, "nvidia-mma") # fmt: on if __name__ == "__main__": diff --git a/testing/python/operators/test_matmul_dequantize_ops.py b/testing/python/operators/test_matmul_dequantize_ops.py deleted file mode 100644 index c727dc7e8..000000000 --- a/testing/python/operators/test_matmul_dequantize_ops.py +++ /dev/null @@ -1,922 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import pytest -import bitblas -from bitblas.utils import auto_detect_nvidia_target -from bitblas.ops.matmul_dequantize import ( - MatmulWeightOnlyDequantize, - MatmulWeightOnlyDequantizeConfig, -) -from bitblas import tvm -import logging -from bitblas import set_log_level - -set_log_level(logging.DEBUG) -target = tvm.target.Target(auto_detect_nvidia_target()) - - -def get_codegen_result(ops, target): - code = ops.get_source(target=target) - return code - - -# fmt: off -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,propagate_a,propagate_b,zeros_mode", - [ - (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, "nt", False, False, "original"), - (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, "nt", False, False, "original"), - (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, -1, True, - True, "nt", False, True, "original"), - ], -) -def test_matmul_dequantize_codegen_default(M, N, K, in_dtype, out_dtype, accum_dtype, bit, - storage_dtype, source_format, with_scaling, with_zeros, - group_size, fast_decoding, with_bias, layout, - propagate_a, propagate_b, zeros_mode): - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - zeros_mode=zeros_mode, - ) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - assert get_codegen_result(matmul, target) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", - [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, - False, False, False, "nt"), - ], -) -def test_matmul_dequantize_retrieve_weight_shape( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - propagate_a, - propagate_b, - layout, -): - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - assert matmul.retrieve_weight_shape() - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", - [ - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "uint", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "uint", - False, - False, - -1, - False, - False, - False, - True, - "nt", - ), - ], -) -def test_matmul_dequantize_codegen_finetune( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - propagate_a, - propagate_b, - layout, -): - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - matmul.hardware_aware_finetune(topk=20) - assert get_codegen_result(matmul, target) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout", - [ - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "uint", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - False, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - False, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - False, - False, - -1, - False, - False, - True, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - True, - False, - -1, - False, - False, - True, - True, - "nt", - ), - ( - 1024, - 1024, - 1024, - "float16", - "float16", - "float16", - 4, - "int8", - "nf", - True, - False, - 128, - False, - False, - True, - True, - "nt", - ), - ], -) -def test_matmul_dequantize_profile_latency( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - propagate_a, - propagate_b, - layout, -): - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - matmul.hardware_aware_finetune(topk=20) - latency = matmul.profile_latency() - assert latency - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode", - [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, False, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "int", False, False, -1, True, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", False, False, -1, True, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, -1, True, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, True, - False, False, False, "nt", "rescale"), - (1, 1024, 1024, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, False, - False, False, False, "nt", "rescale"), - (1, 1024, 4096, "float16", "float16", "float16", 2, "int8", "uint", True, True, 128, True, - False, False, False, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, - False, False, False, False, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, - False, False, False, True, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, - False, False, True, True, "nt", "rescale"), - (1024, 1024, 1024, "float16", "float16", "float16", 2, "int8", "int", True, False, 128, - False, False, True, True, "nt", "original"), - ([1, 1024], 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, - -1, False, False, False, False, "nt", "original"), - ], -) -def test_matmul_dequantize_torch_forward(M, N, K, in_dtype, out_dtype, accum_dtype, bit, - storage_dtype, source_format, with_scaling, with_zeros, - group_size, fast_decoding, with_bias, propagate_a, - propagate_b, layout, zeros_mode): - import torch - torch.random.manual_seed(0) - import numpy as np - from bitblas.quantization.utils import general_compress - - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - zeros_mode=zeros_mode) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - if not isinstance(M, int): - M = 32 - matmul.hardware_aware_finetune(topk=20) - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - maxq = 2 ** (bit - 1) - zeros = maxq - if source_format == "uint": - inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) - elif source_format == "int": - inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) - else: - raise NotImplementedError - - inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) - - intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) - if source_format == "int": - intweight = intweight + maxq - if with_zeros: - inputs[1] = inputs[1] - zeros - bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() - ref_result = torch.matmul(inputs[0], - (inputs[1].t() if layout == "nt" else inputs[1]).to(torch.float16)) - if with_bias: - ref_result = ref_result + bias - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() - permuted_inputs = [] - if matmul.input_transform is not None: - permuted_inputs.append(matmul.input_transform(inputs[0].cpu()).cuda()) - else: - permuted_inputs.append(inputs[0]) - if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) - else: - permuted_inputs.append(qw_torch) - if with_scaling: - if group_size == -1: - group_size = K - # Note that scaling is default to all 1 - permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda()) - if with_zeros: - permuted_inputs.append(torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) - if with_bias: - permuted_inputs.append(bias) - permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs) - if zeros_mode == "rescale": - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-0) - else: - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-0, atol=1e-1) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,zeros_mode", - [ - (1, 768, 768, "float16", "float16", "float16", 2, "int8", "uint", True, False, 128, False, - False, False, False, "nt", "quantized"), - (1, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", True, True, 128, False, - False, False, False, "nt", "quantized"), - ], -) -def test_matmul_dequantize_torch_forward_with_asym_quantized_zeros(M, N, K, in_dtype, out_dtype, accum_dtype, bit, - storage_dtype, source_format, with_scaling, with_zeros, - group_size, fast_decoding, with_bias, propagate_a, - propagate_b, layout, zeros_mode): - import torch - import numpy as np - torch.random.manual_seed(0) - from bitblas.quantization.utils import general_compress - matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - zeros_mode=zeros_mode) - matmul = MatmulWeightOnlyDequantize( - config=matmul_config, - target=target, - ) - if not isinstance(M, int): - M = int(32) - # matmul.hardware_aware_finetune(topk=20) - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - scaling_shape = (N, K // group_size) - zeros_shape = (K // group_size, N) - - input_A = torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5 - max_quantization = 2 ** (bit - 1) - scaling_matrix = torch.rand(scaling_shape, dtype=torch.float16).cuda() - zeros_matrix = torch.randint(0, max_quantization, zeros_shape, dtype=torch.int8).cuda() - bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() - - if source_format == "uint": - input_W = torch.randint(0, max_quantization, weight_shape, dtype=torch.int8).cuda() - elif source_format == "int": - input_W = torch.randint(-max_quantization, max_quantization, weight_shape, dtype=torch.int8).cuda() - else: - raise NotImplementedError - - # Now begin bitblas matmul - input_W_int = input_W.cpu().numpy().astype(np.int8) - if source_format == "int": - input_W_int = input_W_int + max_quantization - qw_np = general_compress(input_W_int, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() - - permuted_inputs = [] - # input and weight - if matmul.input_transform is not None: - permuted_inputs.append(matmul.input_transform(input_A.cpu()).cuda()) - else: - permuted_inputs.append(input_A) - if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(qw_torch.cpu()).cuda()) - else: - permuted_inputs.append(qw_torch) - # scale - if with_scaling: - if group_size == -1: - group_size = K - permuted_inputs.append(scaling_matrix) - # zeros - if with_zeros: - if zeros_mode == "quantized": - original_zeros = zeros_matrix - qzeros = general_compress( - original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) - permuted_inputs.append(torch.from_numpy(qzeros).cuda()) - else: - raise NotImplementedError - # bias - if with_bias: - permuted_inputs.append(bias) - # output - permuted_inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) - matmul(*permuted_inputs) - bitblas_result = permuted_inputs[-1] - - # Now begin torch matmul - if with_scaling and with_zeros and zeros_mode == "quantized": - rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda() - for i in range(K // group_size): - for j in range(group_size): - rescaling_tensor[:, i * group_size + j] = ( - input_W[:, i * group_size + j].to(torch.float16) - zeros_matrix[i, :] - ) * scaling_matrix[:, i] - elif with_scaling: - rescaling_tensor = torch.zeros_like(input_W, dtype=torch.float16).cuda() - for i in range(K // group_size): - for j in range(group_size): - rescaling_tensor[:, i * group_size + j] = input_W[:, i * group_size + j].to(torch.float16) * scaling_matrix[:, i] - ref_result = torch.matmul(input_A, rescaling_tensor.t().to(torch.float16)) - - torch.testing.assert_close(bitblas_result, ref_result, rtol=1e-1, atol=1e-1) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,layout,zeros_mode", - [ - (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, "nt", "original"), - (16, 768, 768, "float16", "float16", "float16", 4, "int8", "uint", False, True, -1, True, - True, "nt", "original"), - (16, 3072, 768, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, "nt", "original"), - (16, 768, 3072, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, True, - False, "nt", "original"), - ], -) -def test_matmul_dequantize_propgate_comparison(M, N, K, in_dtype, out_dtype, accum_dtype, bit, - storage_dtype, source_format, with_scaling, - with_zeros, group_size, fast_decoding, with_bias, - layout, zeros_mode): - import torch - torch.random.manual_seed(0) - original_matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=False, - with_bias=with_bias, - propagate_a=False, - propagate_b=False, - layout=layout, - zeros_mode=zeros_mode) - original_matmul = MatmulWeightOnlyDequantize( - config=original_matmul_config, - target=target, - ) - if not isinstance(M, int): - M = 32 - - if group_size == -1: - group_size = K - input_shape = (M, K) - weight_shape = (N, K // 2) if layout == "nt" else (K, N) - output_shape = (M, N) - scales_shape = (N, K // group_size) - zeros_shape = (N, K // group_size) - bias_shape = (N,) - - inputs = [] - input_tensor = torch.rand(input_shape, dtype=torch.float16).cuda() - weight_tensor = torch.randint(0, 2**(bit - 1) - 1, weight_shape, dtype=torch.int8).cuda() - scales_tensor = torch.rand(scales_shape, dtype=torch.float16).cuda() - zeros_tensor = torch.rand(zeros_shape, dtype=torch.float16).cuda() - bias_tensor = torch.rand(bias_shape, dtype=torch.float16).cuda() - output_tensor = torch.zeros(output_shape, dtype=torch.float16).cuda() - inputs.append(input_tensor) - inputs.append(weight_tensor) - if with_scaling: - inputs.append(scales_tensor) - if with_zeros: - inputs.append(zeros_tensor) - if with_bias: - inputs.append(bias_tensor) - inputs.append(output_tensor) - ref_result = original_matmul(*inputs) - - propagated_matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=False, - propagate_b=True, - layout=layout, - zeros_mode=zeros_mode) - propagated_matmul = MatmulWeightOnlyDequantize( - config=propagated_matmul_config, - target=target, - ) - - propagated_matmul.hardware_aware_finetune(topk=20) - propagated_inputs = [] - propagated_inputs.append(input_tensor) - if propagated_matmul.weight_transform is not None: - propagated_inputs.append(propagated_matmul.weight_transform(weight_tensor.cpu()).cuda()) - else: - propagated_inputs.append(weight_tensor) - if with_scaling: - propagated_inputs.append(scales_tensor) - if with_zeros: - propagated_inputs.append(zeros_tensor) - if with_bias: - propagated_inputs.append(bias_tensor) - propagated_inputs.append(torch.zeros(output_shape, dtype=torch.float16).cuda()) - - propagated_result = propagated_matmul(*propagated_inputs) - torch.testing.assert_close(ref_result, propagated_result, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,bit,storage_dtype,source_format,with_scaling,with_zeros,group_size,fast_decoding,with_bias,propagate_a,propagate_b,layout,source_zeros_mode,target_zeros_mode", - [ - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, - False, False, False, "nt", "rescale", "quantized"), - (1, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, False, - False, False, False, "nt", "rescale", "original"), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, - False, False, False, False, "nt", "rescale", "quantized"), - (1024, 1024, 1024, "float16", "float16", "float16", 4, "int8", "uint", False, False, -1, - False, False, False, False, "nt", "rescale", "original"), - ], -) -def test_matmul_dequantize_diff_zero_types(M, N, K, in_dtype, out_dtype, accum_dtype, bit, - storage_dtype, source_format, with_scaling, with_zeros, - group_size, fast_decoding, with_bias, propagate_a, - propagate_b, layout, source_zeros_mode, - target_zeros_mode): - import torch - torch.random.manual_seed(0) - import numpy as np - from bitblas.quantization.utils import general_compress - - source_quantized_matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - zeros_mode=source_zeros_mode) - source_quantized_matmul = MatmulWeightOnlyDequantize( - config=source_quantized_matmul_config, - target=target, - ) - if not isinstance(M, int): - M = 32 - source_quantized_matmul.hardware_aware_finetune(topk=20) - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda() - 0.5) - maxq = 2**(bit - 1) - 1 - zeros = maxq - if source_format == "uint": - inputs.append(torch.randint(0, maxq, weight_shape, dtype=torch.int8).cuda()) - elif source_format == "int": - inputs.append(torch.randint(-maxq, maxq, weight_shape, dtype=torch.int8).cuda()) - else: - raise NotImplementedError - - inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) - - intweight = inputs[1] - intweight = intweight.cpu().numpy().astype(np.int8) - if source_format == "int": - intweight = intweight + maxq - if with_zeros: - inputs[1] = inputs[1] - zeros - bias = torch.rand((output_shape[-1],), dtype=torch.float16).cuda() - qw_np = general_compress(intweight, source_bits=bit, storage_dtype=np.int8) - qw_torch = torch.from_numpy(qw_np).cuda() - permuted_inputs = [] - if source_quantized_matmul.input_transform is not None: - permuted_inputs.append(source_quantized_matmul.input_transform(qw_torch.cpu()).cuda()) - else: - permuted_inputs.append(inputs[0]) - if source_quantized_matmul.weight_transform is not None: - permuted_inputs.append(source_quantized_matmul.weight_transform(qw_torch.cpu()).cuda()) - else: - permuted_inputs.append(qw_torch) - if with_scaling: - if group_size == -1: - group_size = K - permuted_inputs.append(torch.rand([N, K // group_size], dtype=torch.float16).cuda()) - if with_zeros: - if source_zeros_mode == "original": - permuted_inputs.append( - torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) - elif source_zeros_mode == "rescale": - original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros - scaled_zeros = original_zeros * permuted_inputs[-1] - permuted_inputs.append(scaled_zeros) - elif source_zeros_mode == "quantized": - original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros - qzeros = general_compress( - original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) - permuted_inputs.append(torch.from_numpy(qzeros).cuda()) - else: - raise NotImplementedError - - if with_bias: - permuted_inputs.append(bias) - permuted_inputs.append(inputs[2]) - source_quantized_matmul(*permuted_inputs) - ref_result = permuted_inputs[-1] - target_quantized_matmul_config = MatmulWeightOnlyDequantizeConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - bit=bit, - storage_dtype=storage_dtype, - source_format=source_format, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=group_size, - fast_decoding=fast_decoding, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - zeros_mode=target_zeros_mode) - target_quantized_matmul = MatmulWeightOnlyDequantize( - config=target_quantized_matmul_config, - target=target, - ) - if not isinstance(M, int): - M = 32 - target_quantized_matmul.hardware_aware_finetune(topk=20) - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - - target_inputs = [] - target_inputs.append(permuted_inputs[0]) - target_inputs.append(permuted_inputs[1]) - - if with_scaling: - target_inputs.append(permuted_inputs[2]) - if with_zeros: - if target_zeros_mode == "original": - target_inputs.append( - torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros) - elif target_zeros_mode == "rescale": - original_zeros = torch.ones([N, K // group_size], dtype=torch.float16).cuda() * zeros - scaled_zeros = original_zeros * target_inputs[-1] - target_inputs.append(scaled_zeros) - elif target_zeros_mode == "quantized": - original_zeros = torch.ones([K // group_size, N], dtype=torch.int8).cuda() * zeros - qzeros = general_compress( - original_zeros.cpu().numpy(), source_bits=bit, storage_dtype=np.int8) - target_inputs.append(torch.from_numpy(qzeros).cuda()) - else: - raise NotImplementedError - if with_bias: - target_inputs.append(bias) - target_inputs.append(torch.zeros_like(inputs[2])) - target_quantized_matmul(*target_inputs) - torch.testing.assert_close(target_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) - - -# fmt: on - -if __name__ == "__main__": - bitblas.testing.main() diff --git a/testing/python/operators/test_matmul_ops.py b/testing/python/operators/test_matmul_ops.py deleted file mode 100644 index edf3fd765..000000000 --- a/testing/python/operators/test_matmul_ops.py +++ /dev/null @@ -1,274 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import pytest -import bitblas -from bitblas.ops.matmul import Matmul, MatmulConfig -from bitblas.utils import auto_detect_nvidia_target - -target = auto_detect_nvidia_target() - - -def get_codegen_result(ops, target): - code = ops.get_source(target=target) - return code - - -# fmt: off -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout,enable_tuning", - [ - (16384, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - # dynamic shape - ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", False), - ([1, 32], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt", True), - ], -) -def test_matmul_codegen_default( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - propagate_a, - propagate_b, - layout, - enable_tuning, -): - - matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = Matmul( - config=matmul_config, - target=target, - ) - if enable_tuning: - matmul.hardware_aware_finetune(topk=20) - assert get_codegen_result(matmul, target) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", - [ - (16384, 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), - # dynamic shape - ([1], 16384, 16384, "float16", "float16", "float16", False, False, False, "nt"), - ], -) -def test_matmul_codegen_finetune( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - propagate_a, - propagate_b, - layout, -): - - matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = Matmul( - config=matmul_config, - target=target, - ) - matmul.hardware_aware_finetune(topk=20) - assert get_codegen_result(matmul, target) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", - [ - (1024, 1024, 1024, "float16", "float16", "float16", False, False, False, "nt"), - ], -) -def test_matmul_profile_latency( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - propagate_a, - propagate_b, - layout, -): - matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = Matmul( - config=matmul_config, - target=target, - ) - latency = matmul.profile_latency() - assert latency - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", - [ - (256, 256, 256, "float16", "float16", "float16", False, False, False, "nt"), - (256, 256, 256, "float16", "float16", "float16", False, False, True, "nt"), - (256, 256, 256, "float16", "float16", "float16", False, False, 0, "nt"), - (256, 256, 256, "float16", "float16", "float16", False, False, 1, "nt"), - (256, 256, 256, "float16", "float16", "float16", False, False, 2, "nt"), - ], -) -def test_matmul_torch_forward( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - propagate_a, - propagate_b, - layout, -): - import torch - - matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = Matmul( - config=matmul_config, - target=target, - ) - # convert tensors to torch - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - inputs = [] - inputs.append(torch.rand(input_shape, dtype=torch.float16).cuda()) - inputs.append(torch.rand(weight_shape, dtype=torch.float16).cuda()) - inputs.append(torch.rand(output_shape, dtype=torch.float16).cuda()) - ref_result = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) - - permuted_inputs = [] - if matmul.input_transform is not None: - permuted_inputs.append(matmul.input_transform(inputs[0].cpu())).cuda() - else: - permuted_inputs.append(inputs[0]) - if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) - else: - permuted_inputs.append(inputs[1]) - permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs) - torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e-2, atol=1e-2) - - -@pytest.mark.parametrize( - "M,N,K,in_dtype,out_dtype,accum_dtype,with_bias,propagate_a,propagate_b,layout", - [ - (256, 256, 256, "int8", "int8", "int32", False, False, False, "nt"), - ], -) -def test_matmul_torch_forward_int( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - propagate_a, - propagate_b, - layout, -): - import torch - torch.random.manual_seed(0) - - matmul_config = MatmulConfig( - M=M, - N=N, - K=K, - in_dtype=in_dtype, - out_dtype=out_dtype, - accum_dtype=accum_dtype, - with_bias=with_bias, - propagate_a=propagate_a, - propagate_b=propagate_b, - layout=layout, - ) - matmul = Matmul( - config=matmul_config, - target=target, - ) - - # convert tensors to torch - input_shape = (M, K) - weight_shape = (N, K) if layout == "nt" else (K, N) - output_shape = (M, N) - inputs = [] - inputs.append(torch.randint(-16, 16, input_shape, dtype=torch.int8).cuda()) - inputs.append(torch.randint(-1, 2, weight_shape, dtype=torch.int8).cuda()) - ref_result = torch.matmul( - inputs[0].to(torch.float32), - inputs[1].t().to(torch.float32) if layout == "nt" else inputs[1].to(torch.float32)) - - permuted_inputs = [] - if matmul.input_transform is not None: - permuted_inputs.append(matmul.input_transform(inputs[0].cpu())).cuda() - else: - permuted_inputs.append(inputs[0]) - if matmul.weight_transform is not None: - permuted_inputs.append(matmul.weight_transform(inputs[1].cpu()).cuda()) - else: - permuted_inputs.append(inputs[1]) - - permuted_inputs.append(torch.randint(-7, 7, output_shape, dtype=torch.int32).cuda()) - matmul(*permuted_inputs) - print(permuted_inputs[-1]) - print(ref_result) - torch.testing.assert_close( - permuted_inputs[-1].to(torch.float32), ref_result, rtol=1e-2, atol=1e-2) - - -# fmt: on - -if __name__ == "__main__": - bitblas.testing.main() diff --git a/testing/python/operators/test_param_permutate_ops.py b/testing/python/operators/test_param_permutate_ops.py deleted file mode 100644 index eb85e34cb..000000000 --- a/testing/python/operators/test_param_permutate_ops.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import pytest -import bitblas -from bitblas.ops.param_permutate import ParamPermutate, ParamPermutateConfig - -from bitblas import tvm - -target = tvm.target.Target("llvm") - - -# fmt: off -@pytest.mark.parametrize( - "M,N,datatype,transpose_matrix,group_size,propagate_kind,target_instruction", [ - (1024, 1024, "float16", True, 1, True, "nvidia-mma"), - ]) -def test_param_permutate_profile_latency( - M, - N, - datatype, - transpose_matrix, - group_size, - propagate_kind, - target_instruction, -): - param_permutate_config = ParamPermutateConfig( - M=M, - N=N, - datatype=datatype, - propagate_kind=propagate_kind, - group_size=group_size, - transpose_matrix=transpose_matrix, - target_instruction=target_instruction, - ) - param_permutate = ParamPermutate( - config=param_permutate_config, - target=target, - ) - latency = param_permutate.profile_latency() - assert latency - - -# fmt: on - -if __name__ == "__main__": - bitblas.testing.main() diff --git a/testing/python/transform/test_weight_only_transform.py b/testing/python/transform/test_weight_only_transform.py index 11d7ce6d7..7f6887277 100644 --- a/testing/python/transform/test_weight_only_transform.py +++ b/testing/python/transform/test_weight_only_transform.py @@ -1,16 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from bitblas import tvm + from tvm.script import ir as I from tvm.script import tir as T from tvm.script import relax as R - -from bitblas import tvm -from bitblas import tvm.testing from tvm import relax -from tvm.script import ir as I, relax as R, tir as T -from tvm import tir -from tvm.ir import IRModule -from tvm.ir.transform import PassContext, module_pass from bitblas.base.utils import get_dummy_input_arrays from copy import deepcopy import bitblas @@ -329,16 +324,16 @@ def main( input_tensors = get_dummy_input_arrays(ref_mod["main"], device) print(relax_mod) print("=======================ref llvm result=======================") - # ref_res = get_ref_result(ref_mod, input_tensors) - # print("ref_mod", ref_res) - # bitblas_res = get_ref_result(relax_mod, input_tensors) - # print("bitblas_res", bitblas_res) + ref_res = get_ref_result(ref_mod, input_tensors) + print("ref_mod", ref_res) + bitblas_res = get_ref_result(relax_mod, input_tensors) + print("bitblas_res", bitblas_res) print("=======================default gpu result=======================") - # ref_res = get_default_result(ref_mod, input_tensors, dispatch_target, device) - # print("ref_mod", ref_res) - # bitblas_res = get_default_result(relax_mod, input_tensors, dispatch_target, device) - # print("bitblas_res", bitblas_res) - # print("=======================fast tune gpu result=======================") + ref_res = get_default_result(ref_mod, input_tensors, dispatch_target, device) + print("ref_mod", ref_res) + bitblas_res = get_default_result(relax_mod, input_tensors, dispatch_target, device) + print("bitblas_res", bitblas_res) + print("=======================fast tune gpu result=======================") ref_res = get_fast_tune_result(ref_mod, input_tensors, dispatch_target, device) print("ref_mod", ref_res) print(relax_mod) @@ -348,6 +343,5 @@ def main( print("bitblas_res", bitblas_res) -# test_lop3_transform() -# test_matmul_transform() -test_dequantize_matmul_transform() +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/type_conversion/int4b_fp16_convert.py b/testing/python/type_conversion/test_int4b_fp16_convert.py similarity index 99% rename from testing/python/type_conversion/int4b_fp16_convert.py rename to testing/python/type_conversion/test_int4b_fp16_convert.py index 43474d062..92b0e0788 100644 --- a/testing/python/type_conversion/int4b_fp16_convert.py +++ b/testing/python/type_conversion/test_int4b_fp16_convert.py @@ -1,12 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import bitblas from bitblas import tvm import torch import numpy as np -from bitblas import tvm.testing from tvm.script import tir as T -import os -from tvm import te import numpy as np @@ -226,4 +224,5 @@ def test_lop3_interleave_weight(): np.testing.assert_allclose(tvm_interleaved_b_np_int8, interleaved_weight, atol=1e-5) -test_lop3_interleave_weight() +if __name__ == "__main__": + bitblas.testing.main() diff --git a/testing/python/type_conversion/test_lop3_type_conversion.py b/testing/python/type_conversion/test_lop3_type_conversion.py deleted file mode 100644 index d5853a108..000000000 --- a/testing/python/type_conversion/test_lop3_type_conversion.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.script import tir as T -import bitblas -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.base.utils import apply_and_build -from bitblas.ops.impl.matmul_impl import matmul_nt, matmul_nt_dequantize_b -import numpy as np - - -def test_f16_f16_gemm(): - ir_module = matmul_nt(1, 16384, 16384, "float16", "float16") - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except Exception: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - print( - "[BitBLAS] The best latency of top 1 is {:.3f} ms".format( - cpresults[0].latency * 1e3 - ) - ) - print( - "[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3) - ) - - -def test_f16_i4_gemm(M=1, N=16384, K=16384, bit=4, fast_decoding=True): - ir_module = matmul_nt_dequantize_b( - M, - N, - K, - "float16", - bit=bit, - storage_dtype="uint32", - with_scaling=True, - group_size=-1, - fast_decoding=fast_decoding, - ) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except Exception: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - assert best - - -test_f16_i4_gemm() diff --git a/testing/python/type_conversion/test_numpy_compress_convert.py b/testing/python/type_conversion/test_numpy_compress_convert.py deleted file mode 100644 index 59e481eb9..000000000 --- a/testing/python/type_conversion/test_numpy_compress_convert.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. diff --git a/testing/python/weight_only/correctness/test_fp16xint4_correctness.py b/testing/python/weight_only/correctness/test_fp16xint4_correctness.py deleted file mode 100644 index 7f5b3027d..000000000 --- a/testing/python/weight_only/correctness/test_fp16xint4_correctness.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import torch - -import bitblas -import numpy as np - -from bitblas.quantization.utils import general_compress, interleave_weight -from bitblas.ops.matmul import MatmulWeightOnlyDequantize - -M = 1 -N = 4096 -K = 1024 -bitblas_matmul = MatmulWeightOnlyDequantize( - M=M, - N=N, - K=K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - propagate_b=False, - bit=4, - storage_dtype="uint8", - source_format="int", - with_scaling=False, - group_size=128, - fast_decoding=False, - with_bias=False, -) - -torch_arrs = [] -torch_arrs.append(torch.randint(0, 10, (M, K), dtype=torch.float16, device="cuda")) -torch_arrs.append(torch.randint(0, 7, (N, K), dtype=torch.float16, device="cuda")) -torch_arrs.append(torch.zeros((M, K), dtype=torch.float16, device="cuda")) - -print("torch: {}".format(torch_arrs[-1])) - diff --git a/testing/python/weight_only/index_map_deduce.py b/testing/python/weight_only/index_map_deduce.py deleted file mode 100644 index 9f75267e6..000000000 --- a/testing/python/weight_only/index_map_deduce.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import pytest - -from bitblas import tvm -from bitblas import tvm.testing -from tvm.ir import assert_structural_equal -from tvm.runtime import const -from tvm.tir import IndexMap, IntImm, floordiv, floormod -from tvm import tir -index_map = IndexMap.from_func(lambda i: [i // 4, i % 4], index_dtype="int32") -initial_i = index_map.initial_indices[0] - -# but what we have is i <=> i // 4 -# should do inverse - -block_iter_map = IndexMap.from_func(lambda i: [i // 4], index_dtype="int32") -inverse_block_iter_map = index_map.inverse([32,]) - -new_final_indices = index_map.map_indices([initial_i * 4]) - -# # tir.IndexMap([initial_i // 4], final_indices, None) -# print(new_final_indices) diff --git a/testing/python/weight_only/index_map_fuse.py b/testing/python/weight_only/index_map_fuse.py deleted file mode 100644 index 328f8184d..000000000 --- a/testing/python/weight_only/index_map_fuse.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.script import tir as T -from tvm.tir import IndexMap -from tvm.tir.tensor_intrin.cuda import ( - ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, - ldmatrix_32x16_to_shared_16x32_layout_b, -) - -def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id) - -@tvm.script.ir_module -class LDMATRIX_16x16: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 16], dtype="float16") - B = T.match_buffer(b, [16, 16], dtype="float16") - - for i, j in T.grid(16, 16): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(A[vi, vj]) - A[vi, vj] = B[vi, vj] - -ir_module = LDMATRIX_16x16 -sch = tvm.tir.Schedule(ir_module) - -block_b = sch.get_block("B") -sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16) -print("========================inject transform=============================") -print(sch.mod["main"].script()) - -index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32") -inversed_index_map = index_map.inverse([16, 16]) -def inverse_permutation(i, j): - return inversed_index_map.map_indices([i, j]) -sch.transform_layout(block_b, ('read', 0), inverse_permutation) -print("========================inverse inject transform=============================") -print(sch.mod["main"].script()) - - -def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 16 - local_id = kernel_j % 16 - return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id) - -@tvm.script.ir_module -class LDMATRIX_16x32_A: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 32], dtype="float16") - B = T.match_buffer(b, [16, 32], dtype="float16") - - for i, j in T.grid(16, 32): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(A[vi, vj]) - A[vi, vj] = B[vi, vj] - -ir_module = LDMATRIX_16x32_A -sch = tvm.tir.Schedule(ir_module) - -block_b = sch.get_block("B") -sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) -print("========================inject transform=============================") -print(sch.mod["main"].script()) - -index_map_inter = IndexMap.from_func(lambda i, j: (i // 16, j // 16, i % 16, j % 16), index_dtype="int32") - -index_map_intra = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") - -print("index_map_inter", index_map_inter) \ No newline at end of file diff --git a/testing/python/weight_only/inverse_index_map.py b/testing/python/weight_only/inverse_index_map.py deleted file mode 100644 index a27e72b28..000000000 --- a/testing/python/weight_only/inverse_index_map.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas import tvm -from tvm.script import tir as T -from tvm.tir import IndexMap -from tvm.tir.tensor_intrin.cuda import ( - ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, - ldmatrix_32x16_to_shared_16x32_layout_b, -) - -def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id) - -@tvm.script.ir_module -class LDMATRIX_16x16: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 16], dtype="float16") - B = T.match_buffer(b, [16, 16], dtype="float16") - - for i, j in T.grid(16, 16): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(A[vi, vj]) - A[vi, vj] = B[vi, vj] - -ir_module = LDMATRIX_16x16 -sch = tvm.tir.Schedule(ir_module) - -block_b = sch.get_block("B") -sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x16_32x8_16x16) -print("========================inject transform=============================") -print(sch.mod["main"].script()) - -index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16, index_dtype="int32") -inversed_index_map = index_map.inverse([16, 16]) -def inverse_permutation(i, j): - return inversed_index_map.map_indices([i, j]) -sch.transform_layout(block_b, ('read', 0), inverse_permutation) -print("========================inverse inject transform=============================") -print(sch.mod["main"].script()) - - -def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 16 - local_id = kernel_j % 16 - return ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id) - -@tvm.script.ir_module -class LDMATRIX_16x32_A: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 32], dtype="float16") - B = T.match_buffer(b, [16, 32], dtype="float16") - - for i, j in T.grid(16, 32): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(A[vi, vj]) - A[vi, vj] = B[vi, vj] - -ir_module = LDMATRIX_16x32_A -sch = tvm.tir.Schedule(ir_module) - -block_b = sch.get_block("B") -sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) -print("========================inject transform=============================") -print(sch.mod["main"].script()) - -index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") -inversed_index_map = index_map.inverse([16, 32]) -def inverse_permutation(i, j): - return inversed_index_map.map_indices([i, j]) -sch.transform_layout(block_b, ('read', 0), inverse_permutation) -print("========================inverse inject transform=============================") -print(sch.mod["main"].script()) - -def ldmatrix_trans_permutation_16x32_16x32_16x32(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 16 - local_id = kernel_j % 16 - return ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id) - -@tvm.script.ir_module -class LDMATRIX_16x32_B: - @T.prim_func - def main(a: T.handle, b: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A = T.match_buffer(a, [16, 32], dtype="float16") - B = T.match_buffer(b, [16, 32], dtype="float16") - - for i, j in T.grid(16, 32): - with T.block("B"): - vi, vj = T.axis.remap("SS", [i, j]) - T.reads(B[vi, vj]) - T.writes(A[vi, vj]) - A[vi, vj] = B[vi, vj] - -ir_module = LDMATRIX_16x32_B -sch = tvm.tir.Schedule(ir_module) - -block_b = sch.get_block("B") -sch.transform_layout(block_b, ('read', 0), ldmatrix_trans_permutation_16x32_16x32_16x32) -print("========================inject transform=============================") -print(sch.mod["main"].script()) - -index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x32_16x32_16x32, index_dtype="int32") -inversed_index_map = index_map.inverse([16, 32]) -def inverse_permutation(i, j): - return inversed_index_map.map_indices([i, j]) -sch.transform_layout(block_b, ('read', 0), inverse_permutation) -print("========================inverse inject transform=============================") -print(sch.mod["main"].script())