-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Dev] Set default weight transformation into Ladder Stage3 LDMatrixTr…
…ansform (#133) * Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * bug fix in test * lint fix. * test cuda i4 kernel * Refactor copyright notice in i4matmul.hpp * Refactor BitBLASLinear test module for improved readability and maintainability * refactor test as version below python 3.9 cannot handle int32 overflow. * format lint for test * Refactor test_int4b_fp16_convert.py for improved readability and maintainability * remove unused design file * move tile device from package to base * dummy impl for codegen * Refactor file structure for ladder_permutate module * Refactor backend class and fix typos in comments * Deep refactor Lib related code. * remove ci pull. * LintFix * refactor builder for whl build * Refactor TIRWrapper.wrap() method to include an assertion for the optimized module * Refactor lib_generator to set library and source paths * lint fix * BitNet vllm integration * chore: update codespell to version 2.3.0 * Lintfix * Bump version to 0.0.1.dev13 * lint fix * disable fast decoding [u]int4xint8 by default. * optimize from dict design in Hint * Implement SplitK * bitnet benchmark generation. * Add benchmark script for BitNet integration * AtomicAdd Support * LintFix * ci fix when 3rdparty tvm is initialized. * bug fix for setup * fix a bug in block reduce * typo fix * BUG Fix for block reduce. * Lint fix * Refactor block reduce schedule template * transform branch from bitblas to bitblas_tl * Fix subproject commit reference in 3rdparty/tvm * chore: update submodule branch from bitblas to bitblas_tl * force update config.cmake * Bug fix * Fix subproject commit reference in 3rdparty/cutlass * chore: Add submodule for cutlass library * update tl cutlass path * Refactor BitBLASLinear test module for improved readability and maintainability * format fix * Copy CUTLASS to the package directory * Refactor setup.py to include additional TVM header files * lint fix * bug fix * Refactor BitBLASLinear test module for improved readability and maintainability * Implement Matmul Benchmark Design * chore: Update BitBLAS Matmul benchmark script * lint fix * Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * lint fix * Benchmark bot test * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * int8 test case * Refactor compare_benchmark.py to handle missing benchmark results gracefully * ci fix * disable ci for test benchmark * Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run * remove cli installation * chore: Create virtual environment and install dependencies for benchmark * chore: Update benchmark workflow to include comparison step * Lint fix * upodate tvm cmmit * Imporve lower warp memory pass * Bug fix * Enhance to support warp schedule. * Enhance LOP3 Instructions * Enhance LOP3 Instructions * add test for stage3 propagate * implement propagate func * Stage3 Ladder Permutate integration * get_ladder_stage3_propagate * comments benchmark scirpts as the setting is too big * ci fix for benchmark * lint fix * chore: Update benchmark workflow to trigger on pull request comments * Add LDMatrix Transform 3 * Support GPTQ Test * Fuse BlockReduce Schedule * Support mma propagate 3 * Support MMA Propagate Stage 3 * Lint Fix * Merge block reduce for dequantze config. * fix codeql * chore: Update submodule reference to latest commit * chore: Disable common subexpression elimination in TIR passes * Lint Fix * 4bit related lop3 updates. * lint fix * gptq test fix * Fix for test * lint fix * lint fix * typofix * QuantCompress Test * chore: Refactor quant_compress_impl.py for readability and maintainability * Enhance docs to update latest works. * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * Refactor weight executors in Matmul class for improved readability and maintainability * removed legacy operator * Refactor weight executors in Matmul class for improved readability and maintainability * LintFix * Fix GPTQ Repack with the latest weight transform * lint fix * bug fix for rescale dequantize * test fix * typo fix * lint fix * Set default weight propagate kind into LDMatrixTransform * lint fix * bug fix * bug fix for test * set default to stage3 * revert change * lint fix * case fix * bug fix * fix for legalize * bug fix * chore: Clear global operator cache before running tests * revert optimize_stratety into SingleBatchDecodeOnly * typofix * update benchmark scripts * chore: Refactor benchmark scripts and fix typos * fix for testing * lint fix * fix import. * typo * operator benchmark * optimize * always with shared.dyn * optimize cache. * dsl fix * tqdm * chore: Add serialize_results method to benchmark_matmul_strategies.py * fix performance issue for dynamic async copy * chore: Refactor benchmark_matmul_strategies.py for improved performance and code readability * bug fix * update readme
- Loading branch information
1 parent
0e1e366
commit e71b85d
Showing
29 changed files
with
942 additions
and
74 deletions.
There are no files selected for viewing
Submodule tvm
updated
2 files
+4 −38 | src/target/source/ptx.cc | |
+2 −1 | src/tir/transforms/lower_warp_memory.cc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
import tqdm | ||
import bitblas | ||
from bitblas.base.arch import CUDA | ||
from bitblas.ops.general_matmul.tirscript.matmul_dequantize_impl import ( | ||
matmul_nt_dequantize_b_propagate_b,) | ||
import tvm | ||
import itertools | ||
|
||
search_space = { | ||
"block_row_warps": [1], | ||
"block_col_warps": [1, 2, 4, 8, 16], | ||
"warp_row_tiles": [1], | ||
"warp_col_tiles": [1, 2, 4, 8, 16], | ||
"chunk": [1, 2, 4, 8, 16, 32], | ||
"stage": [2, 3, 4], | ||
"block_reduce": [2, 4], | ||
} | ||
|
||
keys = search_space.keys() | ||
values = search_space.values() | ||
combinations = list(itertools.product(*values)) | ||
|
||
combinations_dicts = [dict(zip(keys, combination)) for combination in combinations] | ||
|
||
# for combination in combinations_dicts: | ||
# print(combination) | ||
print(len(combinations_dicts)) | ||
group_size = -1 | ||
# fmt:off | ||
llm_shape_fp16xint4 = [ | ||
# square test | ||
(matmul_nt_dequantize_b_propagate_b, (16, 16384, 16384, "float16", "float16", "float16", 4, | ||
"int8", "uint", False, False, group_size, True, False)), | ||
] | ||
|
||
# fmt:on | ||
|
||
target = tvm.target.Target(bitblas.auto_detect_nvidia_target()) | ||
benchmark_sets = llm_shape_fp16xint4 | ||
tuning_results = {} | ||
|
||
min_time = 1e9 | ||
min_combination = None | ||
sucess_combinations = [] | ||
for get_prim_func, input_args in benchmark_sets: | ||
ir_module = get_prim_func(*input_args, transform_kind=3) | ||
func = ir_module["main"] | ||
arch = CUDA(target) | ||
|
||
M, N, K = input_args[0], input_args[1], input_args[2] | ||
import numpy as np | ||
|
||
np.random.seed(0) | ||
# a = np.random.randn(M // 16, K // 16, 16, 16).astype(np.float16) | ||
a = np.random.randn(M, K).astype(np.float16) | ||
b = np.random.randn(N // 16, K // 16, 16, 8).astype(np.int8) | ||
c = np.random.randn(M, N).astype(np.float16) | ||
|
||
tvm_a = tvm.nd.array(a, device=tvm.cuda(0)) | ||
tvm_b = tvm.nd.array(b, device=tvm.cuda(0)) | ||
tvm_c = tvm.nd.array(c, device=tvm.cuda(0)) | ||
|
||
intrin_info = bitblas.base.hint.IntrinInfo( | ||
in_dtype="float16", | ||
out_dtype="float16", | ||
trans_b=True, | ||
input_transform_kind=0, | ||
weight_transform_kind=3, | ||
) | ||
|
||
# set up tqdm | ||
pbar = tqdm.tqdm(combinations_dicts) | ||
for combination in pbar: | ||
pbar.set_description( | ||
f"sucess_combinations: {len(sucess_combinations)} min_time: {min_time}") | ||
block_row_warps = combination["block_row_warps"] | ||
block_col_warps = combination["block_col_warps"] | ||
warp_row_tiles = combination["warp_row_tiles"] | ||
warp_col_tiles = combination["warp_col_tiles"] | ||
chunk = combination["chunk"] | ||
stage = combination["stage"] | ||
block_reduce = combination["block_reduce"] | ||
|
||
mma_row = mma_col = 16 | ||
mma_k = 16 | ||
|
||
block = [ | ||
block_row_warps * warp_row_tiles * mma_row, | ||
block_col_warps * warp_col_tiles * mma_col, | ||
] | ||
warp = [mma_row * warp_row_tiles, mma_col * warp_col_tiles] | ||
rstep = [mma_k * chunk * block_reduce] | ||
pipeline_stage = stage | ||
block_reduction_depth = block_reduce | ||
hint = bitblas.base.Hint.from_dict({ | ||
"use_tc": True, | ||
"arch": arch, | ||
"block": block, | ||
"warp": warp, | ||
"rstep": rstep, | ||
"pipeline_stage": pipeline_stage, | ||
"use_async": True, | ||
"intrin_info": intrin_info, | ||
"shared_scope": "shared.dyn", | ||
"vectorize": { | ||
"b": 8, | ||
"a": 8 | ||
}, | ||
"block_reduction_depth": block_reduction_depth, | ||
"rasterization_plan": bitblas.base.rasterization.Rasterization2DColumn(10), | ||
}) | ||
print("Tuning Hint is", hint) | ||
try: | ||
sch = bitblas.gpu.MatmulTensorizationMMAWithDequantizeInfo( | ||
).sch_warp_memory_prefetch_with_config( | ||
func, config=hint) | ||
|
||
with tvm.transform.PassContext( | ||
config={ | ||
"tir.use_async_copy": True, | ||
"tir.merge_static_smem": False, | ||
"tir.disable_cse_tir": True, | ||
}): | ||
rt_mod = tvm.build(sch.mod, target=target) | ||
|
||
time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, tvm.cuda(0), number=10) | ||
|
||
t = time_evaluator(tvm_a, tvm_b, tvm_c).mean * 1e3 | ||
|
||
print(f"For combination {combination}, time is {t} ms") | ||
tuning_results["-".join([str(v) for v in combination.values()])] = t | ||
if t < min_time: | ||
min_time = t | ||
min_combination = combination | ||
sucess_combinations.append(combination) | ||
except Exception as e: | ||
del e | ||
print(f"Failed for combination {combination}") | ||
continue | ||
|
||
print(f"Minimum time is {min_time} for combination {min_combination}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
from bitblas.utils.target_detector import auto_detect_nvidia_target | ||
from bitblas import Matmul, MatmulConfig | ||
import argparse | ||
from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy | ||
from bitblas.base.arch import CUDA | ||
from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags | ||
from bitblas.base.utils import apply_and_build | ||
|
||
# Initialize the parser | ||
parser = argparse.ArgumentParser(description="Benchmark BitBLAS int4 on a specific target.") | ||
|
||
# Add arguments to the parser | ||
parser.add_argument( | ||
"--target", | ||
type=str, | ||
default=auto_detect_nvidia_target(), | ||
help="Specify the target device for benchmarking.") | ||
parser.add_argument( | ||
"--group_size", type=int, default=None, help="Group size for grouped quantization.") | ||
parser.add_argument( | ||
"--A_dtype", | ||
type=str, | ||
default="float16", | ||
choices=["float16", "float32", "float64", "int32", | ||
"int8"], # Assuming these are the valid choices | ||
help="Data type of activation A.") | ||
parser.add_argument( | ||
"--W_dtype", | ||
type=str, | ||
default="int4", | ||
choices=[ | ||
"float16", "float32", "float64", "int32", "int8", "int4", "int2", "int1", "nf4", "fp4_e2m1" | ||
], # Assuming these are the valid choices | ||
help="Data type of weight W.") | ||
parser.add_argument( | ||
"--accum_dtype", | ||
type=str, | ||
default="float16", | ||
choices=["float16", "int32"], # Assuming these are the valid choices | ||
help="Data type for accumulation.") | ||
parser.add_argument( | ||
"--out_dtype", | ||
type=str, | ||
default="float16", | ||
choices=["float16", "float32", "int32", "int8"], # Assuming these are the valid choices | ||
help="Data type for output.") | ||
parser.add_argument( | ||
"--layout", | ||
type=str, | ||
default="nt", | ||
choices=["nt", "nn"], # Assuming these are the valid choices | ||
help="Matrix layout, 'nt' for non-transpose A and transpose W.") | ||
parser.add_argument("--with_bias", action="store_true", help="Include bias in the benchmark.") | ||
parser.add_argument( | ||
"--with_scaling", action="store_true", help="Include scaling factor in the quantization.") | ||
parser.add_argument("--with_zeros", action="store_true", help="Include zeros in the quantization.") | ||
parser.add_argument( | ||
"--zeros_mode", | ||
type=str, | ||
default=None, | ||
choices=["original", "rescale", "quantized"], # Replace with actual modes if applicable | ||
help="Specify the mode for calculating zeros.") | ||
|
||
# Parse the arguments | ||
args = parser.parse_args() | ||
|
||
# Assign arguments to variables | ||
target = args.target | ||
group_size = args.group_size | ||
A_dtype = args.A_dtype | ||
W_dtype = args.W_dtype | ||
accum_dtype = args.accum_dtype | ||
out_dtype = args.out_dtype | ||
layout = args.layout | ||
with_bias = args.with_bias | ||
group_size = args.group_size | ||
with_scaling = args.with_scaling | ||
with_zeros = args.with_zeros | ||
zeros_mode = args.zeros_mode | ||
|
||
test_shapes = [ | ||
# square test | ||
(MatmulConfig, Matmul, (1, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (16, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (32, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (64, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (128, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (256, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (1024, 16384, 16384, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (16, 43008, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (32, 14336, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (64, 57344, 14336, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (128, 14336, 57344, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (256, 9216, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (128, 36864, 9216, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (64, 9216, 36864, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (32, 22016, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (16, 8192, 22016, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (32, 8192, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (64, 28672, 8192, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
(MatmulConfig, Matmul, (128, 8192, 28672, A_dtype, W_dtype, out_dtype, accum_dtype, layout, | ||
with_bias, group_size, with_scaling, with_zeros, zeros_mode)), | ||
] | ||
|
||
benchmark_sets = [] | ||
benchmark_sets.extend(test_shapes) | ||
|
||
# fmt:on | ||
|
||
benchmark_results = {} | ||
for config, operator, input_args in benchmark_sets: | ||
config = config(*input_args) | ||
matmul = operator(config, target=target, enable_tuning=False) | ||
func = matmul.prim_func | ||
arch = CUDA(target) | ||
policy = DefaultPolicy(func=func, arch=arch) | ||
try: | ||
tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) | ||
except Exception: | ||
tags = None | ||
if tags: | ||
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) | ||
|
||
configs = policy.emit_config(20) | ||
static_configs = [] | ||
for config in configs: | ||
static_config = config | ||
static_config.shared_scope = "shared" | ||
static_configs.append(static_config) | ||
dynamic_configs = [] | ||
for config in configs: | ||
dynamic_config = config | ||
dynamic_config.shared_scope = "shared.dyn" | ||
dynamic_configs.append(dynamic_config) | ||
|
||
_, best_static = apply_and_build(func, static_configs, arch, parallel_build=True) | ||
|
||
_, best_dynamic = apply_and_build(func, dynamic_configs, arch, parallel_build=True) | ||
benchmark_results[input_args] = (best_static.latency, best_dynamic.latency, | ||
best_static.latency - best_dynamic.latency) | ||
|
||
for key, value in benchmark_results.items(): | ||
print( | ||
f"Input arguments: {key}, Static latency: {value[0]}, Dynamic latency: {value[1]}, Difference: {value[2]}" | ||
) |
Oops, something went wrong.