Skip to content

Commit

Permalink
[Dev] Bring Block Reduction into our seach space and policy (#132)
Browse files Browse the repository at this point in the history
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* bug fix in test

* lint fix.

* test cuda i4 kernel

* Refactor copyright notice in i4matmul.hpp

* Refactor BitBLASLinear test module for improved readability and maintainability

* refactor test as version below python 3.9 cannot handle int32 overflow.

* format lint for test

* Refactor test_int4b_fp16_convert.py for improved readability and maintainability

* remove unused design file

* move tile device from package to base

* dummy impl for codegen

* Refactor file structure for ladder_permutate module

* Refactor backend class and fix typos in comments

* Deep refactor Lib related code.

* remove ci pull.

* LintFix

* refactor builder for whl build

* Refactor TIRWrapper.wrap() method to include an assertion for the optimized module

* Refactor lib_generator to set library and source paths

* lint fix

* BitNet vllm integration

* chore: update codespell to version 2.3.0

* Lintfix

* Bump version to 0.0.1.dev13

* lint fix

* disable fast decoding [u]int4xint8 by default.

* optimize from dict design in Hint

* Implement SplitK

* bitnet benchmark generation.

* Add benchmark script for BitNet integration

* AtomicAdd Support

* LintFix

* ci fix when 3rdparty tvm is initialized.

* bug fix for setup

* fix a bug in block reduce

* typo fix

* BUG Fix for block reduce.

* Lint fix

* Refactor block reduce schedule template

* transform branch from bitblas to bitblas_tl

* Fix subproject commit reference in 3rdparty/tvm

* chore: update submodule branch from bitblas to bitblas_tl

* force update config.cmake

* Bug fix

* Fix subproject commit reference in 3rdparty/cutlass

* chore: Add submodule for cutlass library

* update tl cutlass path

* Refactor BitBLASLinear test module for improved readability and maintainability

* format fix

* Copy CUTLASS to the package directory

* Refactor setup.py to include additional TVM header files

* lint fix

* bug fix

* Refactor BitBLASLinear test module for improved readability and maintainability

* Implement Matmul Benchmark Design

* chore: Update BitBLAS Matmul benchmark script

* lint fix

* Refactor BitBLASMatmulOpsBenchmark for improved readability and maintainability

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* lint fix

* Benchmark bot test

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* int8 test case

* Refactor compare_benchmark.py to handle missing benchmark results gracefully

* ci fix

* disable ci for test benchmark

* Refactor BitBLASMatmulOpsBenchmark to disable tuning during benchmark run

* remove cli installation

* chore: Create virtual environment and install dependencies for benchmark

* chore: Update benchmark workflow to include comparison step

* Lint fix

* upodate tvm cmmit

* Imporve lower warp memory pass

* Bug fix

* Enhance to support warp schedule.

* Enhance LOP3 Instructions

* Enhance LOP3 Instructions

* add test for stage3 propagate

* implement propagate func

* Stage3 Ladder Permutate integration

* get_ladder_stage3_propagate

* comments benchmark scirpts as the setting is too big

* ci fix for benchmark

* lint fix

* chore: Update benchmark workflow to trigger on pull request comments

* Add LDMatrix Transform 3

* Support GPTQ Test

* Fuse BlockReduce Schedule

* Support mma propagate 3

* Support MMA Propagate Stage 3

* Lint Fix

* Merge block reduce for dequantze config.

* fix codeql

* chore: Update submodule reference to latest commit

* chore: Disable common subexpression elimination in TIR passes

* Lint Fix

* 4bit related lop3 updates.

* lint fix

* gptq test fix

* Fix for test

* lint fix

* lint fix

* typofix

* QuantCompress Test

* chore: Refactor quant_compress_impl.py for readability and maintainability

* Enhance docs to update latest works.

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* Refactor weight executors in Matmul class for improved readability and maintainability

* removed legacy operator

* Refactor weight executors in Matmul class for improved readability and maintainability

* LintFix

* Fix GPTQ Repack with the latest weight transform

* lint fix

* bug fix for rescale dequantize

* test fix

* typo fix

* lint fix
  • Loading branch information
LeiWang1999 authored Aug 5, 2024
1 parent 5d14d31 commit 2e60d2b
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 79 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
158 changes: 85 additions & 73 deletions bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,83 +117,92 @@ def _check_small_tile(td: TileDict):
return True
return False

if not _check_small_tile(td):
return None
if _check_small_tile(td):

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep

def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map,
td.tensor_strides_map[node])

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score

