diff --git a/3rdparty/tvm b/3rdparty/tvm index 618306ce3..a077796b9 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 618306ce3baa2c606d43856afbe6655e4e67b2c8 +Subproject commit a077796b9e2dd3b2275fbaa212786645758c360d diff --git a/docs/QuickStart.md b/docs/QuickStart.md index 5a57edbb2..10f9948d3 100644 --- a/docs/QuickStart.md +++ b/docs/QuickStart.md @@ -14,7 +14,7 @@ import torch # enabling debug output -bitblas.set_debug_level("Debug") +bitblas.set_log_level("Debug") matmul_config = bitblas.MatmulConfig( M=1, # M dimension N=1024, # N dimension @@ -129,7 +129,7 @@ import bitblas import torch # enabling debug output -bitblas.set_debug_level("Debug") +bitblas.set_log_level("Debug") model = bitblas.Linear( in_features=1024, @@ -185,7 +185,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( ) # enabling debug output -bitblas.set_debug_level("Debug") +bitblas.set_log_level("Debug") in_features = 1024 out_features = 1024 diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py index 6a18c4703..14b510845 100644 --- a/python/bitblas/__init__.py +++ b/python/bitblas/__init__.py @@ -40,6 +40,7 @@ import logging from tqdm import tqdm + class TqdmLoggingHandler(logging.Handler): """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ @@ -61,6 +62,7 @@ def set_log_level(level): Args: level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). + OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' """ if isinstance(level, str): level = getattr(logging, level.upper(), logging.INFO) diff --git a/python/bitblas/base/roller/arch/cuda.py b/python/bitblas/base/roller/arch/cuda.py index 63775ecbe..2189947e7 100644 --- a/python/bitblas/base/roller/arch/cuda.py +++ b/python/bitblas/base/roller/arch/cuda.py @@ -4,7 +4,7 @@ import tvm from tvm.target import Target from .arch_base import TileDevice -from typing import List, Dict +from typing import List, Dict, Union def check_sm_version(arch: str) -> int: @@ -28,7 +28,9 @@ def __init__( class CUDA(TileDevice): - def __init__(self, target: Target): + def __init__(self, target: Union[Target, str]): + if isinstance(target, str): + target = tvm.target.Target(target) self.target = target self.sm_version = check_sm_version(self.target.arch) device = tvm.runtime.cuda(0) diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py index 89f607cde..f6e2fb03a 100644 --- a/python/bitblas/base/roller/hint.py +++ b/python/bitblas/base/roller/hint.py @@ -154,18 +154,20 @@ def __init__(self) -> None: self.arch = None self.use_tc = None # todo(lei): this should be renamed. - # special axes tiling info + # Special axes tiling info self.block = [] self.thread = [] - # special axes for tensorCore + # Special axes for MMA self.warp = [] - # reduce axes tiling info + # Reduce axes tiling info self.rstep = [] self.reduce_thread = [] self.rasterization_plan = NoRasterization() self.cached_tensors = [] self.output_strides = {} self.schedule_stages = None + # Config for block reduction + self.block_reduction_depth = None # type: int # Experimental self._raxis_order = [] @@ -203,6 +205,10 @@ def to_dict(self) -> Dict: dic["raxis_order"] = self._raxis_order if self.vectorize != {}: dic["vectorize"] = self.vectorize + if self.pipeline_stage != 1: + dic["pipeline_stage"] = self.pipeline_stage + if self.block_reduction_depth is not None: + dic["block_reduction_depth"] = self.block_reduction_depth return dic def from_dict(self, dic: Dict) -> "Hint": diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py index 97edb50fc..f4047ef08 100644 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ b/python/bitblas/base/roller/policy/tensorcore.py @@ -25,6 +25,7 @@ def __init__(self, self.wmma_k = 16 self.pipeline_stage: int = 1 self.use_async_copy: bool = False + self.block_reduction_depth: Optional[int] = None self._legalize_info() def _legalize_info(self): @@ -44,6 +45,11 @@ def _legalize_info(self): self.use_async_copy = True else: self.use_async_copy = False + # TODO: block reduction depth is not used for now. + # As there still exists some performance issues for block reduction. + # block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") + # if block_reduction_depth: + # self.block_reduction_depth = block_reduction_depth def _compute_tc_strides( self, @@ -114,6 +120,7 @@ def _check_small_tile(td: TileDict): smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) rstep_map = td.rstep_map.copy() + is_block_reduction = self.block_reduction_depth is not None def _optimize(node, rstep): all_steps = self.get_node_reduce_step_candidates(node) @@ -177,6 +184,13 @@ def _enlarge(rstep_id): if len(node.raxis) > 0: rstep = _optimize(node, rstep_map) rstep_map = rstep + + if is_block_reduction: + # If block reduction, we should constrain the max value is 64 + # Otherwise it will introduce an issue of cuda invalid args. + MAX_REDUCE_K = 64 + for k in rstep_map: + rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) td.rstep_map = rstep_map td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) return @@ -289,6 +303,7 @@ def _score(node, thread): # small is better codegen_dict.warp = warp_tile codegen_dict.use_tc = True codegen_dict.pipeline_stage = self.pipeline_stage + codegen_dict.block_reduction_depth = self.block_reduction_depth codegen_dict.use_async = self.use_async_copy codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis] codegen_dict.cached_tensors = td.cached_tensors_map[node] diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index 7da309dd5..50adc135f 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -168,6 +168,7 @@ def apply_and_build_parallel(func, arch, num_repeats=3, max_workers=10, + timeout=30, data_distribution="uniform") -> CompileResult: cpresults = [] @@ -187,10 +188,10 @@ def _apply_schedule(f, c): with ThreadPoolExecutor(max_workers=4) as scheduler: futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} - for future in as_completed(futures): + for future in as_completed(futures, timeout=timeout): _sched.append(future.result()) - builder = PopenPoolExecutor(max_workers=max_workers) + builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout) # build in process parallel def _build(context) -> str: diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py index 7b08179d3..60a290a81 100644 --- a/python/bitblas/gpu/gemv.py +++ b/python/bitblas/gpu/gemv.py @@ -21,7 +21,7 @@ # Modifications Copyright (c) Microsoft. # The code below is mostly copied from apache/tvm gemv.py in dlight. """A rule for GEMV and DecodeGEMV.""" -import re + from functools import reduce from typing import List, Optional, Union, Dict diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py index 5a6405f52..5ccc5b40e 100644 --- a/python/bitblas/gpu/gemv_dequantize.py +++ b/python/bitblas/gpu/gemv_dequantize.py @@ -287,7 +287,6 @@ def get_vectorize_factor(target_format): assert len(config.thread) == 2, "SplitK only support 2D thread config" num_warps = int(num_warps // config.thread[0]) - # get target dequantize buffer's idx def get_idx(weight_decode_info: Dict): # for LUT dequantize, the expr is LUT(w), the idx is 1 diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index 4638f2e72..6537a555a 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -619,6 +619,16 @@ def check_last_trait(region: List[Range]): if func.attrs is not None and "weight_transform_kind" in func.attrs: intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"] tags["intrin_info"] = intrin_info + # Analysis Block Reduction Optimization + # Currently, we only support block reduction depth 2 for small M + # When the func is a dequantize like ops, we should consider the M + if hasattr(func.attrs, "dequantize_info"): + for arg in func.params: + inp_shape = func.buffer_map[arg].shape + M = inp_shape[0] + if isinstance(M, tir.IntImm) and M <= 128: + tags["block_reduction_depth"] = 2 + break return tags diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py index a20359e11..4bf8be4e6 100644 --- a/python/bitblas/gpu/matmul_mma.py +++ b/python/bitblas/gpu/matmul_mma.py @@ -8,6 +8,7 @@ from tvm import tir, DataType from tvm.target import Target +from ..base.roller import Hint from ..base.roller.rasterization import NoRasterization from ..base import analysis from .base import GPUScheduleRule @@ -338,12 +339,15 @@ def store_output(block_outer, write_buffer_idx): def apply_config( # pylint: disable=too-many-locals,missing-docstring self, func: tir.PrimFunc, - config, + config: Hint, ) -> Optional[tir.Schedule]: if "dequantize_info" in func.attrs: dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() return dequantize_rule.apply_config(func, config) + if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: + return self.apply_block_reduction_with_config(func, config) + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group,) @@ -686,7 +690,375 @@ def tensorize_init_store_compute(): sch.annotate( sch.get_loops(block_init_c)[-2], ann_key="inject_customized_code_prepend", - ann_val=invoke_func) + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def apply_block_reduction_with_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config: Hint, + ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.sch_shared_memory_prefetch_block_reduction_with_config( + func, config) + + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] + + def check_require_cache(func: tir.PrimFunc, config): + conditions: List[bool] = [] + + # check if has dynamic symbolic + def check_has_dynamic(func: tir.PrimFunc): + for param in func.params: + if param not in func.buffer_map: + continue + arg = func.buffer_map[param] + for i in arg.shape: + if isinstance(i, tir.Var): + return True + return False + + conditions.append(check_has_dynamic(func)) + # check if has post process + conditions.append(sch.get(main_block) not in output_blocks) + # check if not use async copy + conditions.append(config.use_async is False) + return any(conditions) + + cache_write_required = check_require_cache(func, config=config) + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + # Start Schedule + # Step 0. Get schedule config. + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + assert (config.block_reduction_depth is not None), "block_reduction_depth is required" + reduce_k = config.block_reduction_depth + chunk = config.rstep[0] // reduce_k + + # tensor core intrinsic size + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + # introduce the constraint of reduce_k because reduce_k will doubling the size of + if (chunk * reduce_k) * DataType(dtype).bits != (512): + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, (reduce_k * chunk) // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + k0, kr = sch.split(k0, [None, reduce_k]) + + sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(thread_idy, "threadIdx.y") + sch.bind(kr, "threadIdx.z") + + # rewrite smooth layout of shared memory + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_smem_layout_rewrite( + block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_smem_layout_rewrite( + block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) + smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + f_0 = f_1 = sch.fuse(f_0, f_1) + sch.bind(f_0, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + if len(config.vectorize.values()) < 2: + return None + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + trans=intrin_info.trans_a, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + vec_len=list(config.vectorize.values())[1], + can_swizzle=can_swizzle_b, + is_smooth=intrin_info.smooth_b, + trans=intrin_info.trans_b, + ) + + # rewrite global smooth layout + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + producers = _collect_producers(sch, block) + g2s_block = a_g2s if matrix_name == "A" else b_g2s + propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + smooth_gmem_layout_rewrite( + sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) + j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + # sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + # sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) # plan import source if len(import_source) > 0: sch.annotate( diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index 96461db45..679e84395 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -27,6 +27,94 @@ ) +def _bind_thread_based_on_config(sch, block, block_row_warps, block_col_warps, warp_size): + # assume the block loops has been fused + last_loop = sch.get_loops(block)[-1] + loop_extent = sch.get(last_loop).extent + vec = get_coalesced_veclen(sch.get(block)) + + if loop_extent // (vec * warp_size) == 0: + last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) + sch.bind(B_shared_tx, "threadIdx.x") + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. + if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // + vec) & (warp_size - 1) == 0: + last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: + last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + loop_extent = sch.get(last_loop).extent + + if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: + last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) + sch.bind(B_shared_ty, "threadIdx.y") + loop_extent = sch.get(last_loop).extent + + if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: + last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) + sch.bind(B_shared_tz, "threadIdx.z") + loop_extent = sch.get(last_loop).extent + + sch.unroll(last_loop) + sch.annotate(last_loop, "pragma_unroll_explicit", False) + + +def _bind_thread_based_with_block_reduce_on_config(sch, block, block_row_warps, block_col_warps, + warp_size, reduce_k): + # assume the block loops has been fused + last_loop = sch.get_loops(block)[-1] + loop_extent = sch.get(last_loop).extent + vec = get_coalesced_veclen(sch.get(block)) + + if loop_extent // (vec * warp_size) == 0: + last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) + sch.bind(B_shared_tx, "threadIdx.x") + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. + if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // + vec) & (warp_size - 1) == 0: + last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: + last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + loop_extent = sch.get(last_loop).extent + + B_shared_ty = None + if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: + last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) + loop_extent = sch.get(last_loop).extent + + B_shared_tz = None + if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: + last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) + loop_extent = sch.get(last_loop).extent + + if B_shared_ty and B_shared_tz: + B_shared_ty = sch.fuse(B_shared_tz, B_shared_ty) + sch.bind(B_shared_ty, "threadIdx.y") + elif B_shared_ty: + sch.bind(B_shared_ty, "threadIdx.y") + + if loop_extent // reduce_k >= 1 and loop_extent % reduce_k == 0: + last_loop, B_shared_tk = sch.split(last_loop, factors=[None, reduce_k]) + sch.bind(B_shared_tk, "threadIdx.z") + loop_extent = sch.get(last_loop).extent + + sch.unroll(last_loop) + sch.annotate(last_loop, "pragma_unroll_explicit", False) + + def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 def index_map_3d(b, i, j): @@ -359,7 +447,7 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): if not (can_swizzle or is_smooth): pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) + sch.annotate(f_0, "pragma_unroll_explicit", False) return block_read a_g2s = fetch_to_shared( @@ -579,7 +667,7 @@ def tensorize_init_store_compute(): def sch_dequantize_in_register_with_config( self, func: tir.PrimFunc, - config, + config: Hint, ): """ For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. @@ -841,13 +929,14 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): sch.bind(f_1, "threadIdx.y") sch.vectorize(f_4) sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + # Apply Swizzling sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) # if not, apply padding to alleviate bank conflict if not (can_swizzle or is_smooth): pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read a_g2s = fetch_to_shared( @@ -1067,7 +1156,7 @@ def tensorize_init_store_compute(): def sch_shared_memory_prefetch_with_config( self, func: tir.PrimFunc, - config, + config: Hint, ): """ For A100 Like devices, the shared memory prefetch(async) is required @@ -1083,6 +1172,9 @@ def sch_shared_memory_prefetch_with_config( V compute """ + if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: + return self.sch_shared_memory_prefetch_block_reduction_with_config(func, config) + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel get_mma_intrin_group,) from .intrin import get_lop3_intrin_group @@ -1136,8 +1228,6 @@ def check_weight_decode_info(weight_decode_info): # Start Schedule # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - # tensor core intrinsic size shared_scope = config.shared_scope @@ -1387,13 +1477,14 @@ def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): sch.bind(f_1, "threadIdx.y") sch.vectorize(f_4) sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + # Apply Swizzling sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) # if not, apply padding to alleviate bank conflict if not (can_swizzle or is_smooth): pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) return block_read a_g2s = fetch_to_shared( @@ -1496,25 +1587,10 @@ def get_idx(): sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) ndim = len(sch.get(block_shared_local_local_shared).iter_vars) - fused = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) - - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, - factors=[ - None, - num_tz, - num_ty, - warp_size, - get_coalesced_veclen(sch.get(block_shared_local_local_shared)), - ], - ) + _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) - sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.y") - sch.bind(f_1, "threadIdx.z") - sch.vectorize(f_4) - sch.unroll(f_0) - sch.annotate(f_0, "pragma_unroll_explicit", False) + _bind_thread_based_on_config(sch, block_shared_local_local_shared, num_ty, num_tz, + warp_size) # cache small tensors, e.g. LUT if b_idx: @@ -1636,10 +1712,572 @@ def tensorize_init_store_compute(): ) return sch + def sch_shared_memory_prefetch_block_reduction_with_config( + self, + func: tir.PrimFunc, + config: Hint, + ): + """ + For A100 Like devices, the shared memory prefetch(async) is required + to achieve optimal performance. + quantized weight + | + V + shared memory prefetch (with async copy) + | + V + dequantized into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" + + # Start Schedule + # Step 0. Get schedule config. + # tensor core intrinsic size + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + assert (config.block_reduction_depth is not None), "block_reduction_depth is required" + reduce_k = config.block_reduction_depth + chunk = config.rstep[0] // reduce_k + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if (chunk * reduce_k) * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite( + sch, + main_block, + enable=True, + trans=False, + matrix_name="A", + intrin_group=intrin_group, + ): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + intra_index_map, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # get target dequantize buffer's offset + def get_offset(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_dequantize_info["source_format"]["format"] == "nf": + return 1 + return 0 + + offset = get_offset() + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + group_size = weight_dequantize_info["group_size"] + + _, mn, mk = intrin_group["micro_kernel"] + + def get_param_indices( + indexmap, + l=mn, + r=mk, + group_size=group_size # noqa: E741 + ): # noqa: E741 + # assume the param layout is n, k + rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + return new_indices + + with_scaling = bool(weight_dequantize_info["with_scaling"]) + if with_scaling: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 1), + get_param_indices(intra_index_map), + ) + with_zeros = bool(weight_dequantize_info["with_zeros"]) + if with_zeros: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 2), + get_param_indices(intra_index_map), + ) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + k0, kr = sch.split(k0, [None, reduce_k]) + + sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(thread_idy, "threadIdx.y") + sch.bind(kr, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_r, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, reduce_k, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + f_1 = f_2 = sch.fuse(f_1, f_2) + sch.bind(f_1, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + # TODO(lei): the factor should be analyzed more deeper. + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block, should skip qzeros + is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and + weight_decode_info["zeros_mode"] == "quantized") + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = ( + _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + # global -> prefetch_shared -> vector load -> dequantzed_local -> shared + block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, + shared_scope) + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if is_qzeros: + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_rk, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, reduce_k, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + B_shared_tz = B_shared_ty = sch.fuse(B_shared_ty, B_shared_tz) + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_rk, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_shared_local_local_shared).iter_vars) + _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) + + _bind_thread_based_with_block_reduce_on_config(sch, block_shared_local_local_shared, + num_ty, num_tz, warp_size, reduce_k) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + j2, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + def apply_config( # pylint: disable=too-many-locals,missing-docstring self, func: tir.PrimFunc, - config, + config: Hint, ) -> Optional[tir.Schedule]: def check_sm_version(arch: str) -> int: diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py index b08f3d188..71ef224d7 100644 --- a/python/bitblas/quantization/quantization.py +++ b/python/bitblas/quantization/quantization.py @@ -144,10 +144,12 @@ def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): assert dtype == "float16" s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") e4 = val & tir.const(0x40, "uint16") - prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), tir.const(0x4000, "uint16")) + prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), + tir.const(0x4000, "uint16")) e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix return tir.reinterpret("float16", s_f16 | e_f16) + def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" @@ -157,6 +159,7 @@ def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): e_f16 = e_f16 ^ tir.const(0x2000, "uint16") return tir.reinterpret("float16", s_f16 | e_f16) + def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): assert nbit == 8 assert dtype == "float16" diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py index 6c34f6e4e..58aa8d226 100644 --- a/python/bitblas/wrapper/general.py +++ b/python/bitblas/wrapper/general.py @@ -291,8 +291,8 @@ def legalize_c(p): call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) else: call_str = "" - call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, smem_str, - call_args) + call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) # Create the host function wrapper for the CUDA kernel host_func = """ extern "C" void call({}) {{ @@ -410,9 +410,7 @@ def legalize_c(p): (symbolic,) = list(dynamic_symbolic_set) range_str = opt_shapes[symbolic] if last_range == 0: - call_str = "if ({} == 0) return; \n".format( - symbolic, - ) + call_str = "if ({} == 0) return; \n".format(symbolic,) call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( symbolic, range_str, diff --git a/testing/python/operators/test_general_matmul_splitk_ops.py b/testing/python/operators/test_general_matmul_splitk_ops.py index 5d21116f1..12fbbcabe 100644 --- a/testing/python/operators/test_general_matmul_splitk_ops.py +++ b/testing/python/operators/test_general_matmul_splitk_ops.py @@ -41,6 +41,7 @@ def test_matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtyp matmul = MatmulWithSplitK(config=matmul_config, enable_tuning=False) assert get_codegen_result(matmul) + @pytest.mark.parametrize( "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ @@ -83,6 +84,7 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu output_torch = torch.matmul(inputs[0], inputs[1].t() if layout == "nt" else inputs[1]) torch.testing.assert_close(output_bitblas, output_torch, rtol=1e-2, atol=1e-1) + @pytest.mark.parametrize( "SPlitK,M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", [ @@ -93,8 +95,8 @@ def test_matmul_torch_forward_consistent(SplitK, M, N, K, A_dtype, W_dtype, accu ], ) def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, - layout, with_bias, group_size, with_scaling, with_zeros, - zeros_mode): + layout, with_bias, group_size, with_scaling, with_zeros, + zeros_mode): import torch torch.random.manual_seed(0) matmul_config = MatmulConfigWithSplitK( @@ -119,6 +121,7 @@ def test_matmul_torch_forward_fp8e4m3(SplitK, M, N, K, A_dtype, W_dtype, accum_d input_shape = (M, K) weight_shape = (N, K) if layout == "nt" else (K, N) + def map_torch_type(intype): typemap = { @@ -148,6 +151,10 @@ def map_torch_type(intype): print("torch_ref_out", ref_out) print("bitblas_out", bitblas_out) + matmul.forward(torch_a, torch_b, output=bitblas_out) + print("torch_ref_out", ref_out) + print("bitblas_out", bitblas_out) + torch.testing.assert_close(bitblas_out, ref_out, rtol=1e0, atol=1e-1)