From e71b85de49497c88c37c9e683f7ea78e013a881a Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Tue, 13 Aug 2024 02:34:32 +0800 Subject: [PATCH] [Dev] Set default weight transformation into Ladder Stage3 LDMatrixTransform (#133) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template * transform branch from bitblas to bitblas_tl * Fix subproject commit reference in 3rdparty/tvm * chore: update submodule branch from bitblas to bitblas_tl * force update config.cmake * Bug fix * Fix subproject commit reference in 3rdparty/cutlass * chore: Add submodule for cutlass library * update tl cutlass path * Refactor BitBLASLinear test module for improved readability and maintainability * format fix * Copy CUTLASS to the package directory * Refactor setup.py to include additional TVM header files * lint fix * bug fix * Refactor BitBLASLinear test module for improved readability and maintainability * Implement Matmul Benchmark Design * chore: Update BitBLAS Matmul benchmark script * lint fix * Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * lint fix * Benchmark bot test * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * int8 test case * Refactor compare_benchmark.py to handle missing benchmark results gracefully * ci fix * disable ci for test benchmark * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * remove cli installation * chore: Create virtual environment and install dependencies for benchmark * chore: Update benchmark workflow to include comparison step * Lint fix * upodate tvm cmmit * Imporve lower warp memory pass * Bug fix * Enhance to support warp schedule. * Enhance LOP3 Instructions * Enhance LOP3 Instructions * add test for stage3 propagate * implement propagate func * Stage3 Ladder Permutate integration * get_ladder_stage3_propagate * comments benchmark scirpts as the setting is too big * ci fix for benchmark * lint fix * chore: Update benchmark workflow to trigger on pull request comments * Add LDMatrix Transform 3 * Support GPTQ Test * Fuse BlockReduce Schedule * Support mma propagate 3 * Support MMA Propagate Stage 3 * Lint Fix * Merge block reduce for dequantze config. * fix codeql * chore: Update submodule reference to latest commit * chore: Disable common subexpression elimination in TIR passes * Lint Fix * 4bit related lop3 updates. * lint fix * gptq test fix * Fix for test * lint fix * lint fix * typofix * QuantCompress Test * chore: Refactor quant_compress_impl.py for readability and maintainability * Enhance docs to update latest works. * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * removed legacy operator * Refactor weight executors in Matmul class for improved readability and maintainability * LintFix * Fix GPTQ Repack with the latest weight transform * lint fix * bug fix for rescale dequantize * test fix * typo fix * lint fix * Set default weight propagate kind into LDMatrixTransform * lint fix * bug fix * bug fix for test * set default to stage3 * revert change * lint fix * case fix * bug fix * fix for legalize * bug fix * chore: Clear global operator cache before running tests * revert optimize_stratety into SingleBatchDecodeOnly * typofix * update benchmark scripts * chore: Refactor benchmark scripts and fix typos * fix for testing * lint fix * fix import. * typo * operator benchmark * optimize * always with shared.dyn * optimize cache. * dsl fix * tqdm * chore: Add serialize_results method to benchmark_matmul_strategies.py * fix performance issue for dynamic async copy * chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability * bug fix * update readme --- 3rdparty/tvm | 2 +- README.md | 1 + benchmark/dsl/convolution.py | 2 +- benchmark/dsl/matmul.py | 2 +- benchmark/dsl/matmul_dequantize_af.py | 2 +- benchmark/dsl/matmul_dequantize_fp.py | 2 +- benchmark/dsl/matmul_dequantize_int1.py | 3 +- benchmark/dsl/matmul_dequantize_int4.py | 2 +- benchmark/dsl/weight_propagate.py | 2 +- benchmark/gridsearch/grid_search_example.py | 143 ++++++++++ .../benchmark_matmul_scope_compare.py | 166 ++++++++++++ .../operators/benchmark_matmul_strategies.py | 248 ++++++++++++++++++ benchmark/operators/benchmark_ops_matmul.py | 90 ++++--- bitblas/base/roller/hint.py | 2 + bitblas/base/roller/policy/tensorcore.py | 3 +- bitblas/base/utils.py | 7 +- bitblas/benchmark/operator/__init__.py | 17 ++ bitblas/cache/__init__.py | 1 + bitblas/gpu/matmul_mma.py | 12 +- bitblas/gpu/matmul_mma_dequantize.py | 18 +- bitblas/module/__init__.py | 5 +- bitblas/ops/general_matmul/__init__.py | 37 ++- bitblas/ops/general_matmul_splitk.py | 14 +- .../ops/impl/matmul_dequantize_splitk_impl.py | 205 ++++++++++++++- docs/ExtendOperatorsWithDSL.md | 2 +- testing/python/cache/test_operator_cache.py | 15 +- testing/python/module/test_bitblas_linear.py | 2 + .../operators/test_general_matmul_ops.py | 3 + .../test_general_matmul_tile_schedule.py | 8 +- 29 files changed, 942 insertions(+), 74 deletions(-) create mode 100644 benchmark/gridsearch/grid_search_example.py create mode 100644 benchmark/operators/benchmark_matmul_scope_compare.py create mode 100644 benchmark/operators/benchmark_matmul_strategies.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 6daecacc7..07648907e 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 6daecacc73c8c8fdea1b9732891e1d4a5ebbf818 +Subproject commit 07648907e1678ec2b84d8ec579b2ec8f4925d218 diff --git a/README.md b/README.md index 34c203cf1..7eeda4ab0 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ Some of the key features of BitBLAS include: ## Latest News +- 08/12/2024 🚀🚀: We've improved performance for contiguous batching. To enable it, you'll need to set specific flags. For more details, please refer to [PR #133](https://github.com/microsoft/BitBLAS/pull/133). - 07/11/2024 ✨: Ladder is published and presented in OSDI'24. Please find [Ladder paper and presentation](https://www.usenix.org/conference/osdi24/presentation/wang-lei) if you are interested in the technical details of BitBLAS. - 06/25/2024 🚀🚀: BitBLAS has been integrated into [GPTQModel](https://github.com/ModelCloud/GPTQModel)! You can now use BitBLAS as a backend in GPTQ. - 05/04/2024 🚀🚀: We’ve added integration examples for the 1.58-bit model! Check out the files under integration/BitNet. diff --git a/benchmark/dsl/convolution.py b/benchmark/dsl/convolution.py index bf02a41d7..9bb9f4e48 100644 --- a/benchmark/dsl/convolution.py +++ b/benchmark/dsl/convolution.py @@ -3,7 +3,7 @@ import numpy as np import tvm from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.base.utils import apply_and_build diff --git a/benchmark/dsl/matmul.py b/benchmark/dsl/matmul.py index 85b9374e9..13ab7d189 100644 --- a/benchmark/dsl/matmul.py +++ b/benchmark/dsl/matmul.py @@ -3,7 +3,7 @@ import tvm from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target diff --git a/benchmark/dsl/matmul_dequantize_af.py b/benchmark/dsl/matmul_dequantize_af.py index 5bc8362af..69939e1b7 100644 --- a/benchmark/dsl/matmul_dequantize_af.py +++ b/benchmark/dsl/matmul_dequantize_af.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target diff --git a/benchmark/dsl/matmul_dequantize_fp.py b/benchmark/dsl/matmul_dequantize_fp.py index 102ba978c..7b310456d 100644 --- a/benchmark/dsl/matmul_dequantize_fp.py +++ b/benchmark/dsl/matmul_dequantize_fp.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target diff --git a/benchmark/dsl/matmul_dequantize_int1.py b/benchmark/dsl/matmul_dequantize_int1.py index 8874c37f5..e810782cf 100644 --- a/benchmark/dsl/matmul_dequantize_int1.py +++ b/benchmark/dsl/matmul_dequantize_int1.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target @@ -14,6 +14,7 @@ import tvm import time import argparse + bitblas.set_log_level("DEBUG") # append a parser for the benchmark set diff --git a/benchmark/dsl/matmul_dequantize_int4.py b/benchmark/dsl/matmul_dequantize_int4.py index f8c577556..5367ea3d9 100644 --- a/benchmark/dsl/matmul_dequantize_int4.py +++ b/benchmark/dsl/matmul_dequantize_int4.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target diff --git a/benchmark/dsl/weight_propagate.py b/benchmark/dsl/weight_propagate.py index aab5316c0..4c587028e 100644 --- a/benchmark/dsl/weight_propagate.py +++ b/benchmark/dsl/weight_propagate.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import bitblas from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags from bitblas.gpu import Matmul from bitblas.utils import auto_detect_nvidia_target diff --git a/benchmark/gridsearch/grid_search_example.py b/benchmark/gridsearch/grid_search_example.py new file mode 100644 index 000000000..a6a00adad --- /dev/null +++ b/benchmark/gridsearch/grid_search_example.py @@ -0,0 +1,143 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tqdm +import bitblas +from bitblas.base.arch import CUDA +from bitblas.ops.general_matmul.tirscript.matmul_dequantize_impl import ( + matmul_nt_dequantize_b_propagate_b,) +import tvm +import itertools + +search_space = { + "block_row_warps": [1], + "block_col_warps": [1, 2, 4, 8, 16], + "warp_row_tiles": [1], + "warp_col_tiles": [1, 2, 4, 8, 16], + "chunk": [1, 2, 4, 8, 16, 32], + "stage": [2, 3, 4], + "block_reduce": [2, 4], +} + +keys = search_space.keys() +values = search_space.values() +combinations = list(itertools.product(*values)) + +combinations_dicts = [dict(zip(keys, combination)) for combination in combinations] + +# for combination in combinations_dicts: +# print(combination) +print(len(combinations_dicts)) +group_size = -1 +# fmt:off +llm_shape_fp16xint4 = [ + # square test + (matmul_nt_dequantize_b_propagate_b, (16, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "uint", False, False, group_size, True, False)), +] + +# fmt:on + +target = tvm.target.Target(bitblas.auto_detect_nvidia_target()) +benchmark_sets = llm_shape_fp16xint4 +tuning_results = {} + +min_time = 1e9 +min_combination = None +sucess_combinations = [] +for get_prim_func, input_args in benchmark_sets: + ir_module = get_prim_func(*input_args, transform_kind=3) + func = ir_module["main"] + arch = CUDA(target) + + M, N, K = input_args[0], input_args[1], input_args[2] + import numpy as np + + np.random.seed(0) + # a = np.random.randn(M // 16, K // 16, 16, 16).astype(np.float16) + a = np.random.randn(M, K).astype(np.float16) + b = np.random.randn(N // 16, K // 16, 16, 8).astype(np.int8) + c = np.random.randn(M, N).astype(np.float16) + + tvm_a = tvm.nd.array(a, device=tvm.cuda(0)) + tvm_b = tvm.nd.array(b, device=tvm.cuda(0)) + tvm_c = tvm.nd.array(c, device=tvm.cuda(0)) + + intrin_info = bitblas.base.hint.IntrinInfo( + in_dtype="float16", + out_dtype="float16", + trans_b=True, + input_transform_kind=0, + weight_transform_kind=3, + ) + + # set up tqdm + pbar = tqdm.tqdm(combinations_dicts) + for combination in pbar: + pbar.set_description( + f"sucess_combinations: {len(sucess_combinations)} min_time: {min_time}") + block_row_warps = combination["block_row_warps"] + block_col_warps = combination["block_col_warps"] + warp_row_tiles = combination["warp_row_tiles"] + warp_col_tiles = combination["warp_col_tiles"] + chunk = combination["chunk"] + stage = combination["stage"] + block_reduce = combination["block_reduce"] + + mma_row = mma_col = 16 + mma_k = 16 + + block = [ + block_row_warps * warp_row_tiles * mma_row, + block_col_warps * warp_col_tiles * mma_col, + ] + warp = [mma_row * warp_row_tiles, mma_col * warp_col_tiles] + rstep = [mma_k * chunk * block_reduce] + pipeline_stage = stage + block_reduction_depth = block_reduce + hint = bitblas.base.Hint.from_dict({ + "use_tc": True, + "arch": arch, + "block": block, + "warp": warp, + "rstep": rstep, + "pipeline_stage": pipeline_stage, + "use_async": True, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + "vectorize": { + "b": 8, + "a": 8 + }, + "block_reduction_depth": block_reduction_depth, + "rasterization_plan": bitblas.base.rasterization.Rasterization2DColumn(10), + }) + print("Tuning Hint is", hint) + try: + sch = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo( + ).sch_warp_memory_prefetch_with_config( + func, config=hint) + + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": True, + "tir.merge_static_smem": False, + "tir.disable_cse_tir": True, + }): + rt_mod = tvm.build(sch.mod, target=target) + + time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(0), number=10) + + t = time_evaluator(tvm_a, tvm_b, tvm_c).mean * 1e3 + + print(f"For combination {combination}, time is {t} ms") + tuning_results["-".join([str(v) for v in combination.values()])] = t + if t < min_time: + min_time = t + min_combination = combination + sucess_combinations.append(combination) + except Exception as e: + del e + print(f"Failed for combination {combination}") + continue + +print(f"Minimum time is {min_time} for combination {min_combination}") diff --git a/benchmark/operators/benchmark_matmul_scope_compare.py b/benchmark/operators/benchmark_matmul_scope_compare.py new file mode 100644 index 000000000..0fb3c5ba6 --- /dev/null +++ b/benchmark/operators/benchmark_matmul_scope_compare.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.utils.target_detector import auto_detect_nvidia_target +from bitblas import Matmul, MatmulConfig +import argparse +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.base.arch import CUDA +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +from bitblas.base.utils import apply_and_build + +# Initialize the parser +parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") + +# Add arguments to the parser +parser.add_argument( + "--target", + type=str, + default=auto_detect_nvidia_target(), + help="Specify the target device for benchmarking.") +parser.add_argument( + "--group_size", type=int, default=None, help="Group size for grouped quantization.") +parser.add_argument( + "--A_dtype", + type=str, + default="float16", + choices=["float16", "float32", "float64", "int32", + "int8"], # Assuming these are the valid choices + help="Data type of activation A.") +parser.add_argument( + "--W_dtype", + type=str, + default="int4", + choices=[ + "float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1" + ], # Assuming these are the valid choices + help="Data type of weight W.") +parser.add_argument( + "--accum_dtype", + type=str, + default="float16", + choices=["float16", "int32"], # Assuming these are the valid choices + help="Data type for accumulation.") +parser.add_argument( + "--out_dtype", + type=str, + default="float16", + choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices + help="Data type for output.") +parser.add_argument( + "--layout", + type=str, + default="nt", + choices=["nt", "nn"], # Assuming these are the valid choices + help="Matrix layout, 'nt' for non-transpose A and transpose W.") +parser.add_argument("--with_bias", action="store_true", help="Include bias in the benchmark.") +parser.add_argument( + "--with_scaling", action="store_true", help="Include scaling factor in the quantization.") +parser.add_argument("--with_zeros", action="store_true", help="Include zeros in the quantization.") +parser.add_argument( + "--zeros_mode", + type=str, + default=None, + choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable + help="Specify the mode for calculating zeros.") + +# Parse the arguments +args = parser.parse_args() + +# Assign arguments to variables +target = args.target +group_size = args.group_size +A_dtype = args.A_dtype +W_dtype = args.W_dtype +accum_dtype = args.accum_dtype +out_dtype = args.out_dtype +layout = args.layout +with_bias = args.with_bias +group_size = args.group_size +with_scaling = args.with_scaling +with_zeros = args.with_zeros +zeros_mode = args.zeros_mode + +test_shapes = [ + # square test + (MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (16, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (32, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (64, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (128, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (256, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (1024, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (16, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (32, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (64, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (128, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (256, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (128, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (64, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (32, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (16, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (32, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (64, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), + (MatmulConfig, Matmul, (128, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, + with_bias, group_size, with_scaling, with_zeros, zeros_mode)), +] + +benchmark_sets = [] +benchmark_sets.extend(test_shapes) + +# fmt:on + +benchmark_results = {} +for config, operator, input_args in benchmark_sets: + config = config(*input_args) + matmul = operator(config, target=target, enable_tuning=False) + func = matmul.prim_func + arch = CUDA(target) + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + + configs = policy.emit_config(20) + static_configs = [] + for config in configs: + static_config = config + static_config.shared_scope = "shared" + static_configs.append(static_config) + dynamic_configs = [] + for config in configs: + dynamic_config = config + dynamic_config.shared_scope = "shared.dyn" + dynamic_configs.append(dynamic_config) + + _, best_static = apply_and_build(func, static_configs, arch, parallel_build=True) + + _, best_dynamic = apply_and_build(func, dynamic_configs, arch, parallel_build=True) + benchmark_results[input_args] = (best_static.latency, best_dynamic.latency, + best_static.latency - best_dynamic.latency) + +for key, value in benchmark_results.items(): + print( + f"Input arguments: {key}, Static latency: {value[0]}, Dynamic latency: {value[1]}, Difference: {value[2]}" + ) diff --git a/benchmark/operators/benchmark_matmul_strategies.py b/benchmark/operators/benchmark_matmul_strategies.py new file mode 100644 index 000000000..114ba900e --- /dev/null +++ b/benchmark/operators/benchmark_matmul_strategies.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas.benchmark import BitblasOperatorBenchmarkBase +from bitblas import Matmul, MatmulConfig +from bitblas.ops.general_matmul import OptimizeStrategy +from bitblas.utils import get_commit_id +from bitblas import set_log_level +from tabulate import tabulate +from os import path, makedirs +from typing import List +import argparse +from tqdm import tqdm + +set_log_level("DEBUG") + + +class BitblasMatmulOpsBenchmarkCompareStategies(BitblasOperatorBenchmarkBase): + + BENCHMARK_RESULTS_FILE = "benchmark_results.json" + BENCHMARK_SHAPES_FILE = "benchmark_shapes.json" + BENCHMARK_DEVICE_FILE = "benchmark_device.json" + + config_map = { + "FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "optimize_stratety": OptimizeStrategy.SingleBatchDecodeOnly, + }, + "FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "optimize_stratety": OptimizeStrategy.ContigousBatching, + }, + } + + OPT_SHAPES = [1, 16, 32, 64, 128, 256, 512, 4096] + + CURRENT_COMMIT_ID = get_commit_id() + + def __init__(self): + super().__init__() + + def prepare_set_group_4x(self, name: str, N, K) -> List: + assert name in self.config_map, f"Operator {name} not found in config map" + optimize_strategy = self.config_map[name]["optimize_stratety"] + return [ + self.generate_op_unit( + self.generate_operator_config( + name, self.OPT_SHAPES if optimize_strategy + == OptimizeStrategy.SingleBatchDecodeOnly else self.OPT_SHAPES[1:], N, K)), + ] + + def prepare_benchmark_sets(self): + """Prepare benchmark sets.""" + self.add_benchmark_set( + "FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", + [ + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 16384, 16384), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV", 8192, 28672), + ], + ) + + self.add_benchmark_set( + "FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + [ + *self.prepare_set_group_4x( + "FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 16384, + 16384, + ), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 3200, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8640, 3200), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 3200, 8640), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 1024, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8192, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 28672, 8192), + *self.prepare_set_group_4x("FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching", + 8192, 28672), + ], + ) + + def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: + """Generate configuration for the given operator.""" + if name not in self.config_map: + raise ValueError(f"Operator {name} not found in config map") + return self.get_operator_config()( + M=M, + N=N, + K=K, + **self.config_map[name], + ) + + def report(self): + """Generate and print a report of the benchmark results.""" + results4compare = {} + for name, results in self.benchmark_results.items(): + name, strategy = name.split("STRATEGY") + results4compare.setdefault(name, {})[strategy] = results + + for name, strategy in results4compare.items(): + table_data = [ + ["TAG:", name, "Device:", self.benchmark_target], + [ + "Shape (M-N-K / N-K_M)", + "Single Batching Time (ms)", + "Shape (M-N-K / N-K_M)", + "Contiguous Batching Time (ms)", + "Tune Time (s)", + ], + ] + + def legalize_shape(M, N, K, dyn_prof_shape): + """Generate a string representation of the operator shape. + + Args: + M: The M dimension (can be an int or a tuple). + N: The N dimension (must be an int). + K: The K dimension (must be an int). + dyn_prof_shape: The dynamic profiling shape (dict with "m" key if M is dynamic). + + Returns: + A string representing the shape in either 'M-N-K' or 'N-K_M' format. + """ + if isinstance(M, int): + return f"{M}-{N}-{K}" + elif dyn_prof_shape and "m" in dyn_prof_shape: + return f"{M}-{N}-{K}_{dyn_prof_shape['m']}" + else: + # Calculate the average of tuple M + str_m = "[" + "-".join(str(m) for m in M) + "]" + opt_m = sum(M) / len(M) + return f"{N}-{K}_{str_m}_{opt_m}" + + data = [] + for strategy_name, results in strategy.items(): + tmp_data = [] + origin_name = f"{name}STRATEGY{strategy_name}" + for i, benchmark_set in enumerate(self.benchmark_sets[origin_name]): + op_config = benchmark_set[1] + sub_results = results[i * len(self.OPT_SHAPES):(i + 1) * len(self.OPT_SHAPES)] + for i, result in enumerate(sub_results): + latency = result[0] + dyn_prof_shape = {"m": self.OPT_SHAPES[i]} + shape = legalize_shape("DYN", op_config.N, op_config.K, dyn_prof_shape) + latency_str = "N/A" if latency is None else f"{latency:.3f}" + tmp_data.append([shape, latency_str]) + if len(data) == 0: + data = tmp_data + else: + for i, item in enumerate(tmp_data): + data[i].extend(item) + + for i, item in enumerate(data): + base = item[1] + head = item[3] + + speedup = float(head) / float(base) - 1 + symbol = "+" if speedup > 0 else "-" + speedup = abs(speedup) + data[i][3] = f"{head} ({symbol}{speedup * 100 :.3f}%)" + table_data.append([*data[i], "N/A"]) + + print(tabulate(table_data, headers="firstrow", tablefmt="fancy_grid")) + + for data in table_data: + print(data) + + def get_operator(self): + """Return the Matmul operator.""" + return Matmul + + def get_operator_config(self): + """Return the Matmul operator configuration.""" + return MatmulConfig + + def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: + """Make an Matmul instance.""" + # Disable default tuning when do benchmark + return operator(config, target=self.benchmark_target, enable_tuning=False) + + def benchmark(self): + """Run benchmarks on all benchmark sets.""" + # Calculate the total number of benchmark runs for the progress bar + total_runs = sum( + len(benchmark_set) * len(self.OPT_SHAPES) + for benchmark_set in self.benchmark_sets.values()) + + with tqdm(total=total_runs, desc="Total Progress", unit="benchmark") as pbar: + for name, benchmark_set in self.benchmark_sets.items(): + self.benchmark_results[name] = [] + for op, config, _ in benchmark_set: + for opt in self.OPT_SHAPES: + print(f"Running benchmark for {name} with shape {opt}") + self.benchmark_results[name].extend( + [self.run_benchmark(op, config, {"m": opt})]) + # Update the progress bar after each run + pbar.update(1) + + def run_compare_strategy(self, report=True, serialize=True, enable_tuning: bool = False): + """Run the benchmark process.""" + + if not path.exists(self.log_path): + makedirs(self.log_path) + + if enable_tuning: + self.enable_tuning() + + self.prepare_benchmark_sets() + self.benchmark() + + if report: + self.report() + + self.cleanup() + + def serialize_results(self) -> None: + """Serialize the benchmark results.""" + pass + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") + + parser.add_argument( + "--enable_tuning", + action="store_true", + help="Enable hardware-aware tuning", + ) + + args = parser.parse_args() + enable_tuning = args.enable_tuning + BitblasMatmulOpsBenchmarkCompareStategies().run_compare_strategy( + enable_tuning=args.enable_tuning) diff --git a/benchmark/operators/benchmark_ops_matmul.py b/benchmark/operators/benchmark_ops_matmul.py index a17baa154..282e26997 100644 --- a/benchmark/operators/benchmark_ops_matmul.py +++ b/benchmark/operators/benchmark_ops_matmul.py @@ -3,12 +3,14 @@ from bitblas.benchmark import BitblasOperatorBenchmarkBase from bitblas import Matmul, MatmulConfig +from bitblas.ops.general_matmul import OptimizeStrategy from bitblas.utils import get_commit_id from bitblas import set_log_level from tabulate import tabulate import json from os import path, makedirs from typing import Tuple, Dict, List, Union +import argparse set_log_level("DEBUG") @@ -54,10 +56,29 @@ class BitblasMatmulOpsBenchmark(BitblasOperatorBenchmarkBase): "accum_dtype": "int32", "out_dtype": "int8", }, + "FP16xUINT4_ACCFP16_NT_STRATEGY_GEMV": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "optimize_stratety": OptimizeStrategy.GEMV, + }, + "FP16xUINT4_ACCFP16_NT_STRATEGY_ContigiousBatching": { + "A_dtype": "float16", + "W_dtype": "uint4", + "accum_dtype": "float16", + "optimize_stratety": OptimizeStrategy.ContigousBatching, + }, } CURRENT_COMMIT_ID = get_commit_id() + def __init__(self, optimize_strategy: Union[int, OptimizeStrategy, None] = None): + super().__init__() + if optimize_strategy is not None: + self.optimize_strategy = optimize_strategy + else: + self.optimize_strategy = OptimizeStrategy.SingleBatchDecodeOnly + def prepare_set_group_4x(self, name: str, M, N, K) -> List: return [ self.generate_op_unit(self.generate_operator_config(name, 1, N, K)), @@ -102,47 +123,40 @@ def prepare_set_group_llm(self, name: str, N, K) -> List: ), ] + def get_llm_benchmark_sets(self, name: str) -> List: + return [ + *self.prepare_set_group_llm(name, 3200, 3200), + *self.prepare_set_group_llm(name, 8640, 3200), + *self.prepare_set_group_llm(name, 3200, 8640), + *self.prepare_set_group_llm(name, 5120, 5120), + *self.prepare_set_group_llm(name, 13824, 5120), + *self.prepare_set_group_llm(name, 5120, 13824), + *self.prepare_set_group_llm(name, 6656, 6656), + *self.prepare_set_group_llm(name, 17920, 6656), + *self.prepare_set_group_llm(name, 6656, 17920), + *self.prepare_set_group_llm(name, 1024, 8192), + *self.prepare_set_group_llm(name, 8192, 8192), + *self.prepare_set_group_llm(name, 28672, 8192), + *self.prepare_set_group_llm(name, 8192, 28672) + ] + def prepare_benchmark_sets(self): """Prepare benchmark sets.""" self.add_benchmark_set( "FP16xFP16_ACCFP16_NT", [ *self.prepare_set_group_4x("FP16xFP16_ACCFP16_NT", 16384, 16384, 16384), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 3200), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8640, 3200), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 3200, 8640), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 5120), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 13824, 5120), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 5120, 13824), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 6656), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 17920, 6656), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 6656, 17920), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 1024, 8192), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 8192), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 28672, 8192), - # *self.prepare_set_group_llm("FP16xFP16_ACCFP16_NT", 8192, 28672), + *self.get_llm_benchmark_sets("FP16xFP16_ACCFP16_NT"), ], ) - # self.add_benchmark_set( - # "INT8xINT8_ACCINT32_NT", - # [ - # *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 3200), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8640, 3200), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 3200, 8640), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 5120), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 13824, 5120), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 5120, 13824), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 6656), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 17920, 6656), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 6656, 17920), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 1024, 8192), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 8192), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 28672, 8192), - # *self.prepare_set_group_llm("INT8xINT8_ACCINT32_NT", 8192, 28672), - # ], - # ) + self.add_benchmark_set( + "INT8xINT8_ACCINT32_NT", + [ + *self.prepare_set_group_4x("INT8xINT8_ACCINT32_NT", 16384, 16384, 16384), + *self.get_llm_benchmark_sets("INT8xINT8_ACCINT32_NT"), + ], + ) def generate_operator_config(self, name: str, M, N, K) -> MatmulConfig: """Generate configuration for the given operator.""" @@ -270,6 +284,7 @@ def legalize_shape(M, N, K, dyn_prof_shape): dyn_prof_shape = self.benchmark_sets[name][i][2] shape = legalize_shape(op_config.M, op_config.N, op_config.K, dyn_prof_shape) + # This is a bug and should be fixed. benchmark_M = ( sum(op_config.M) / len(op_config.M) if isinstance(op_config.M, Tuple) else op_config.M) @@ -300,4 +315,13 @@ def make_operator(self, operator: Matmul, config: MatmulConfig) -> Matmul: if __name__ == "__main__": - BitblasMatmulOpsBenchmark().run(enable_tuning=False) + parser = argparse.ArgumentParser(description="Bitblas Matmul Operator Benchmark") + parser.add_argument( + "--enable_tuning", + action="store_true", + help="Enable hardware-aware tuning", + ) + + args = parser.parse_args() + enable_tuning = args.enable_tuning + BitblasMatmulOpsBenchmark().run(enable_tuning=args.enable_tuning) diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py index 191614dfa..14ee510c4 100644 --- a/bitblas/base/roller/hint.py +++ b/bitblas/base/roller/hint.py @@ -245,5 +245,7 @@ def complete_config(self, node: PrimFuncNode): # int32 and float32 accum may take too much shared memory if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: merge_static_smem = True + # Always merge static shared memory + merge_static_smem = False self.pass_context = {"tir.merge_static_smem": merge_static_smem} return self diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py index 468498fbd..f5a0d1f24 100644 --- a/bitblas/base/roller/policy/tensorcore.py +++ b/bitblas/base/roller/policy/tensorcore.py @@ -333,8 +333,7 @@ def _score(node, thread): # small is better logger.info(info_message) codegen_dict.shared_scope = "shared.dyn" - # Or assume we always use shared memory - # codegen_dict.shared_scope = "shared.dyn" + codegen_dict.shared_scope = "shared.dyn" codegen_dict.complete_config(node) codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 71839b957..92c642b68 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -238,11 +238,12 @@ def tvm_callback_cuda_postproc(code, _): continue elif map_result.status == StatusKind.COMPLETE: idx, code, artifact_path = map_result.value - if artifact_path is None: - logger.debug("Artifact path is None") - continue sch = _sched[idx] config = configs[idx] + if artifact_path is None: + ARTIFACT_NOT_FOUND = f"Apply config {config} failed, artifact path is None" + logger.debug(ARTIFACT_NOT_FOUND) + continue rt_mod = tvm.runtime.load_module(artifact_path) cpresult = CompileResult(config, sch, rt_mod) timer_cuda_mod = rt_mod.time_evaluator( diff --git a/bitblas/benchmark/operator/__init__.py b/bitblas/benchmark/operator/__init__.py index f59ca34ee..7c21d9d0c 100644 --- a/bitblas/benchmark/operator/__init__.py +++ b/bitblas/benchmark/operator/__init__.py @@ -9,6 +9,10 @@ from bitblas.utils import get_default_cache_path from bitblas import auto_detect_nvidia_target from bitblas import tvm as tvm +from bitblas.cache import OperatorCache +import logging + +logger = logging.getLogger(__name__) class BitblasOperatorBenchmarkBase(ABC): @@ -28,6 +32,9 @@ class BitblasOperatorBenchmarkBase(ABC): # Log path log_path: Optional[str] = path.join(get_default_cache_path(), "benchmark") + # Operator cache + operator_cache: OperatorCache = OperatorCache() + @abstractmethod def prepare_benchmark_sets(self): pass @@ -98,6 +105,14 @@ def run_benchmark( dynamic_profiling_shape: Optional[Dict[str, int]] = None, ) -> Optional[float]: """Run a single benchmark.""" + + if self.operator_cache.exists(config): + logger.info(f"Operator {config} found in cache") + op_inst = self.operator_cache.get(config) + latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) + op_inst.cleanup() + return latency, None + op_inst = self.make_operator(operator, config) tuning_time = None @@ -106,6 +121,8 @@ def run_benchmark( op_inst.hardware_aware_finetune(topk=20, parallel_build=True) tuning_time = perf_counter() - start + self.operator_cache.add(config, op_inst) + latency = op_inst.profile_latency(dynamic_symbolic_constraints=dynamic_profiling_shape) op_inst.cleanup() diff --git a/bitblas/cache/__init__.py b/bitblas/cache/__init__.py index 0c8fd3b9c..ee522ec3f 100644 --- a/bitblas/cache/__init__.py +++ b/bitblas/cache/__init__.py @@ -6,4 +6,5 @@ load_global_ops_cache, # noqa: F401 get_database_path, # noqa: F401 set_database_path, # noqa: F401 + OperatorCache, # noqa: F401 ) diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py index 8700e6580..5d92f99b1 100644 --- a/bitblas/gpu/matmul_mma.py +++ b/bitblas/gpu/matmul_mma.py @@ -315,7 +315,17 @@ def store_output(block_outer, write_buffer_idx): sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) - sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) + weight_transform_kind = 0 + if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs: + weight_transform_kind = func.attrs["weight_transform_kind"] + if weight_transform_kind >= TransformKind.LDMatrixTransform: + fused = sch.fuse(sch.get_loops(block_read_reg_b)[-2:]) + vec_len = get_coalesced_veclen(sch.get(block_read_reg_b)) + f0, f1, f2 = sch.split(fused, factors=[None, 32, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + else: + sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index 2033b8f75..7c40d3243 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -628,9 +628,19 @@ def get_idx(): i0, i1 = sch.split(i, factors=[None, b_lr[0]]) j0, j1 = sch.split(j, factors=[None, b_lr[1]]) sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) + weight_transform_kind = 0 + if hasattr(func, "attrs") and "weight_transform_kind" in func.attrs: + weight_transform_kind = func.attrs["weight_transform_kind"] + if weight_transform_kind >= TransformKind.LDMatrixTransform: + fused = sch.fuse(i1, j1) + vec_len = get_coalesced_veclen(sch.get(B_mat)) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + else: + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) def tensorize_init_store_compute(): sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) @@ -1176,7 +1186,7 @@ def sch_shared_memory_prefetch_with_config( """ weight_transform_kind = config.intrin_info.weight_transform_kind - if weight_transform_kind == TransformKind.LDMatrixTransform: + if weight_transform_kind == TransformKind.LDMatrixTransform and config.block_reduction_depth is not None: return self.sch_warp_memory_prefetch_with_config(func, config) is_cross_thread_reduce = ( diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index f148097d6..dd79b4bfa 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -18,7 +18,6 @@ from bitblas.quantization.utils import general_compress from bitblas import auto_detect_nvidia_target -BITBLAS_TARGET = auto_detect_nvidia_target() BITBLAS_DATABASE_PATH = get_database_path() @@ -59,7 +58,7 @@ def unpack_qweight(qweight, bits): class Linear(nn.Module): - opt_M = [1, 16, 32, 64, 128, 256, 512] + opt_M = [16, 32, 64, 128, 256, 512] STORAGE_DTYPE = "int8" # assume int8 storage TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) BITBLAS_DTYPES = { @@ -224,6 +223,8 @@ def _configure_bitblas_matmul( self.source_format = self.bitblas_matmul.source_format def _get_or_create_bitblas_operator(self, config, enable_tuning): + BITBLAS_TARGET = auto_detect_nvidia_target() + if global_operator_cache.size() == 0: global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) logger.info(f"Loaded {global_operator_cache.size()} operators from database.") diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 7d99d9628..16908dd41 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -4,6 +4,7 @@ from tvm.target import Target import operator from functools import reduce +from enum import IntEnum from bitblas.base.arch.cuda import CUDA from typing import Any, Literal, Optional, Tuple, Union from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU @@ -41,6 +42,15 @@ def is_native_compute(A_dtype, W_dtype) -> bool: return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS +CONFIG_INFO_MESSAGE_STRATEGY = """Optimization Strategy Notice: You are currently using the "{}" optimization strategy. If you wish to change this, you can do so by setting the `optimize_strategy` in the Config. The **SingleBatchDecodeOnly** strategy provides the best performance when the batch size (M) is 1. On the other hand, the **ContiguousBatching** strategy is optimized for situations where the batch size (M) is greater than 1. However, please note that using ContiguousBatching for M=1 will result in a slight performance decrease of about 5%. +""" + + +class OptimizeStrategy(IntEnum): + SingleBatchDecodeOnly = 0 + ContigousBatching = 1 + + @dataclass(frozen=True) class MatmulConfig(OperatorConfig): M: Union[int, Tuple[int]] = None @@ -75,6 +85,9 @@ class MatmulConfig(OperatorConfig): None # propagate_b is a flag to control the ladder permutation ) + # optimize strategy, default is SingleBatchDecodeOnly + optimize_stratety: Union[int, OptimizeStrategy] = OptimizeStrategy.SingleBatchDecodeOnly + def __legalize_dynamic_symbolic(self, M): return tuple(self.M) if isinstance(self.M, list) else self.M @@ -86,6 +99,11 @@ def __legalize_propagate(self, propagate): return propagate + def __legalize_optimize_strategy(self, optimize_stratety): + if isinstance(optimize_stratety, int): + return OptimizeStrategy(optimize_stratety) + return optimize_stratety + def __initialize_propagate(self, propagate_a: Optional[TransformKind], propagate_b: Optional[TransformKind]): MICRO_KERNEL_SIZE = 16 @@ -111,6 +129,13 @@ def __initialize_propagate(self, propagate_a: Optional[TransformKind], if propagate_b is not None: object.__setattr__(self, "propagate_b", propagate_b) + # enhance propagate_b into ldmatrix transform if allowed + if (self.optimize_stratety == OptimizeStrategy.ContigousBatching + # TODO(lei): Should add ladder stage 3 inverse layout propagation in the expr. + # And recover the layout in the schedule template. + and (self.M != 1 or (isinstance(self.M, Tuple) and 1 not in self.M))): + object.__setattr__(self, "propagate_b", TransformKind.LDMatrixTransform) + # TODO(lei): This is a limitation arose by pytorch and llvm # Should be removed in the future. if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: @@ -144,7 +169,10 @@ def is_not_fast_decoding_supported(): def __post_init__(self): # set M to default dynamic range if it is None if self.M is None: - object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + if self.optimize_stratety == OptimizeStrategy.SingleBatchDecodeOnly: + object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + else: + object.__setattr__(self, "M", [16, 32, 64, 128, 256, 512, 1024]) if self.N is None: raise ValueError("N should be specified currently.") if self.K is None: @@ -158,6 +186,10 @@ def __post_init__(self): object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + # set optimize_stratety to legal value + object.__setattr__(self, "optimize_stratety", + self.__legalize_optimize_strategy(self.optimize_stratety)) + # This is hack to legalize propagate_a and b # TODO(lei): should be removed in the future when tc+br template is ready. self.__initialize_propagate(self.propagate_a, self.propagate_b) @@ -547,7 +579,8 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.lib is None: self._forward_from_torch_func(*args) - self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + else: + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) return output diff --git a/bitblas/ops/general_matmul_splitk.py b/bitblas/ops/general_matmul_splitk.py index 28e3cbbf2..39671432a 100644 --- a/bitblas/ops/general_matmul_splitk.py +++ b/bitblas/ops/general_matmul_splitk.py @@ -160,18 +160,16 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if output is None: output = torch.empty( - A.shape[:-1] + (self.N,), - dtype=self.torch_output_dtype, - device=A.device) + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) if scale is not None: args.append(scale) if zeros is not None: args.append(zeros) if bias is not None: args.append(bias) - - sk_output = torch.empty((self.k_split,) + - A.shape[:-1] + (self.N,), + + sk_output = torch.empty( + (self.k_split,) + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) args.append(sk_output) @@ -184,7 +182,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.lib is None: self._forward_from_torch_func(*args) - self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + else: + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + torch.sum(sk_output, dim=0, out=output) return output diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py index cc1b60de0..657b45a42 100644 --- a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -9,6 +9,7 @@ from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, _tir_u8_to_f8_e4m3_to_f16) +from typing import Union def matmul_nt_dequantize_b( @@ -132,6 +133,201 @@ def decode_func(n, k): return tvm.IRModule.from_expr(func) +def matmul_nt_dequantize_b_propagate_b( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform, + transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform, +): + if isinstance(transform_kind_input, int): + transform_kind_input = TransformKind(transform_kind_input) + if isinstance(transform_kind_weight, int): + transform_kind_weight = TransformKind(transform_kind_weight) + + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + RK = K // SplitK + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum( + A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), + axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((SplitK, M, N), + lambda b, i, j: last_output[b, i, j].astype(out_dtype), + name="D") + last_output = D + + args = [A, B] + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + return tvm.IRModule.from_expr(func) + + def matmul_nt_dequantize_b_propagate_a_propagate_b( SplitK, M, @@ -149,9 +345,14 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( fast_decoding=False, with_bias=False, zeros_mode="original", - transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, - transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform, + transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform, ): + if isinstance(transform_kind_input, int): + transform_kind_input = TransformKind(transform_kind_input) + if isinstance(transform_kind_weight, int): + transform_kind_weight = TransformKind(transform_kind_weight) + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) if not isinstance(M, int): M = tvm.te.var("m") diff --git a/docs/ExtendOperatorsWithDSL.md b/docs/ExtendOperatorsWithDSL.md index 1c01602b9..279cc2490 100644 --- a/docs/ExtendOperatorsWithDSL.md +++ b/docs/ExtendOperatorsWithDSL.md @@ -1,7 +1,7 @@ ### Using BitBLAS from DSL ```python from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA +from bitblas.base.arch import CUDA from bitblas.base.utils import apply_and_build @tvm.script.ir_module class MatmulNT: diff --git a/testing/python/cache/test_operator_cache.py b/testing/python/cache/test_operator_cache.py index 51a155f7f..e0a9d0118 100644 --- a/testing/python/cache/test_operator_cache.py +++ b/testing/python/cache/test_operator_cache.py @@ -6,10 +6,13 @@ import bitblas from bitblas import Matmul, MatmulConfig from bitblas.cache import global_operator_cache +from bitblas import tvm as tvm +from tvm.contrib import utils target = bitblas.utils.auto_detect_nvidia_target() bitblas.set_log_level("DEBUG") + def get_codegen_result(ops, target): code = ops.get_source(target=target) return code @@ -38,7 +41,7 @@ def test_config_hashable( layout, enable_tuning, ): - + global_operator_cache.clear() matmul_config = MatmulConfig( M=M, N=N, @@ -90,7 +93,7 @@ def test_global_cache_inquery( layout, enable_tuning, ): - + global_operator_cache.clear() matmul_config = MatmulConfig( M=M, N=N, @@ -143,7 +146,7 @@ def test_global_cache_inquery_torch_forward( layout, enable_tuning, ): - + global_operator_cache.clear() matmul_config = MatmulConfig( M=M, N=N, @@ -217,7 +220,7 @@ def test_global_cache_save_to_database( layout, enable_tuning, ): - + global_operator_cache.clear() matmul_config = MatmulConfig( M=M, N=N, @@ -244,7 +247,9 @@ def test_global_cache_save_to_database( print(hash_error) assert success - database_path = "/tmp/.tmp_bitblas_cache.db" + tempdir = utils.tempdir() + database_path = str(tempdir.path) + global_operator_cache.save_into_database(database_path, target=target) assert os.path.exists(database_path) global_operator_cache.clear() diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index 3adacaa8c..e15a7adc5 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -59,7 +59,9 @@ def correctness_weight_only_dequantize( ): import numpy as np from bitblas.quantization.utils import general_compress + from bitblas.cache import global_operator_cache + global_operator_cache.clear() linear_bitblas = BitBLASLinear( in_features, out_features, diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 354914d22..2d6890577 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -134,6 +134,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo with_scaling=with_scaling, with_zeros=with_zeros, zeros_mode=zeros_mode, + propagate_a=False, ) matmul = Matmul(config=matmul_config, enable_tuning=False) @@ -194,6 +195,8 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) + print(permuted_inputs[-1]) + print(ref_result) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) else: diff --git a/testing/python/operators/test_general_matmul_tile_schedule.py b/testing/python/operators/test_general_matmul_tile_schedule.py index 68452140f..c2d263c7b 100644 --- a/testing/python/operators/test_general_matmul_tile_schedule.py +++ b/testing/python/operators/test_general_matmul_tile_schedule.py @@ -127,6 +127,7 @@ def assert_correctness_with_ladder_ldmatrix_propagate( accum_dtype="float16", block_reduction_depth=1, ): + propagate_b = 3 matmul_func = matmul_select_implementation( M=M, N=N, @@ -135,8 +136,8 @@ def assert_correctness_with_ladder_ldmatrix_propagate( out_dtype=out_dtype, accum_dtype=accum_dtype, propagate_a=0, - propagate_b=3)["main"] - propagate_b = 3 + propagate_b=propagate_b)["main"] + target = bitblas.auto_detect_nvidia_target() intrin_info = bitblas.base.hint.IntrinInfo( in_dtype=in_dtype, @@ -169,7 +170,6 @@ def assert_correctness_with_ladder_ldmatrix_propagate( "tir.merge_static_smem": False }): block_reduce_rt_mod = tvm.build(block_reduce_sch.mod, target=target) - # Evaluate the correctness import numpy as np a = np.random.randn(M, K).astype(np.float16 if in_dtype == "float16" else "int8") @@ -456,7 +456,7 @@ def assert_dequantize_correctness_with_ladder_ldmatrix_propagate( dequantize_bits=bit, storage_dtype="int8", transpose_matrix=True, - transform_kind=3, + transform_kind=propagate_b, ) ladder_permutate = bitblas.ops.LadderPermutate(ladder_permutate_config)