def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]

cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]]
for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
# todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k]
for k in all_steps:
all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k]))
if any([v == [] for v in all_steps.values()]):
return rstep
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)

def _shared_memory_usage(td: TileDict):
return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node])
if self.block_reduction_depth is not None:

def _score(rstep_id):
rstep = {
k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis
}
score = 0
shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep)
input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block)
for i, input_buffer in enumerate(input_buffers):
score += coalesced_factor(shape[i], input_buffer.shape)
return score

def _enlarge(rstep_id):
candidates = []
for ax in rstep_id:
if rstep_id[ax] + 1 == len(all_steps[ax]):
continue
r = rstep_id.copy()
r[ax] += 1
candidates.append((r, _score(r)))
if len(candidates) == 0:
return None
return max(candidates, key=lambda x: x[1])[0]

cur_rstep_id = {
k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis
}
new_rstep_map = rstep_map.copy()
while True:
new_rstep_id = _enlarge(cur_rstep_id)
if new_rstep_id is None:
break
new_rstep_map = {
k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis
}
old_rstep_map = td.rstep_map
td.rstep_map = new_rstep_map
smem_usage, _ = _shared_memory_usage(td)
td.rstep_map = old_rstep_map
if smem_usage > smem_limit:
break
else:
cur_rstep_id = new_rstep_id
rstep = {
k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis
}
return rstep
def _expand_with_tags(rstep):
new_rstep = {k: v * self.block_reduction_depth for k, v in rstep.items()}
return new_rstep

rstep_map = td.rstep_map.copy()
for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _expand_with_tags(rstep_map)
rstep_map = rstep
td.rstep_map = rstep_map

for node in self.ordered_nodes:
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

# if is_block_reduction:
# # If block reduction, we should constrain the max value is 64
# # Otherwise it will introduce an issue of cuda invalid args.
# MAX_REDUCE_K = 64
# for k in rstep_map:
# rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
return

def get_node_reduce_step_candidates(self, node):
Expand Down Expand Up @@ -318,12 +327,15 @@ def _score(node, thread): # small is better
# smem capacity
# TODO: This is a dummy mul which avoid reusing some shared memory.
# Should be removed in the future.
if td.smem_cost > (self.arch.smem_cap * 1.3):
if td.smem_cost > (self.arch.smem_cap):
info_message = f"Tile Dict: {td.output_tile} Shared memory exceeds the static capacity," \
" use dynamic shared memory."
logger.info(info_message)
codegen_dict.shared_scope = "shared.dyn"

# Or assume we always use shared memory
# codegen_dict.shared_scope = "shared.dyn"

codegen_dict.complete_config(node)
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
Expand Down
6 changes: 4 additions & 2 deletions bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,14 +622,16 @@ def check_last_trait(region: List[Range]):
# Analysis Block Reduction Optimization
# Currently, we only support block reduction depth 2 for small M
# When the func is a dequantize like ops, we should consider the M
require_block_reduce = False
if hasattr(func.attrs, "dequantize_info"):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
if isinstance(M, tir.IntImm) and M <= 128:
tags["block_reduction_depth"] = 2
require_block_reduce = True
break

if require_block_reduce and check_sm_version(target.arch) == 80:
tags["block_reduction_depth"] = 2
return tags

(main_block,) = reduction_blocks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def matmul_nt_dequantize_b_propagate_b(
fast_decoding=False,
with_bias=False,
zeros_mode="original",
transform_kind: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
):
if isinstance(transform_kind, int):
transform_kind = TransformKind(transform_kind)
Expand Down Expand Up @@ -699,8 +699,8 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b(
fast_decoding=False,
with_bias=False,
zeros_mode="original",
transform_kind_input: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind_weight: Union[int, TransformKind] = TransformKind.NonTransform,
transform_kind_input: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
transform_kind_weight: Union[int, TransformKind] = TransformKind.IntraWarpTransform,
):
if isinstance(transform_kind_input, int):
transform_kind_input = TransformKind(transform_kind_input)
Expand Down

0 comments on commit 2e60d2b

Please sign in to comment.