From ff011624ff390362ceb092f9b41722001bc8cbe1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 1 Jul 2024 16:57:48 +0000 Subject: [PATCH 1/3] chore: Update support matrix in README --- README.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index f112989ce..7b623aa6e 100644 --- a/README.md +++ b/README.md @@ -55,22 +55,22 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and ## Support Matrix -| **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS
Support** | **Tested
Platform** | -|:-----------:|:-----------:|:---------------:|:---------------:|:----------------------:|:----------------------:| -| FP16 | FP16 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | FP4_E2M1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | FP8_E4M3 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | INT8 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | UINT4/INT4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | UINT2/INT2 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | UINT1 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP16 | NF4 | FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| INT8 | INT8 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | -| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | -| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| **A_dtype** | **W_dtype** | **Accum_dtype** | **Out_dtype** | **BitBLAS Support** | **Tested Platform** | +|:-----------:|:-----------:|:---------------:|:--------------------:|:-------------------:|:----------------------------------------------------:| +| FP16 | FP16 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | FP4_E2M1 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | FP8_E4M3 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | INT8 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | UINT4/INT4 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | UINT2/INT2 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | UINT1 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP16 | NF4 | FP32/FP16 | FP16 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | INT8 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR. From 5b9c49e142eb472fe185d36f6ddd27c8e1fde3ce Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 3 Jul 2024 10:06:05 +0000 Subject: [PATCH 2/3] Move bitblas package to root --- bitblas/__init__.py | 87 + bitblas/base/__init__.py | 18 + bitblas/base/analysis.py | 300 +++ bitblas/base/common_schedules.py | 163 ++ bitblas/base/roller/__init__.py | 7 + bitblas/base/roller/arch/__init__.py | 14 + bitblas/base/roller/arch/arch_base.py | 40 + bitblas/base/roller/arch/cpu.py | 19 + bitblas/base/roller/arch/cuda.py | 67 + bitblas/base/roller/bestfit.py | 66 + bitblas/base/roller/hint.py | 248 ++ bitblas/base/roller/node.py | 408 +++ bitblas/base/roller/policy/__init__.py | 5 + bitblas/base/roller/policy/common.py | 56 + bitblas/base/roller/policy/default.py | 748 ++++++ bitblas/base/roller/policy/tensorcore.py | 349 +++ bitblas/base/roller/rasterization.py | 88 + .../base/roller/shape_inference/__init__.py | 4 + bitblas/base/roller/shape_inference/common.py | 66 + bitblas/base/roller/shape_inference/tir.py | 399 +++ bitblas/base/schedule_rule.py | 149 ++ bitblas/base/transform.py | 218 ++ bitblas/base/utils.py | 517 ++++ bitblas/cache/__init__.py | 9 + bitblas/cache/operator.py | 179 ++ bitblas/generator.py | 15 + bitblas/gpu/__init__.py | 23 + bitblas/gpu/base.py | 44 + bitblas/gpu/element_wise.py | 97 + bitblas/gpu/fallback.py | 95 + bitblas/gpu/gemv.py | 794 ++++++ bitblas/gpu/gemv_dequantize.py | 369 +++ bitblas/gpu/general_reduction.py | 465 ++++ bitblas/gpu/intrin/__init__.py | 3 + bitblas/gpu/intrin/lop3.py | 1667 ++++++++++++ bitblas/gpu/matmul.py | 372 +++ bitblas/gpu/matmul_analysis.py | 786 ++++++ bitblas/gpu/matmul_mma.py | 1069 ++++++++ bitblas/gpu/matmul_mma_dequantize.py | 2295 +++++++++++++++++ bitblas/gpu/matmul_wmma.py | 892 +++++++ bitblas/gpu/reduction.py | 301 +++ bitblas/gpu/rmsnorm.py | 144 ++ bitblas/gpu/transpose.py | 133 + bitblas/gpu/utils.py | 86 + bitblas/module/__init__.py | 305 +++ bitblas/ops/__init__.py | 7 + bitblas/ops/general_matmul.py | 588 +++++ bitblas/ops/general_matmul_splitk.py | 199 ++ bitblas/ops/impl/__init__.py | 3 + .../ops/impl/batch_matmul_dequantize_impl.py | 392 +++ bitblas/ops/impl/batch_matmul_impl.py | 93 + bitblas/ops/impl/convolution2d_impl.py | 190 ++ bitblas/ops/impl/ladder_permutate_impl.py | 81 + bitblas/ops/impl/lop3_permutate_impl.py | 152 ++ bitblas/ops/impl/matmul_dequantize_impl.py | 644 +++++ .../ops/impl/matmul_dequantize_splitk_impl.py | 184 ++ bitblas/ops/impl/matmul_impl.py | 356 +++ bitblas/ops/impl/matmul_splitk_impl.py | 94 + bitblas/ops/impl/param_permutate_impl.py | 56 + bitblas/ops/ladder_permutate.py | 97 + bitblas/ops/lop3_permutate.py | 72 + bitblas/ops/matmul.py | 288 +++ bitblas/ops/matmul_dequantize.py | 331 +++ bitblas/ops/operator.py | 367 +++ bitblas/ops/param_permutate.py | 91 + bitblas/quantization/__init__.py | 12 + bitblas/quantization/quantization.py | 217 ++ bitblas/quantization/utils.py | 110 + bitblas/relax/op/interleave_weight.py | 23 + bitblas/relax/transform/__init__.py | 5 + .../relax/transform/annotate_decode_block.py | 123 + .../relax/transform/weight_only_propagate.py | 432 ++++ bitblas/testing/__init__.py | 25 + bitblas/utils/__init__.py | 5 + bitblas/utils/post_process.py | 38 + bitblas/utils/target_detector.py | 103 + bitblas/utils/tensor_adapter.py | 130 + bitblas/wrapper/__init__.py | 4 + bitblas/wrapper/general.py | 518 ++++ 79 files changed, 20209 insertions(+) create mode 100644 bitblas/__init__.py create mode 100644 bitblas/base/__init__.py create mode 100644 bitblas/base/analysis.py create mode 100644 bitblas/base/common_schedules.py create mode 100644 bitblas/base/roller/__init__.py create mode 100644 bitblas/base/roller/arch/__init__.py create mode 100644 bitblas/base/roller/arch/arch_base.py create mode 100644 bitblas/base/roller/arch/cpu.py create mode 100644 bitblas/base/roller/arch/cuda.py create mode 100644 bitblas/base/roller/bestfit.py create mode 100644 bitblas/base/roller/hint.py create mode 100644 bitblas/base/roller/node.py create mode 100644 bitblas/base/roller/policy/__init__.py create mode 100644 bitblas/base/roller/policy/common.py create mode 100644 bitblas/base/roller/policy/default.py create mode 100644 bitblas/base/roller/policy/tensorcore.py create mode 100644 bitblas/base/roller/rasterization.py create mode 100644 bitblas/base/roller/shape_inference/__init__.py create mode 100644 bitblas/base/roller/shape_inference/common.py create mode 100644 bitblas/base/roller/shape_inference/tir.py create mode 100644 bitblas/base/schedule_rule.py create mode 100644 bitblas/base/transform.py create mode 100644 bitblas/base/utils.py create mode 100644 bitblas/cache/__init__.py create mode 100644 bitblas/cache/operator.py create mode 100644 bitblas/generator.py create mode 100644 bitblas/gpu/__init__.py create mode 100644 bitblas/gpu/base.py create mode 100644 bitblas/gpu/element_wise.py create mode 100644 bitblas/gpu/fallback.py create mode 100644 bitblas/gpu/gemv.py create mode 100644 bitblas/gpu/gemv_dequantize.py create mode 100644 bitblas/gpu/general_reduction.py create mode 100644 bitblas/gpu/intrin/__init__.py create mode 100644 bitblas/gpu/intrin/lop3.py create mode 100644 bitblas/gpu/matmul.py create mode 100644 bitblas/gpu/matmul_analysis.py create mode 100644 bitblas/gpu/matmul_mma.py create mode 100644 bitblas/gpu/matmul_mma_dequantize.py create mode 100644 bitblas/gpu/matmul_wmma.py create mode 100644 bitblas/gpu/reduction.py create mode 100644 bitblas/gpu/rmsnorm.py create mode 100644 bitblas/gpu/transpose.py create mode 100644 bitblas/gpu/utils.py create mode 100644 bitblas/module/__init__.py create mode 100644 bitblas/ops/__init__.py create mode 100644 bitblas/ops/general_matmul.py create mode 100644 bitblas/ops/general_matmul_splitk.py create mode 100644 bitblas/ops/impl/__init__.py create mode 100644 bitblas/ops/impl/batch_matmul_dequantize_impl.py create mode 100644 bitblas/ops/impl/batch_matmul_impl.py create mode 100644 bitblas/ops/impl/convolution2d_impl.py create mode 100644 bitblas/ops/impl/ladder_permutate_impl.py create mode 100644 bitblas/ops/impl/lop3_permutate_impl.py create mode 100644 bitblas/ops/impl/matmul_dequantize_impl.py create mode 100644 bitblas/ops/impl/matmul_dequantize_splitk_impl.py create mode 100644 bitblas/ops/impl/matmul_impl.py create mode 100644 bitblas/ops/impl/matmul_splitk_impl.py create mode 100644 bitblas/ops/impl/param_permutate_impl.py create mode 100644 bitblas/ops/ladder_permutate.py create mode 100644 bitblas/ops/lop3_permutate.py create mode 100644 bitblas/ops/matmul.py create mode 100644 bitblas/ops/matmul_dequantize.py create mode 100644 bitblas/ops/operator.py create mode 100644 bitblas/ops/param_permutate.py create mode 100644 bitblas/quantization/__init__.py create mode 100644 bitblas/quantization/quantization.py create mode 100644 bitblas/quantization/utils.py create mode 100644 bitblas/relax/op/interleave_weight.py create mode 100644 bitblas/relax/transform/__init__.py create mode 100644 bitblas/relax/transform/annotate_decode_block.py create mode 100644 bitblas/relax/transform/weight_only_propagate.py create mode 100644 bitblas/testing/__init__.py create mode 100644 bitblas/utils/__init__.py create mode 100644 bitblas/utils/post_process.py create mode 100644 bitblas/utils/target_detector.py create mode 100644 bitblas/utils/tensor_adapter.py create mode 100644 bitblas/wrapper/__init__.py create mode 100644 bitblas/wrapper/general.py diff --git a/bitblas/__init__.py b/bitblas/__init__.py new file mode 100644 index 000000000..172c4cbf1 --- /dev/null +++ b/bitblas/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import sys +import os + +# installing tvm +install_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm", "python") +if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = install_tvm_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, install_tvm_path) + +develop_tvm_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "3rdparty", "tvm", "python") +if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: + os.environ["PYTHONPATH"] = develop_tvm_path + ":" + os.environ.get("PYTHONPATH", "") + sys.path.insert(0, develop_tvm_path) + +from . import gpu # noqa: F401 +from .base import ( + TileDevice, # noqa: F401 + fast_tune, # noqa: F401 + ApplyDefaultSchedule, # noqa: F401 + ApplyFastTuning, # noqa: F401 + BlockInfo, # noqa: F401 + IterInfo, # noqa: F401 + ScheduleRule, # noqa: F401 + normalize_prim_func, # noqa: F401 + try_inline, # noqa: F401 + try_inline_contiguous_spatial, # noqa: F401 +) + +from . import testing # noqa: F401 +from .utils import auto_detect_nvidia_target # noqa: F401 +from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 +from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 +from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 +from .module import Linear # noqa: F401 + +import logging +from tqdm import tqdm + + +class TqdmLoggingHandler(logging.Handler): + """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ + + def __init__(self, level=logging.NOTSET): + """ Initialize the handler with an optional log level. """ + super().__init__(level) + + def emit(self, record): + """ Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """ + try: + msg = self.format(record) + tqdm.write(msg) + except Exception: + self.handleError(record) + + +def set_log_level(level): + """ Set the logging level for the module's logger. + + Args: + level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). + OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' + """ + if isinstance(level, str): + level = getattr(logging, level.upper(), logging.INFO) + logger = logging.getLogger(__name__) + logger.setLevel(level) + + +def _init_logger(): + """ Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """ + logger = logging.getLogger(__name__) + handler = TqdmLoggingHandler() + formatter = logging.Formatter( + fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") + handler.setFormatter(formatter) + logger.addHandler(handler) + logger.propagate = False + set_log_level('WARNING') + + +_init_logger() + +__version__ = "0.0.1.dev12" diff --git a/bitblas/base/__init__.py b/bitblas/base/__init__.py new file mode 100644 index 000000000..122c44cbd --- /dev/null +++ b/bitblas/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Base infra""" +from .analysis import ( + BlockInfo, + IterInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, +) +from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial +from .schedule_rule import ScheduleRule +from .transform import ApplyDefaultSchedule, ApplyFastTuning +from .utils import fast_tune, fast_tune_with_dynamic_range +from .roller import * diff --git a/bitblas/base/analysis.py b/bitblas/base/analysis.py new file mode 100644 index 000000000..eb9c19415 --- /dev/null +++ b/bitblas/base/analysis.py @@ -0,0 +1,300 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Analysis on TIR blocks, loops and functions.""" +from typing import List, Optional, Set, Union +from typing_extensions import Literal + +from tvm import ir, tir, DataType +from tvm._ffi import get_global_func +from tvm.target.target import Target +from tvm.tir import Schedule, IterVar +from tvm.tir.schedule import BlockRV + + +class IterInfo: + """Information about a loop/iter var.""" + + kind: Literal["S", "R", "O"] + var: tir.Var + _dom: tir.PrimExpr + loop_rv: tir.schedule.LoopRV + + def __init__( + self, + kind: Literal["S", "R", "O"], + var: tir.Var, + dom: tir.PrimExpr, + loop_rv: tir.schedule.LoopRV, + ): + """Construct an IterInfo object.""" + self.kind = kind + self.var = var + self._dom = dom + self.loop_rv = loop_rv + + @property + def dom(self) -> Union[int, tir.PrimExpr]: + """The iteration domain of the loop.""" + return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom + + def __str__(self) -> str: + return f'Iter("{self.kind}", {self.dom})' + + def __repr__(self) -> str: + return str(self) + + +class BlockInfo: + """Information about a TIR block.""" + + name: str + iters: List[IterInfo] + block_rv: tir.schedule.BlockRV + _reduction_block: bool + + def __init__( + self, + name: str, + iters: List[IterInfo], + block_rv: tir.schedule.BlockRV, + reduction_block: bool = False, + ): + """Construct a BlockInfo object.""" + self.name = name + self.block_rv = block_rv + self.iters = iters + self._reduction_block = reduction_block + + def dom(self) -> List[Union[int, tir.PrimExpr]]: + """The iteration domain of the block.""" + return [i.dom for i in self.iters] + + def dom_kind(self) -> str: + """The iteration domain kind of the block, for example, SSSS, SSSR.""" + return "".join(i.kind for i in self.iters) + + def is_injective(self) -> bool: + """Whether the block is injective, i.e. all its iteration domains are injective.""" + return all(k == "S" for k in self.dom_kind()) + + def is_elementwise(self, sch: tir.Schedule) -> bool: + """Whether the block is elementwise, i.e. trivial mapping between read/write region""" + + def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: + return dom.min.same_as(var) and dom.extent == 1 + + if not self.is_injective(): + return False + block = sch.get(self.block_rv) + if len(block.reads) != 1 or len(block.writes) != 1: + return False + r_region = block.reads[0].region + w_region = block.writes[0].region + if len(r_region) != len(w_region): + return False + for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region): + if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom): + return False + return True + + def is_reduction(self) -> bool: + """Whether the block is a reduction workload.""" + # TODO(@junrushao): distinguish GEMV and reduction + return self._reduction_block + + def is_gemv(self) -> bool: + """Whether the block is a GEMV workload.""" + raise NotImplementedError + + def is_gemm(self) -> bool: + """Whether the block is a GEMM workload.""" + raise NotImplementedError + + def __str__(self) -> str: + return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' + + def __repr__(self) -> str: + return str(self) + + +_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") + + +def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: + """Normalize the primfunc to normal form""" + try: + result = _normalize_prim_func(sch) + if result is None: + return None + except Exception: # pylint: disable=broad-except + return None + + def _iter_kind(i: tir.IterVar) -> str: + return { + tir.IterVar.DataPar: "S", + tir.IterVar.CommReduce: "R", + }.get(i.iter_type, "O") + + blocks: List[BlockInfo] = [] + for block, loops, iters, is_reduction in zip(*result): + blocks.append( + BlockInfo( + name=sch.get(block).name_hint, + iters=[ + IterInfo( + kind=_iter_kind(iter), # type: ignore + var=iter.var, + dom=iter.dom, + loop_rv=loop, + ) for loop, iter in zip(loops, iters) + ], + block_rv=block, + reduction_block=is_reduction, + )) + return blocks + + +def find_var_from_func(func, var: str): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var) and i.name == var: + return i + return None + + +def check_func_with_dynamic(func): + for buffer in func.buffer_map.values(): + for i in buffer.shape: + if isinstance(i, tir.Var): + return True + return False + + +def _assert_gpu_target(target: Target): + if "gpu" not in target.keys: + raise ValueError(f"Expect a GPU target, but got {target}") + + +def get_max_threads_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_threads_per_block = None + for name in ["max_threads_per_block", "max_num_threads"]: + if max_threads_per_block is None: + max_threads_per_block = target.attrs.get(name, None) + if max_threads_per_block is None: + max_threads_per_block = 64 + return int(max_threads_per_block) + + +def get_max_shared_memory_per_block(target: Target) -> int: + _assert_gpu_target(target) + max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) + if max_shared_memory_per_block is None: + raise ValueError( + f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") + return int(max_shared_memory_per_block) + + +def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: + try: + block = sch.mod[func_name].body.block + except Exception: + raise ValueError(f"The function body is expected to be the root block, but got:\n" + f"{sch.mod[func_name].body}") from None + return sch.get_block(block.name_hint) + + +def collect_block_iter_vars_used_in_access_region(block: tir.Block, + region: List[ir.Range]) -> Set[tir.Var]: + """Collect the block iter variables used in the access region of a buffer region.""" + tir_vars = set() + for expr in region: + if expr.extent != 1: + continue + tir_vars |= collect_vars_used_in_prim_expr(expr.min) + tir_vars &= set(iter_var.var for iter_var in block.iter_vars) + return tir_vars + + +def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: + """Collect the variables used in the PrimExpr.""" + tir_vars = set() + + def _collect_tir_var(expr): + if isinstance(expr, tir.Var): + tir_vars.add(expr) + + tir.stmt_functor.post_order_visit(expr, _collect_tir_var) + return tir_vars + + +def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: + """Detect the dominant read indices in the block.""" + dominant_read = None + num_read_iters = -1 + for buffer_region in block.reads: + tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) + if num_read_iters < len(tir_vars): + num_read_iters = len(tir_vars) + dominant_read = buffer_region + assert dominant_read is not None + (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) + return result + + +def is_broadcast_epilogue( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + epilogue: tir.schedule.BlockRV, +) -> bool: + """Check if the epilogue block is a broadcast pattern""" + write_buffers = {r.buffer for r in sch.get(block).writes} + epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1} + for buffer_region in sch.get(epilogue).reads: + if buffer_region.buffer not in write_buffers: + continue + tir_vars = collect_block_iter_vars_used_in_access_region( + sch.get(epilogue), buffer_region.region) + if len(tir_vars) < len(epilogue_iters): + return True + return False + + +def get_reduction_blocks(sch: tir.Schedule, + blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) == 0: + return None + return reduction_blocks + + +def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: + # gpu memory prefer 128 bits coalesced access (e.g. four banks) + # 128 bits + buffers: List[tir.Buffer] = [] + for read in block_stmt.reads: + buffers.append(read.buffer) + for write in block_stmt.writes: + buffers.append(write.buffer) + # pick the dtype with the largest bits + max_dtype_bits: int = 0 + for buffer in buffers: + max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits) + return target_bits // max_dtype_bits diff --git a/bitblas/base/common_schedules.py b/bitblas/base/common_schedules.py new file mode 100644 index 000000000..7d528c70a --- /dev/null +++ b/bitblas/base/common_schedules.py @@ -0,0 +1,163 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm common_schedules.py in dlight. +"""Common schedule strategies for TIR.""" +from typing import Callable, List + +from tvm import tir + +from .analysis import BlockInfo + + +def get_block( + sch: tir.Schedule, + blocks: List[BlockInfo], + name: str, +): + """Get the target block from a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get target block. + name : str + The name of the target block. + + Returns + ------- + target_block : BlockRV + The target block. + """ + + target_block: tir.BlockRV = None + for block_info in blocks: + block = block_info.block_rv + if sch.get(block).name_hint == name: + target_block = block + return target_block + + +def get_output_blocks( + sch: tir.Schedule, + blocks: List[BlockInfo], +): + """Get the output blocks of a schedule. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to get output blocks. + blocks : List[BlockInfo] + The blocks to be analyzed. + + Returns + ------- + output_blocks : List[BlockInfo] + The output blocks. + """ + + # collect arguments buffer + func = sch.mod["main"] + args = list(func.buffer_map.values()) + + output_blocks = [] + for block_info in blocks: + block = block_info.block_rv + for write in sch.get(block).writes: + if write.buffer in args: + output_blocks.append(block) + + return output_blocks + + +def try_inline( + sch: tir.Schedule, + blocks: List[BlockInfo], +) -> List[BlockInfo]: + """Try to inline as many blocks as possible, and return the remaining blocks. + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + blocks : List[BlockInfo] + The blocks to be inlined. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + def _trial(func: Callable): + for i, block in enumerate(blocks): + try: + func(block.block_rv) + except Exception: # pylint: disable=bare-except + continue + return i + return None + + while True: + i = _trial(sch.compute_inline) + if i is None: + i = _trial(sch.reverse_compute_inline) + if i is None: + break + blocks.pop(i) + return blocks + + +def try_inline_contiguous_spatial( + sch: tir.Schedule, + block_infos: List[BlockInfo], +) -> List[BlockInfo]: + """Try to inline contiguous spatial blocks in a schedule + + Parameters + ---------- + sch : tir.Schedule + The TIR schedule used to inline blocks. + block_infos : List[BlockInfo] + The blocks to be try. + + Returns + ------- + remaining : List[BlockInfo] + The remaining blocks that cannot be inlined. + """ + + if block_infos is None: + return None + results = [] + spatial_blocks = [] + block: BlockInfo + for block in block_infos: + if block.is_injective(): + spatial_blocks.append(block) + elif spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + results.append(block) + spatial_blocks = [] + else: + results.append(block) + if spatial_blocks: + results.extend(try_inline(sch, spatial_blocks)) + return results diff --git a/bitblas/base/roller/__init__.py b/bitblas/base/roller/__init__.py new file mode 100644 index 000000000..9afd7cff0 --- /dev/null +++ b/bitblas/base/roller/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .node import PrimFuncNode # noqa: F401 +from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 +from .hint import Hint # noqa: F401 +from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 +from .arch import TileDevice, CUDA # noqa: F401 diff --git a/bitblas/base/roller/arch/__init__.py b/bitblas/base/roller/arch/__init__.py new file mode 100644 index 000000000..9cb036792 --- /dev/null +++ b/bitblas/base/roller/arch/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .arch_base import TileDevice +from .cuda import * +from .cpu import * + + +def get_arch(target: tvm.target.Target) -> TileDevice: + if target.kind.name == "cuda": + return CUDA(target) + elif target.kind.name == "llvm": + return CPU(target) + else: + raise ValueError(f"Unsupported target: {target.kind.name}") diff --git a/bitblas/base/roller/arch/arch_base.py b/bitblas/base/roller/arch/arch_base.py new file mode 100644 index 000000000..6e98838c7 --- /dev/null +++ b/bitblas/base/roller/arch/arch_base.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List + + +class TileDevice: + """ + Represents the architecture of a computing device, capturing various hardware specifications. + """ + + def __init__(self) -> None: + self.reg_cap: int = 0 # Register capacity: The amount of register memory available + self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available + self.compute_max_core: int = 0 # The maximum number of computing cores + self.warp_size: int = ( + 0 # The size of a warp, a group of threads that execute instructions in lockstep + ) + self.sm_partition: int = 0 # The number of streaming multiprocessor partitions + self.transaction_size: List[int] = [ + 0, + 0, + ] # The size of memory transactions, typically in bytes + self.max_smem_usage: int = 0 # The maximum shared memory usage allowed + self.bandwidth: List[int] = [ + 0, + 0, + ] # Bandwidth specifications, possibly including peak and sustained rates + self.platform: str = "unknown" # The platform or manufacturer of the device + self.compute_capability: str = ( + "unknown" # The compute capability, indicating the feature set and performance level + ) + self.l2_cache_size_bytes: int = 0 + # the number of transaction size in bytes + self.transaction_size: List[int] = [0, 0] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + self.bandwidth: List[int] = [0, 0] + + def get_avaliable_tensorintrin_shapes(self): + raise NotImplementedError() diff --git a/bitblas/base/roller/arch/cpu.py b/bitblas/base/roller/arch/cpu.py new file mode 100644 index 000000000..98fb14af5 --- /dev/null +++ b/bitblas/base/roller/arch/cpu.py @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +from tvm.target import Target +from .arch_base import TileDevice + + +# For LLVM Backend, we do not provide the detailed information of the CPU +# As the LLVM backend do not required tuning, just maintain the consistency +class CPU(TileDevice): + + def __init__(self, target: Target): + self.target = target + device = tvm.runtime.cpu(0) + if not device.exist: + raise RuntimeError("Cannot find cpu device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CPU" diff --git a/bitblas/base/roller/arch/cuda.py b/bitblas/base/roller/arch/cuda.py new file mode 100644 index 000000000..2189947e7 --- /dev/null +++ b/bitblas/base/roller/arch/cuda.py @@ -0,0 +1,67 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +from tvm.target import Target +from .arch_base import TileDevice +from typing import List, Dict, Union + + +def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + +class TensorInstruction(object): + + def __init__( + self, + name: str, + intrin_group: Dict, + shape: List[int], + ): + self.name: str = name + self.intrin_group: Dict = intrin_group + # only maintain the shape of M and N + self.shape: List[int] = shape + + +class CUDA(TileDevice): + + def __init__(self, target: Union[Target, str]): + if isinstance(target, str): + target = tvm.target.Target(target) + self.target = target + self.sm_version = check_sm_version(self.target.arch) + device = tvm.runtime.cuda(0) + if not device.exist: + raise RuntimeError("Cannot find cuda device 0.") + self.device: tvm.runtime.Device = device + self.platform: str = "CUDA" + self.smem_cap = device.max_shared_memory_per_block + self.compute_max_core = device.multi_processor_count + self.warp_size = device.warp_size + self.compute_capability = device.compute_version.replace(".", "") + self.reg_cap: int = 65536 + self.max_smem_usage: int = 2 * self.smem_cap + self.sm_partition: int = 4 + self.l2_cache_size_bytes: int = target.l2_cache_size_bytes + # the number of transaction size in bytes + self.transaction_size: List[int] = [32, 128] # in bytes + # bandwidth in MB/s, will be used for recommend basic tile size + # TODO(lei): find some way to get the real bandwidth + # However, the ratio of bandwidth between different devices can + # be similar. The bandwidth can work for another devices as well. + self.bandwidth: List[int] = [750, 12080] + # get the available tensor instructions during runtime to avoid + # the dependency of the tensor intrinsics registration + self.available_tensor_instructions: List[TensorInstruction] = None + + def get_avaliable_tensorintrin_shapes(self): + from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group + + self.available_tensor_instructions = ( + TensorInstruction("mma", get_mma_intrin_group, [16, 16]), + TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), + ) + return [t.shape for t in self.available_tensor_instructions] diff --git a/bitblas/base/roller/bestfit.py b/bitblas/base/roller/bestfit.py new file mode 100644 index 000000000..ad8ec20a8 --- /dev/null +++ b/bitblas/base/roller/bestfit.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Benifit For BitBLAS Schedule""" +class Block: + def __init__(self, start, end, is_free): + self.start = start + self.end = end + self.is_free = is_free + + def size(self) -> int: + return self.end - self.start + + def merge(self, other): + assert self.is_free == other.is_free + self.start = min(self.start, other.start) + self.end = max(self.end, other.end) + + def __repr__(self) -> str: + return "".format(self.start, self.size()) + + +class BestFit: + def __init__(self, align=32): + self.limit = 0 + self.list = [] + self.align = align + + def malloc(self, size) -> Block: + size = (size + self.align - 1) // self.align * self.align + found = None + for block in self.list: + if block.is_free and block.size() >= size: + if not found or found.size() > block.size(): + found = block + if found: + found.is_free = False + remain = found.size() - size + if remain != 0: + found.end -= remain + self.list.insert( + self.list.index(found) + 1, Block(found.end, found.end + remain, True) + ) + return found + elif len(self.list) > 0 and self.list[-1].is_free: + add = size - self.list[-1].size() + self.list[-1].end += add + self.limit = self.list[-1].end + self.list[-1].is_free = False + return self.list[-1] + else: + block = Block(self.limit, self.limit + size, False) + self.list.append(block) + self.limit += size + return block + + def free(self, block: Block) -> None: + assert not block.is_free + idx = self.list.index(block) + self.list[idx] = Block(block.start, block.end, True) + if idx + 1 < len(self.list) and self.list[idx + 1].is_free: + self.list[idx].merge(self.list[idx + 1]) + self.list.pop(idx + 1) + if idx - 1 >= 0 and self.list[idx - 1].is_free: + self.list[idx].merge(self.list[idx - 1]) + self.list.pop(idx - 1) diff --git a/bitblas/base/roller/hint.py b/bitblas/base/roller/hint.py new file mode 100644 index 000000000..f6e2fb03a --- /dev/null +++ b/bitblas/base/roller/hint.py @@ -0,0 +1,248 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Hint definition for schedule""" +from typing import Dict, List, Tuple +from . import PrimFuncNode +import numpy as np +from .rasterization import * + + +class TensorCoreExtraConfig: + """ + This class is used to store extra information for tensorcore + """ + + def __init__( + self, + AS_shape: Tuple[int], + BS_shape: Tuple[int], + AF_shape: Tuple[int], + BF_shape: Tuple[int], + tc_axis: Tuple[int], + ) -> None: + self.AS_shape: Tuple[int] = AS_shape + self.BS_shape: Tuple[int] = BS_shape + self.AF_shape: Tuple[int] = AF_shape + self.BF_shape: Tuple[int] = BF_shape + self.tc_axis: Tuple[int] = tc_axis + + +class Stride: + """ + Manages stride information for a given axis of a tensor. + """ + + def __init__(self, stride: int = 1, ax: int = -1) -> None: + # which axis to put stride on + self._ax: int = int(ax) + # the stride size of the axis + self._stride: int = int(stride) + + @property + def ax(self) -> int: + return self._ax + + @property + def stride(self) -> int: + return self._stride + + def compute_strides_from_shape(self, shape: List[int]) -> List[int]: + ndim = len(shape) + strides = [1 for _ in shape] + for i in range(ndim - 2, -1, -1): + if i == self.ax: + strides[i] = self.stride + else: + strides[i] = int(strides[i + 1] * shape[i + 1]) + return strides + + def compute_elements_from_shape(self, shape: List[int]) -> int: + original_shape = np.prod(shape) + if not self.is_valid(): + strided_elem = original_shape + else: + assert self.ax < len(shape) + strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride + assert strided_elem >= original_shape + return int(strided_elem) + + def is_valid(self) -> bool: + return self.ax >= 0 + + def __repr__(self) -> str: + return f"" + + +class TileDict: + """ + Manages tiling information and configurations for computational tasks. + """ + + def __init__(self, output_tile) -> None: + self.output_tile = output_tile + # schedule config + self.tile_map = {} + self.rstep_map = {} + self.cached_tensors_map = {} + self.output_strides_map = {} + self.tensor_strides_map = {} + + # analysis + self.traffic = -1 + self.smem_cost = -1 + self.block_per_SM = -1 + self.num_wave = -1 + self.grid_size = -1 + self.valid = True + + def get_tile(self, func) -> List[int]: + return self.tile_map[func] + + def get_rstep(self, func) -> Dict[str, int]: + return self.rstep_map + + def __hash__(self) -> int: + return hash(tuple(self.output_tile)) + + +class IntrinInfo: + """ + The information of tensorcore intrinsic related information + """ + + def __init__( + self, + in_dtype: str, + out_dtype: str, + trans_b: bool, + input_transform_kind: int = 0, + weight_transform_kind: int = 0, + ) -> None: + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.trans_a = False + self.trans_b = trans_b + self.input_transform_kind = input_transform_kind + self.weight_transform_kind = weight_transform_kind + + def __repr__(self) -> str: + return f"" + + @property + def smooth_a(self) -> bool: + return self.input_transform_kind >= 2 + + @property + def smooth_b(self) -> bool: + return self.weight_transform_kind >= 2 + + @property + def inter_transform_a(self) -> bool: + return self.input_transform_kind >= 1 + + @property + def inter_transform_b(self) -> bool: + return self.weight_transform_kind >= 1 + + +class Hint(object): + """ + Central configuration class for managing various parameters of computational tasks. + """ + + def __init__(self) -> None: + self.arch = None + self.use_tc = None # todo(lei): this should be renamed. + + # Special axes tiling info + self.block = [] + self.thread = [] + # Special axes for MMA + self.warp = [] + # Reduce axes tiling info + self.rstep = [] + self.reduce_thread = [] + self.rasterization_plan = NoRasterization() + self.cached_tensors = [] + self.output_strides = {} + self.schedule_stages = None + # Config for block reduction + self.block_reduction_depth = None # type: int + + # Experimental + self._raxis_order = [] + self._step = [] + self.vectorize: Dict[str, int] = {} + self.pipeline_stage = 1 + self.use_async = False + self.opt_shapes: Dict[str, int] = {} + self.intrin_info = IntrinInfo("float16", "float16", True) + self.shared_scope: str = "shared" + self.pass_context: Dict = {} + + def to_dict(self) -> Dict: + dic = {} + dic["block"] = self.block + if self.use_tc: + dic["warp"] = self.warp + else: + dic["thread"] = self.thread + dic["rstep"] = self.rstep + if np.prod(self.reduce_thread) > 1: + dic["reduce_thread"] = self.reduce_thread + if self.use_tc: + dic["use_tc"] = self.use_tc + if self.output_strides: + dic["strides"] = {} + for k, stride in self.output_strides.items(): + if stride.is_valid(): + dic["strides"][k] = stride + if len(dic["strides"]) == 0: + del dic["strides"] + if np.prod(self._step) > 1: + dic["step"] = self._step + if self._raxis_order != []: + dic["raxis_order"] = self._raxis_order + if self.vectorize != {}: + dic["vectorize"] = self.vectorize + if self.pipeline_stage != 1: + dic["pipeline_stage"] = self.pipeline_stage + if self.block_reduction_depth is not None: + dic["block_reduction_depth"] = self.block_reduction_depth + return dic + + def from_dict(self, dic: Dict) -> "Hint": + self.__init__() + for k, v in dic.items(): + setattr(self, k, v) + return self + + def tensorcore_legalization(self): + # only keep the last 2 axes for tensorcore + self.warp = self.warp[-2:] + self.block = self.block[-2:] + return self + + @property + def raxis_order(self) -> List[int]: + if self._raxis_order != []: + return self._raxis_order + return list(range(len(self.rstep))) + + @property + def step(self) -> List[int]: + if self._step != []: + return self._step + return [1 for _ in self.block] + + def __repr__(self) -> str: + return str(self.to_dict()) + + def complete_config(self, node: PrimFuncNode): + # analysis pass context, for int8 mma, we should merge static shared memory + merge_static_smem = False + # int32 and float32 accum may take too much shared memory + if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: + merge_static_smem = True + self.pass_context = {"tir.merge_static_smem": merge_static_smem} + return self diff --git a/bitblas/base/roller/node.py b/bitblas/base/roller/node.py new file mode 100644 index 000000000..8e20440bb --- /dev/null +++ b/bitblas/base/roller/node.py @@ -0,0 +1,408 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""PrimFunc Wrapper and Block information Analaysis""" + +import tvm +from tvm import tir +from tvm.tir import IterVar, PrimFunc +from typing import Any, Dict, List, Tuple, Optional +from tvm.tir.schedule.schedule import BlockRV +import numpy as np +import functools +from ..analysis import BlockInfo, get_reduction_blocks +from .. import analysis +from .. import normalize_prim_func +from .shape_inference import get_analyzer_by_tir + + +def pre_order_traverse(block_analyzer, blocks, func): + visited = set() + + def _traverse(block): + if block in visited: + return + visited.add(block) + for dep_blocks in block_analyzer.get_consumer_blocks(block): + _traverse(dep_blocks) + func(block) + + for block in blocks: + _traverse(block) + + +class BlockAnalyzer(object): + + def __init__(self, sch) -> None: + self.sch: tir.Schedule = sch + self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) + + def get_block_name(self, block: BlockRV) -> str: + return self.sch.get(block).name_hint + + def get_block_info(self, block: BlockRV) -> BlockInfo: + for block_info in self.block_infos: + if self.get_block_name(block) == block_info.name: + return block_info + return None + + def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: + block_info = self.get_block_info(block) + axis = [] + for iter in block_info.iters: + if iter.kind == "S": + axis.append(iter) + return axis + + def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: + block_info = self.get_block_info(block) + raxis = [] + for iter in block_info.iters: + if iter.kind == "R": + raxis.append(iter) + return raxis + + def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]: + buffers = [] + for read in self.sch.get(block).reads: + buffers.append(read.buffer) + return buffers + + def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]: + buffers = [] + for write in self.sch.get(block).writes: + buffers.append(write.buffer) + return buffers + + def get_buffers(self, block: BlockRV) -> List[tir.Buffer]: + return self.get_input_buffers(block) + self.get_output_buffers(block) + + def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]: + return self.sch.get_producers(block) + + def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: + return self.sch.get_consumers(block) + + +class Node(object): + + def __init__(self, tags: Optional[Dict] = None) -> None: + if tags is None: + tags = {} + self._dtypes = [] + self._tag: Dict = {} + for tag in tags: + self.add_tag(tag, tags[tag]) + + def set_tag(self, k: str, v: Any = True) -> None: + self.add_tag(k, v) + + def add_tag(self, k: str, v: Any = True) -> None: + self._tag[k] = v + + def get_tag(self, k: str) -> Any: + if k not in self._tag: + return None + return self._tag[k] + + +class PrimFuncNode(Node): + + def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None: + super().__init__(tags) + self.prim_func = self._specialize_func(prim_func) + self.sch: tir.Schedule = tir.Schedule(self.prim_func) + self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) + self.schedule_stages: List[BlockRV] = [] + self.blocks: List[BlockRV] = [] + self.output_blocks: List[BlockRV] = None + self.reduction_block: BlockRV = None + self.raxis = [] + self.input_buffers = [] + self.output_buffers = [] + self.buffers = [] + self.args = [] + self._analysis_funcinfo() + self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks) + + def _specialize_func(self, func: PrimFunc): + # Specialize the function to make it more friendly for analysis. + # set attrs + for k, v in func.attrs.items(): + self.set_tag(k, v) + if self.get_tag("is_speclized"): + return func + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = analysis.find_var_from_func(func, name) + if var is not None: + func = func.specialize({var: shape.astype(var.dtype)}) + return func + + def _analysis_funcinfo(self): + root_block = analysis.get_root_block(self.sch) + blocks = self.sch.get_child_blocks(root_block) + self.blocks = blocks + + self.output_blocks = self.sch.get_output_blocks(root_block) + reduction_blocks = get_reduction_blocks(self.sch, blocks) + if reduction_blocks is None: + self.reduction_block = None + self.schedule_stages.append(*self.output_blocks) + else: + # analysis on the last reduction block + self.reduction_block = reduction_blocks[-1] + # set raxis + reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in reduce_block_info.iters: + if iter.kind == "R": + self.raxis.append(iter) + self.schedule_stages.append(self.reduction_block) + + # collect output buffers + for output_block in self.output_blocks: + for write in self.sch.get(output_block).writes: + if write not in self.output_buffers: + self.output_buffers.append(write.buffer) + + for param in self.prim_func.params: + if param not in self.prim_func.buffer_map: + # in case of dynamic symbolic may in params + continue + buffer = self.prim_func.buffer_map[param] + if buffer not in self.output_buffers: + self.input_buffers.append(buffer) + + self.args = self.input_buffers + self.output_buffers + self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()] + + # set dtype + self.set_dtype(tvm.DataType(self.output_buffers[0].dtype)) + + def get_opt_shape(self, name) -> int: + opt_shapes = self.get_tag("opt_shapes") + if opt_shapes is None: + return None + return opt_shapes[name] + + def extent_wrapper(self, value) -> int: + if isinstance(value, tvm.tir.Var): + return self.get_opt_shape(value.name) + elif isinstance(value, tvm.tir.IntImm): + return int(value) + else: + return value + + @functools.lru_cache() + def get_space_dim(self) -> List[int]: + dim_size = [] + if self.reduction_block: + block_info = self.block_analyzer.get_block_info(self.reduction_block) + for iter in block_info.iters: + if iter.kind == "S": + if isinstance(iter.dom.extent, tvm.tir.IntImm): + dim_size.append(int(iter.dom.extent)) + else: + assert isinstance(iter.dom.extent, tvm.tir.Var) + dim_size.append(self.get_opt_shape(iter.dom.extent.name)) + else: + # assume outer stage has the same shape + loops = self.sch.get_loops(self.schedule_stages[0]) + for loop in loops: + dim_size.append(int(self.sch.get(loop).extent)) + return [int(x) for x in dim_size] + + def set_dtype(self, dtype: tvm.DataType, id=0) -> None: + assert isinstance(dtype, tvm.DataType), type(dtype) + if dtype == tvm.DataType("bool"): + dtype = tvm.DataType("int8") + if len(self._dtypes) <= id: + self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) + elif self._dtypes[id] is not None: + assert self._dtypes[id] == dtype, (self._dtypes, dtype) + self._dtypes[id] = dtype + + def get_dtype(self, id=0) -> tvm.DataType: + return self._dtypes[id] + + def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: + return tvm.DataType(buffer.dtype) + + def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): + if rstep is None: + rstep = {} + shape = { + self.block_analyzer.get_output_buffers(block)[0].name: + [tvm.arith.ConstIntBound(0, val - 1) for val in tile] for block in self.schedule_stages + } + return self.ana.infer(shape, rstep, targets) + + def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[:read_idx_offset]] + shapes, intermediate_bind = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[:read_idx_offset]): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + trimmed_shape = [ + self.extent_wrapper(i) + for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + ] + results.append(trimmed_shape) + return results + + # Propagate inputs only on reduction block + def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} + reduction_block = self.reduction_block + args = self.block_analyzer.get_input_buffers(reduction_block) + targets = [t.name for t in args] + shapes, intermediate_bind = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(args): + if arg.name in intermediate_bind: + results.append(shapes[arg.name]) + continue + # should not exceed original shape + propagate_shape = shapes[arg.name] + buffer_shape = args[i].shape + if len(buffer_shape) > len(propagate_shape): + buffer_shape = buffer_shape[-len(propagate_shape):] + trimmed_shape = [ + self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape))) + ] + results.append(trimmed_shape) + return results + + def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: + if rstep is None: + rstep = {} + read_idx_offset = len(self.input_buffers) + targets = [t.name for t in self.args[read_idx_offset:]] + shapes, _ = self.propagate(tile, rstep, targets) + results = [] + for i, arg in enumerate(self.args[read_idx_offset:]): + # should not exceed original shape + trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) + results.append(trimmed_shape) + return results + + def propagate_reduction_inputs(self, + shape, + rstep: Optional[Dict] = None) -> Dict[str, List[int]]: + if rstep is None: + rstep = {} + if self.reduction_block is None: + return {} + targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)] + results, _ = self.propagate(shape, rstep, targets) + return results + + def get_reduce_inputs_dtype(self): + if self.reduction_block is None: + return {} + return { + b.name: tvm.DataType(b.dtype) + for b in self.block_analyzer.get_input_buffers(self.reduction_block) + } + + @functools.lru_cache() + def infer_tensorcore_axis(self) -> Tuple[int]: + # axis is fixed for one expression, so only inference and cached + assert self.get_tag("tensorcore_config") + + C_ax_m, C_ax_n = self.get_tag("tensorcore_config") + wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok + + output_buffer_shape = ( + self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape) + valid_region = [] + for region in output_buffer_shape: + if region.value == 1: + continue + valid_region.append(region) + + num_nvalid_regions = len(output_buffer_shape) - len(valid_region) + self.set_tag("num_nvalid_regions", num_nvalid_regions) + + def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): + spatial_dim = self.get_space_dim() + assert len(valid_region) == len( + spatial_dim), f" {valid_region} mismatch with {spatial_dim}" + cl_shapes = [1] * len(spatial_dim) + cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m + cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n + return cl_shapes + + CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions) + self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]]) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_m = A_deps.index(wmma_m) + B_ax_n = B_deps.index(wmma_n) + + CL_shape = [1] * len(self.get_space_dim()) + shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) + A_deps, B_deps = shapes.values() + A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k) + B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k) + tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) + return tc_axis + + def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int: + if stride_map is None: + stride_map = {} + result = 0 + shapes, _ = self.propagate(shape, rstep) + + def is_broadcast_pattern(buffer, output_buffer): + return (buffer in self.args and + len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and + np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) + + def is_after_reduce_stage(block): + if not self.reduction_block: + return False + reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None) + if reduce_dependent_blocks is None: + reduce_dependent_blocks = set() + pre_order_traverse( + self.block_analyzer, + [self.reduction_block], + lambda block: reduce_dependent_blocks.add(block), + ) + self.reduce_dependent_blocks = reduce_dependent_blocks + return block not in reduce_dependent_blocks + + # compute cached stages + cached_tensor = [] + for block in self.blocks: + output_buffer = self.block_analyzer.get_output_buffers(block)[0] + for buffer in self.block_analyzer.get_input_buffers(block): + cache = buffer.name not in cached_tensor and ( + is_broadcast_pattern(buffer, output_buffer) or + self.block_analyzer.get_block_info(block).is_reduction) + if not cache: + continue + cached_tensor.append(buffer.name) + if is_after_reduce_stage(block): + continue # cache after reduce op can often reuse buffer in reduce stage + + if buffer.name in stride_map: + num_elem = stride_map[buffer.name].compute_elements_from_shape( + shapes[buffer.name]) + else: + num_elem = np.prod(shapes[buffer.name]) + buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) + buffer_len = (buffer_len + 31) // 32 * 32 + result += buffer_len + return result, cached_tensor + + def get_input_buffers(self) -> List[tir.Buffer]: + return self.block_analyzer.input_buffers diff --git a/bitblas/base/roller/policy/__init__.py b/bitblas/base/roller/policy/__init__.py new file mode 100644 index 000000000..09ed1d51b --- /dev/null +++ b/bitblas/base/roller/policy/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .default import DefaultPolicy +from .tensorcore import TensorCorePolicy diff --git a/bitblas/base/roller/policy/common.py b/bitblas/base/roller/policy/common.py new file mode 100644 index 000000000..9141550c8 --- /dev/null +++ b/bitblas/base/roller/policy/common.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import List +import numpy as np + + +def get_all_factors(n: int) -> List[int]: + # Calculate the square root of n and round it up to the nearest integer + n0 = int(np.ceil(np.sqrt(n))) + + # Find all divisors of n that are less than n0 + val = np.where(n % np.arange(1, n0) == 0)[0] + 1 + + # If n is a perfect square, add the square root to the list of factors + mid = np.array([], dtype=int) if n0 * n0 != n else [n0] + + # Combine the factors and their corresponding larger pair factors + return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] + + +def factorize(n: int) -> List[int]: + i = 2 # Start with the smallest prime number + result = [] + + # Iterate through numbers to find factors + while n > 1: + if n % i == 0: # If i is a factor of n + n //= i # Divide n by i and keep the integer part + result.append(i) + else: + i += 1 # Try the next number + return result + + +def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: + # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension + if subtensor[-1] != tensor[-1] or len(subtensor) == 1: + return subtensor[-1] + else: + # Recursively calculate the coalesced factor for the remaining dimensions + return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) + + +def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int: + # Calculate the total number of elements in the subtensor + bytes = int(np.prod(subtensor)) + + if bytes == 0: + return 0 + + # Calculate the coalesced factor for the subtensor + factor = int(coalesced_factor(subtensor, tensor)) + + # Compute the shape of the coalesced tensor + return transaction_size * bytes / min(transaction_size, factor) diff --git a/bitblas/base/roller/policy/default.py b/bitblas/base/roller/policy/default.py new file mode 100644 index 000000000..81aeba123 --- /dev/null +++ b/bitblas/base/roller/policy/default.py @@ -0,0 +1,748 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Policy for cuda core schedule""" +import functools +import math +from queue import PriorityQueue +from typing import Iterable, Dict, List, Optional + +import numpy as np +import tvm + +from ..arch import TileDevice +from ..bestfit import BestFit +from ..hint import Hint, Stride, TileDict +from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors +from ..node import PrimFuncNode +from ..rasterization import NoRasterization + + +class DefaultPolicy: + """ + Default Policy for fastdlight, a heuristic plan that tries to + minimize memory traffic and maximize parallelism.for BitBLAS Schedule. + """ + + def __init__(self, + func: tvm.tir.PrimFunc, + arch: TileDevice, + tags: Optional[Dict] = None) -> None: + if tags is None: + tags = {} + self.arch = arch + self.prim_func_node = PrimFuncNode(func, tags) + self.ordered_nodes = [self.prim_func_node] + self.output_nodes = [self.prim_func_node] + + def emit_config(self, topk: int) -> List[Hint]: + base_tile = self.get_base_tile() + if base_tile is None: + return [] + + rstep_map = self._assign_reduce_step(self.prim_func_node) + smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map) + results = [] + for td in smem_tile_condidates: + if not self.check_tile_shape_isvalid(td): + continue + + self._expand_reduce_axis(td) + for codegen_dicts in self.assign_block_size(td): + results.append(codegen_dicts) + if len(results) >= topk: + break + if len(results) >= topk: + break + return results + + def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: + _steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()] + steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] + for i in range(len(steps)): + added = list( + filter( + lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], + [2, 4, 8, 16, 32], + )) + steps[i].extend(added) + steps[i] = sorted(steps[i]) + visited_tiles = {} + queue = PriorityQueue() + + def prio(td: TileDict): + return (td.traffic + 1) * td.num_wave + + def add_to_queue(tile): + if tuple(tile) in visited_tiles: + return + td = self.compute_tile_dict(tile, rstep_map) + visited_tiles[tuple(tile)] = td + if td.valid: + queue.put([prio(td), tile]) + + add_to_queue(init_tile) + while not (queue.empty() or len(visited_tiles) > 2000): + _, tile = queue.get() + dim_ids = [step.index(t) for step, t in zip(steps, tile)] + for i in reversed(range(len(dim_ids))): + if dim_ids[i] + 1 < len(steps[i]): + new_tile = tile.copy() + new_tile[i] = steps[i][dim_ids[i] + 1] + add_to_queue(new_tile) + + visited_tiles = filter(lambda td: td.valid, visited_tiles.values()) + sorted_tiles = sorted(visited_tiles, key=lambda td: prio(td)) + return sorted_tiles + + def get_base_tile(self): + """ + Gets the minimum tile configuration that satisfies no redundancy in computation. + + Returns + ------- + List[int] + The base tile configuration, which is a list of 1s equal in length to the space dimensions + of the primary function node. + """ + shape = self.prim_func_node.get_space_dim() + base_tile = [1 for _ in shape] + + return base_tile + + # handles multiple output cases + def _get_output_tile_map(self, tile): + """ + Handles multiple output cases by mapping output nodes to their respective tile configurations. + + Parameters + ---------- + tile : List[int] + The tile configuration. + + Returns + ------- + Dict + A dictionary mapping the primary function node to its corresponding tile configuration + based on the output nodes' space dimensions. + """ + tile_map = {} + tile_map[self.prim_func_node] = [ + tile[i] * self.prim_func_node.get_space_dim()[i] // + self.output_nodes[0].get_space_dim()[i] for i in range(len(tile)) + ] + return tile_map + + def score_block_size(self, n): + """ + Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition. + + Parameters + ---------- + n : int + The block size to score. + + Returns + ------- + Tuple[float, float] + A tuple containing two scores representing efficiency and fit, respectively. + """ + num_wrap = (n + self.arch.warp_size - 1) // self.arch.warp_size + r1 = max(num_wrap / self.arch.sm_partition, self.arch.sm_partition / num_wrap) + r2 = (num_wrap * self.arch.warp_size - n) / n + return (r1, r2) + + def get_block_size(self, n): + """ + Determines the optimal block size for a given constraint, based on scoring various factors. + + Parameters + ---------- + n : int + The constraint size. + + Returns + ------- + int + The optimal block size chosen from the factors of n, constrained by a maximum of 1024 and + scored by the `score_block_size` method. + """ + factors = get_all_factors(n) + factors = list(filter(lambda x: x <= 1024, factors)) + factor_ordered = sorted(factors, key=self.score_block_size) + return factor_ordered[0] + + def get_node_reduce_step_candidates(self, node: PrimFuncNode): + """ + Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2. + + Parameters + ---------- + node : PrimFuncNode + The node for which to calculate reduction step candidates. It contains reduction axes (raxis) + with their domains (dom.extent). + + Returns + ------- + Dict[str, List[int]] + A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, + this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 + as step candidates; for others, it uses all factors of the domain. + """ + + results = {} + for k_iter in node.raxis: + all_factors = get_all_factors(int(k_iter.dom.extent)) + if len(all_factors) == 2 and int(k_iter.dom.extent) > 64: + all_factors = [1] + while all_factors[-1] * 2 < int(k_iter.dom.extent): + all_factors.append(all_factors[-1] * 2) + results[k_iter.var.name] = all_factors + return results + + def _assign_reduce_step(self, node: PrimFuncNode): + """ + Assigns an optimal reduction step for the given PrimFuncNode. + + Parameters + ---------- + node : PrimFuncNode + The node for which the reduction step is to be assigned. + + Returns + ------- + Dict + A dictionary mapping reduction axis variable names to their optimal reduction steps. + """ + if node.reduction_block is None: + return {} + + raxis = node.raxis + tile = [1] * len(node.get_space_dim()) + all_steps = self.get_node_reduce_step_candidates(node) + + def sim(a: int, b: int): + return (2 * a * b) / (a * a + b * b) + + def _score(rstep_id): + rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id} + score = 0 + shape = node.propagate_inputs(tile, rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + read_transaction_elements = self.arch.transaction_size[1] // ( + (node.get_buffer_dtype(input_buffer).bits + 7) // 8) + score += sim( + int(coalesced_factor(shape[i], input_buffer.shape)), + read_transaction_elements, + ) + return score + + def _enlarge(rstep_id): + candidates = [] + candidates.append((rstep_id, _score(rstep_id))) + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + best = max(candidates, key=lambda x: x[1]) + return best + + # enlarge rstep to ensure read is coaleased + cur_rstep_id = {ax.var.name: 0 for ax in raxis} + cur_score = _score(cur_rstep_id) + while True: + if cur_score == 0: + break + new_rstep, new_score = _enlarge(cur_rstep_id) + if new_score <= cur_score: + break + else: + cur_rstep_id, cur_score = new_rstep, new_score + rstep = {k: all_steps[k][cur_rstep_id[k]] for k in cur_rstep_id} + return rstep + + def _expand_reduce_axis(self, td: TileDict): + """ + Expands the reduction axis in the TileDict based on shared memory limits. + + Parameters + ---------- + td : TileDict + The TileDict object to be optimized. + + Returns + ------- + None + This function modifies the TileDict in place. + """ + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propagate_inputs(td.get_tile(node), rstep=rstep) + for i, input_buffer in enumerate(node.input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = self._compute_shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + + def _compute_memory_traffic(self, output_tile): + """ + Computes the memory traffic for a given output tile configuration. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + + Returns + ------- + Tuple[int, Dict] + The total memory traffic and a map of operation tiles. + """ + op_tile_map = self._get_output_tile_map(output_tile) + traffic = 0 + for node in reversed(self.ordered_nodes): + tile = op_tile_map[node] + input_shapes = node.propagate_inputs(tile) + output_shapes = node.propagate_outputs(tile) + for i, buffer in enumerate(node.input_buffers): + nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 + read_transaction_elements = self.arch.transaction_size[1] // nbytes + traffic += ( + coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements) + * nbytes) + for i, buffer in enumerate(node.output_buffers): + nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 + write_transaction_elements = self.arch.transaction_size[0] // nbytes + traffic += ( + coalesced_tensor_shape(output_shapes[i], buffer.shape, + write_transaction_elements) * nbytes) + return traffic, op_tile_map + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + """ + Infers the shared memory usage of a node given a TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + node : PrimFuncNode + The node for which to infer the shared memory usage. + + Returns + ------- + int + The estimated amount of shared memory used by the node. + """ + return node.footprint(td.get_tile(node), td.get_rstep(node), td.tensor_strides_map[node]) + + def _compute_shared_memory_usage(self, td: TileDict): + """ + Computes the stride map for a given node and TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + The output strides and tensor strides. + """ + self._compute_stride_map(td) + allocator = BestFit() + block_map = {} + cached_tensors_map = {} + + node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage( + td, self.prim_func_node) + block = allocator.malloc(node_internal_bytes) + allocator.free(block) + assert len(block_map) == 0 + return allocator.limit, cached_tensors_map + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + """ + Computes the stride map for a given node based on the TileDict configuration. + + Parameters + ---------- + node : PrimFuncNode + The node for which to compute the stride map. + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + Tuple[Dict, Dict] + A tuple of dictionaries containing the output strides and tensor strides. + """ + output_strides = { + int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) + } + tensor_strides = {} + return output_strides, tensor_strides + + def _compute_stride_map(self, td: TileDict): + """ + Computes the stride map for all nodes in a TileDict. + + Parameters + ---------- + td : TileDict + The TileDict object for which to compute the stride maps. + + Returns + ------- + None + This function updates the TileDict object in place with the computed stride maps. + """ + output_strides_map = {} + tensor_strides_map = {} + for node in self.ordered_nodes: + output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( + node, td) + td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map + + def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: + """ + Computes and returns a TileDict object for a given output tile configuration and reduction step map. + + Parameters + ---------- + output_tile : List[int] + The output tile configuration. + rstep_map : Dict + The reduction step map. + + Returns + ------- + TileDict + A TileDict object containing the computed tile configuration, memory traffic, shared memory cost, + grid size, and other related parameters. + """ + td = TileDict(output_tile) + td.rstep_map = rstep_map + td.traffic, td.tile_map = self._compute_memory_traffic(output_tile) + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + if td.smem_cost > self.arch.smem_cap: + td.valid = False + return td + output_shape = self.output_nodes[0].get_space_dim() + td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) + # estimated reg usage + reg_usage = int(2 * max([ + np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes + ])) + if reg_usage > self.arch.reg_cap: + td.valid = False + return td + td.block_per_SM = min( + self.arch.max_smem_usage // max(td.smem_cost, 1), + self.arch.reg_cap // max(reg_usage, 1), + self.arch.sm_partition, + ) + td.num_wave = int(np.ceil(td.grid_size / int(td.block_per_SM * self.arch.compute_max_core))) + return td + + def check_tile_shape_isvalid(self, td: TileDict) -> bool: + """ + Checks if the tile shapes in the TileDict are valid for the nodes in this context. + + Parameters: + - td (TileDict): The TileDict object containing tile shapes and other configurations. + + Returns: + - bool: True if all tile shapes are valid, False otherwise. + """ + for node in self.ordered_nodes: + if np.prod(td.get_tile(node)) == 0: + return False + node_grid_size = np.prod([ + (y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim()) + ]) + if node_grid_size != td.grid_size: + return False + if (hasattr(node, "reduce_op") and node.reduce_op is not None and + len(node.reduce_op.axis) == len(td.output_tile)): + for i, tile_extent in enumerate(td.output_tile): + if node.reduce_op.axis[i].dom.extent % tile_extent: + return False + + return True + + def recommend_block_size(self, td: TileDict) -> List[int]: + """ + Recommends optimal block sizes based on the TileDict configuration. + + Parameters + ---------- + td : TileDict + The TileDict object containing the tile configuration. + + Returns + ------- + List[int] + A list of recommended block sizes sorted based on their score. + """ + node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] + max_block_size = functools.reduce(math.gcd, node_space_sizes) + + if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( + node_space_sizes): + node_reduce_sizes = [ + int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes + ] + total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] + max_possible_size = functools.reduce(math.gcd, total_sizes) + possible_block_sizes = list( + filter( + lambda x: x % max_block_size == 0 and x <= 1024, + get_all_factors(max_possible_size), + )) + possible_block_sizes = list( + filter( # either be a factor of space or cover fully cover the space + lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), + possible_block_sizes, + )) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + else: + possible_block_sizes = get_all_factors(max_block_size) + possible_block_sizes = list(filter(lambda x: x <= 1024, possible_block_sizes)) + factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) + return factor_ordered + + def assign_block_size(self, td: TileDict, topk=1): + """ + Assigns block sizes to the TileDict based on the recommended block sizes. + + Parameters + ---------- + td : TileDict + The TileDict object to assign block sizes to. + topk : int, optional + The number of top block sizes to consider. + + Yields + ------- + Dict + The block size assignment for the primary function node. + """ + block_size_ordered = self.recommend_block_size(td) + for block_size in block_size_ordered: + result = {} + failed = False + result = self._assign_block_size(self.prim_func_node, td, block_size) + if result is None: + failed = True + break + if failed: + continue + else: + yield result + topk -= 1 + if topk == 0: + break + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Assigns a block size to a given PrimFuncNode based on the TileDict configuration and the specified block size. + + Parameters + ---------- + node : PrimFuncNode + The node to assign the block size to. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size to be assigned. + + Returns + ------- + Hint + A Hint object containing the assigned block size and other related settings. + """ + tile, rsteps = td.get_tile(node), td.get_rstep(node) + factors = factorize(block_size) + cur_threads = [1 for _ in tile] + reduce_thread = {k: 1 for k in rsteps} + ndim = len(tile) + + def _score(node, thread): # small is better + score = 0 + block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] + shape = node.propagate_inputs(block_tile) + for i, _ in enumerate(node.input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + for buffer in node.output_buffers: + score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0] + return score + + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if cur_threads[i] >= tile[i]: + continue + if (tile[i] % (cur_threads[i] * factor)) != 0: + continue + cur_threads[i] *= factor + score_map[i] = (_score(node, cur_threads), i) + cur_threads[i] //= factor + if len(score_map) > 0: + # assign to space axis + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + cur_threads[dim_order[0]] *= factor + else: + # assign to reduce axis + target_ax = None + for ax, ax_len in reversed(list(rsteps.items())): + if ax_len % (reduce_thread[ax] * factor) == 0: + target_ax = ax + break + assert target_ax + reduce_thread[target_ax] *= factor + + codegen_dict = Hint() + codegen_dict.block = tile + codegen_dict.thread = cur_threads + codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis] + codegen_dict.reduce_thread = [reduce_thread[ax.var.name] for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + if node.get_dtype().bits == 16: # set step=2 for 16bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 2 == 0: + codegen_dict._step[i] = 2 + break + elif node.get_dtype().bits == 8: # set step=4 for 8bit case to ensure coalesced access + codegen_dict._step = [1 for _ in range(ndim)] + for i in reversed(range(ndim)): + if codegen_dict.block[i] // codegen_dict.thread[i] % 4 == 0: + codegen_dict._step[i] = 4 + break + # Plan vectorize + codegen_dict.vectorize = self._plan_vectorize(node, td, block_size) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") + return codegen_dict + + def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int): + """ + Plans vectorization for a given PrimFuncNode based on the TileDict configuration and block size. + + Parameters + ---------- + node : PrimFuncNode + The node for which to plan vectorization. + td : TileDict + The TileDict object containing the tile configuration. + block_size : int + The block size used for vectorization planning. + + Returns + ------- + Dict + A dictionary mapping tensors to their vectorization size. + """ + + def is_cont(shape, vec): + if len(shape) == 0: + return vec == 1 + last = shape[-1] + if last == 1: + return is_cont(shape[0:-1], vec // last) + else: + return last % vec == 0 + + def is_shape_aligned(shape, factor): + return int(np.prod(shape)) % factor == 0 + + def is_type_allowed(dtype, vec): + return dtype.bits * vec <= 128 + + vectorize_sizes = [16, 8, 4, 2] + dtypes = node.get_reduce_inputs_dtype() + shapes = node.propagate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) + vectorize_result = {} + for tensor, shape in shapes.items(): + for v in vectorize_sizes: + if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and + is_type_allowed(dtypes[tensor], v)): + vectorize_result[tensor] = v + break + return vectorize_result + + def plan_rasterization(self, td: TileDict): # pylint: disable=unused-argument + """ + Plans the rasterization for the given TileDict. This function is not implemented yet. + + Parameters + ---------- + td : TileDict + The TileDict object to plan rasterization for. + + Raises + ------- + RasterRationPlan + This function is not implemented yet. + """ + return NoRasterization() diff --git a/bitblas/base/roller/policy/tensorcore.py b/bitblas/base/roller/policy/tensorcore.py new file mode 100644 index 000000000..f4047ef08 --- /dev/null +++ b/bitblas/base/roller/policy/tensorcore.py @@ -0,0 +1,349 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Policy for tensorcore schedule""" +import tvm +from typing import Dict, List, Tuple, Optional +import numpy as np + +from ..arch import TileDevice +from ..hint import Hint, Stride, TileDict, IntrinInfo +from ..node import PrimFuncNode +from .common import coalesced_factor, factorize, get_all_factors +from .default import DefaultPolicy +from ..rasterization import NoRasterization, Rasterization2DColumn + + +class TensorCorePolicy(DefaultPolicy): + + def __init__(self, + func: tvm.tir.PrimFunc, + arch: TileDevice, + tags: Optional[Dict] = None) -> None: + super().__init__(func, arch, tags) + # this is the trick for wmma. + # However, for int8 mma, the wmma_k should be 32. + self.wmma_k = 16 + self.pipeline_stage: int = 1 + self.use_async_copy: bool = False + self.block_reduction_depth: Optional[int] = None + self._legalize_info() + + def _legalize_info(self): + pipleline_stage = self.prim_func_node.get_tag("pipeline_stage") + if pipleline_stage: + self.pipeline_stage = pipleline_stage + else: + if self.arch.compute_capability == "sm_80": + self.pipeline_stage = 2 + else: + self.pipeline_stage = 1 + use_async_copy = self.prim_func_node.get_tag("use_async_copy") + if use_async_copy: + self.use_async_copy = use_async_copy + else: + if self.arch.compute_capability == "sm_80": + self.use_async_copy = True + else: + self.use_async_copy = False + # TODO: block reduction depth is not used for now. + # As there still exists some performance issues for block reduction. + # block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") + # if block_reduction_depth: + # self.block_reduction_depth = block_reduction_depth + + def _compute_tc_strides( + self, + node: PrimFuncNode, + tile: List[int], + rstep: Optional[Dict[str, int]] = None, + ) -> Tuple[Stride, Stride, Stride]: + if rstep is None: + rstep = {} + # strides was used for shared memory padding. which is necessary for avoiding + # shared memory load bank conflict when we do not applying tensorcore layout. + shapes = node.propagate_reduction_inputs(tile, rstep) + AS_shape, BS_shape = shapes.values() + CS_shape = tile + A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis() + + # applying strides + # TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option.. + offset = 8 + A_high_ax = min(A_ax_m, A_ax_k) + B_high_ax = min(B_ax_n, B_ax_k) + C_high_ax = min(C_ax_m, C_ax_n) + A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) + B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) + C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) + return A_stride, B_stride, C_stride + + def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): + value, cached_tensors = super().infer_node_smem_usage(td, node) + value *= self.pipeline_stage + return value, cached_tensors + + def _assign_reduce_step(self, node): + if not node.get_tag("tensorcore_config"): + return super()._assign_reduce_step(node) + # get reduce input size + target_transaction = self.arch.transaction_size[0] * 2 + # 512 bytes // type bits + reduce_input_dtype = node.get_buffer_dtype( + node.block_analyzer.get_input_buffers(node.reduction_block)[0]) + basic = (target_transaction * 8) // reduce_input_dtype.bits + + result = {} + for iter_info in node.raxis: + iter_name = iter_info.var.name + iter_dom = iter_info.dom.extent + if iter_dom % 16 > 0: + result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding + elif iter_dom % basic == 0: + result[iter_name] = basic + else: + return super()._assign_reduce_step(node) + return result + + def _expand_reduce_axis(self, td: TileDict): + # For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis + # to improve compute efficiency. + def _check_small_tile(td: TileDict): + minimal_threadhold = 32 + for node in self.ordered_nodes: + tile = td.get_tile(node) + if any([t <= minimal_threadhold for t in tile]): + return True + return False + + if not _check_small_tile(td): + return None + + smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) + rstep_map = td.rstep_map.copy() + is_block_reduction = self.block_reduction_depth is not None + + def _optimize(node, rstep): + all_steps = self.get_node_reduce_step_candidates(node) + # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] + for k in all_steps: + all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) + if any([v == [] for v in all_steps.values()]): + return rstep + + def _shared_memory_usage(td: TileDict): + return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) + + def _score(rstep_id): + rstep = { + k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis + } + score = 0 + shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, input_buffer in enumerate(input_buffers): + score += coalesced_factor(shape[i], input_buffer.shape) + return score + + def _enlarge(rstep_id): + candidates = [] + for ax in rstep_id: + if rstep_id[ax] + 1 == len(all_steps[ax]): + continue + r = rstep_id.copy() + r[ax] += 1 + candidates.append((r, _score(r))) + if len(candidates) == 0: + return None + return max(candidates, key=lambda x: x[1])[0] + + cur_rstep_id = { + k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis + } + new_rstep_map = rstep_map.copy() + while True: + new_rstep_id = _enlarge(cur_rstep_id) + if new_rstep_id is None: + break + new_rstep_map = { + k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis + } + old_rstep_map = td.rstep_map + td.rstep_map = new_rstep_map + smem_usage, _ = _shared_memory_usage(td) + td.rstep_map = old_rstep_map + if smem_usage > smem_limit: + break + else: + cur_rstep_id = new_rstep_id + rstep = { + k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis + } + return rstep + + for node in self.ordered_nodes: + if len(node.raxis) > 0: + rstep = _optimize(node, rstep_map) + rstep_map = rstep + + if is_block_reduction: + # If block reduction, we should constrain the max value is 64 + # Otherwise it will introduce an issue of cuda invalid args. + MAX_REDUCE_K = 64 + for k in rstep_map: + rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) + td.rstep_map = rstep_map + td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) + return + + def get_node_reduce_step_candidates(self, node): + if not node.get_tag("tensorcore_config"): + return super().get_node_reduce_step_candidates(node) + else: + # must be a a multiple of wmma_k + return { + k.var.name: + [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] + for k in node.raxis + } + + def check_tile_shape_isvalid(self, td: TileDict): + for node in self.ordered_nodes: + if node.get_tag("tensorcore_config"): + ax_m, ax_n = node.get_tag("tensorcore_config") + block_m, block_n = ( + td.tile_map[node][ax_m], + td.tile_map[node][ax_n], + ) + # check the tile size is valid + wmma_invalid = [ + block_m < wmma_m or block_n < wmma_n + for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes() + ] + if all(wmma_invalid): + return False + if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): + return False + return super().check_tile_shape_isvalid(td) + + def _can_implement_layout(self, node: PrimFuncNode, td: TileDict): + # Not implemented yet + # This function is used to check whether we can implement swizzling + # layout under this tile config + return False + + def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): + if not node.get_tag("tensorcore_config"): + return super().compute_node_stride_map(node, td) + use_layout = self._can_implement_layout(node, td) + + AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), + td.get_rstep(node)) + A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) + tensor_strides = {} + output_strides = { + int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) + } + tensor_strides = {} + # when connected to shared input, should use full stride without rstep + for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): + if use_layout: + continue + _ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name + # TODO(lei): should dig further for shared memory connection case. + + return output_strides, tensor_strides + + def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): + if not node.get_tag("tensorcore_config"): + return super()._assign_block_size(node, td, block_size) + ax_m, ax_n = node.get_tag("tensorcore_config") + if block_size % self.arch.warp_size != 0: + return None + tile, rsteps = td.get_tile(node), td.get_rstep(node) + warps = block_size // self.arch.warp_size + ndim = len(tile) + + wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1] + wmma_tile = [1 for _ in range(ndim)] + wmma_tile[ax_m] = wmma[0] + wmma_tile[ax_n] = wmma[1] + + space = [tile[i] // wmma_tile[i] for i in range(ndim)] + if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: + # allow pad, otherwise, we can not get a valid tile shape + return None + + factors = factorize(np.prod(space) // warps) + + def _score(node, thread): # small is better + score = 0 + block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] + shape = node.propagate_inputs_on_reduction(block_tile) + input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) + for i, _ in enumerate(input_buffers): + score += np.prod(shape[i]) / self.arch.bandwidth[1] + return score + + warp_tile = wmma_tile.copy() + for factor in reversed(factors): + score_map = {} + for i in range(ndim): + if tile[i] % (warp_tile[i] * factor) != 0: + continue + warp_tile[i] *= factor + score_map[i] = (_score(node, warp_tile), i) + warp_tile[i] //= factor + if len(score_map) == 0: + return None + dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) + warp_tile[dim_order[0]] *= factor + + codegen_dict = Hint() + codegen_dict.block = tile + codegen_dict.warp = warp_tile + codegen_dict.use_tc = True + codegen_dict.pipeline_stage = self.pipeline_stage + codegen_dict.block_reduction_depth = self.block_reduction_depth + codegen_dict.use_async = self.use_async_copy + codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis] + codegen_dict.cached_tensors = td.cached_tensors_map[node] + codegen_dict.rasterization_plan = self.plan_rasterization(td) + + intrin_info = node.get_tag("intrin_info") + if intrin_info: + codegen_dict.intrin_info = IntrinInfo(**intrin_info) + if intrin_info["out_dtype"] in ["float32"]: + codegen_dict.shared_scope = "shared.dyn" + # smem capacity + if td.smem_cost > self.arch.smem_cap: + codegen_dict.shared_scope = "shared.dyn" + + codegen_dict.complete_config(node) + codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) + codegen_dict.arch = self.arch + codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") + codegen_dict.tensorcore_legalization() + return codegen_dict + + def plan_rasterization(self, td: TileDict): + conditions = [] + # only support single node for now + conditions.append(len(self.ordered_nodes) > 1) + # only on Ampere+ arch + conditions.append(self.arch.compute_capability < "80") + + def _check_memory_size(): + overall_gmem_size_in_bytes: int = 0 + for node in self.ordered_nodes: + for buffer in node.input_buffers: + overall_gmem_size_in_bytes += ( + int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8) + return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes + + conditions.append(_check_memory_size()) + if any(conditions): + return NoRasterization() + # otherwise, simply provide a block rasterization factor + raster_factor = int(self.arch.compute_max_core**0.5) + + return Rasterization2DColumn(raster_factor) diff --git a/bitblas/base/roller/rasterization.py b/bitblas/base/roller/rasterization.py new file mode 100644 index 000000000..4fb779069 --- /dev/null +++ b/bitblas/base/roller/rasterization.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Rasteration Plan For L2 Cache Locality""" + +from typing import List + + +class Rasterization: + + def __init__(self) -> None: + pass + + def get_code(self) -> List[str]: + raise NotImplementedError() + + +class NoRasterization(Rasterization): + + def __init__(self) -> None: + super().__init__() + + def __repr__(self) -> str: + return "" + + def get_code(self) -> List[str]: + return [] + + +class Rasterization2DRow(Rasterization): + """ + Rasterization by Row, each Row line width is panel_width + _________ + _________| + |_________ + __________| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_code(self) -> List[str]: + raise NotImplementedError() + + +class Rasterization2DColumn(Rasterization): + """ + Rasterization by Column, each column line width is panel_width + _ + | | | | + | | | | + |_| |_| + """ + + def __init__(self, panel_width=4) -> None: + super().__init__() + self.panel_width_ = panel_width + + def __repr__(self) -> str: + return f"" + + def get_device_function(self) -> str: + return """ +__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width; + const auto bz = blockIdx.z; + + dim3 blockIdx(bx, by, bz); + return blockIdx; +} + """ + + def get_code(self, panel_width: int = None) -> List[str]: + if panel_width is None: + panel_width = self.panel_width_ + return [ + self.get_device_function(), + "const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width), + ] diff --git a/bitblas/base/roller/shape_inference/__init__.py b/bitblas/base/roller/shape_inference/__init__.py new file mode 100644 index 000000000..188aa0bb7 --- /dev/null +++ b/bitblas/base/roller/shape_inference/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .tir import get_analyzer_by_tir # pylint: disable=unused-import diff --git a/bitblas/base/roller/shape_inference/common.py b/bitblas/base/roller/shape_inference/common.py new file mode 100644 index 000000000..730bbbeef --- /dev/null +++ b/bitblas/base/roller/shape_inference/common.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import OrderedDict +from typing import Dict, List + +from tvm import arith + + +class Statement(): + def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): + self.output = output + self.dependent_region = dependent_region + self.var_map = var_map + self.range_map = range_map + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + +class InputShapeInference(): + def __init__(self, deps: List[Statement]): + self.deps = deps + + def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]): + shape = shape.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, bound in zip(dep.var_map.values(), shape[dep.output]): + ana.update(var, bound) + for var, bound in dep.range_map.items(): + if var.name in rstep: + bound = arith.ConstIntBound(0, min(bound.max_value, rstep[var.name] - 1)) + ana.update(var, bound) + for name, regions in dep.dependent_region.items(): + for region in regions: + bounds = [ana.const_int_bound(index) for index in region] + if name in shape: # simply merge two bounds + bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] + shape[name] = bounds + + for name, bounds in shape.items(): + shape[name] = [c.max_value - c.min_value + 1 for c in bounds] + return shape + + def infer(self, shape, rstep: Dict[str, int] = {}): + if isinstance(shape, (list, tuple)): + shape = {"output0" : [arith.ConstIntBound(0, val - 1) for val in shape]} + shape = self._infer(shape, rstep) + return shape + + def get_input_exprs(self, output_exprs): + result = output_exprs.copy() + ana = arith.Analyzer() + for dep in reversed(self.deps): + for var, expr in zip(dep.var_map.values(), result[dep.output]): + ana.bind(var, expr) + for var in dep.range_map: + ana.bind(var, 0) + for name, regions in dep.dependent_region.items(): + if name in result: + continue + region = regions[0] + input_expr = [ana.simplify(index) for index in region] + result[name] = input_expr + return result + diff --git a/bitblas/base/roller/shape_inference/tir.py b/bitblas/base/roller/shape_inference/tir.py new file mode 100644 index 000000000..35bf0b7d8 --- /dev/null +++ b/bitblas/base/roller/shape_inference/tir.py @@ -0,0 +1,399 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Dict, List, Tuple, Set, Mapping +from tvm.tir.schedule.schedule import BlockRV +from tvm.ir import structural_equal +from tvm import arith, tir + + +class Statement: + def __init__(self, block_analyzer, block: BlockRV): + self.block_analyzer = block_analyzer + self.block = block + # assume one tir block only has one output buffer + self.dep_name = block_analyzer.get_output_buffers(block)[0].name + self.dependent_region = _extract_dependent_region(block_analyzer, block) + + self.reverse_bound_inference = {} + + def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]): + if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: + return None + if len(self.dependent_region[input_name]) != 1: + return None + indices = self.dependent_region[input_name][0] + iter_map_range = { + _iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block) + } + iter_map_result = arith.detect_iter_map( + indices, + iter_map_range, + check_level=arith.iter_affine_map.IterMapLevel.Surjective, + simplify_trivial_iterators=False, + ) + if len(iter_map_result.errors) > 0: + return None + results = arith.iter_affine_map.inverse_affine_iter_map(iter_map_result.indices, input_iter) + output_indices = [] + for _iter in self.block_analyzer.get_spatial_axis(self.block): + if _iter.var in results: + output_indices.append(results[_iter.var]) + else: + # not Bijective mapping case + output_indices.append(tir.Var("undefined", dtype="int32") % int(_iter.dom.extent)) + return output_indices + + +def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): + return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) + + +class TensorDepNode(object): + """ + For tensor dependency analysis. + """ + + def __init__(self, name): + self.name = name + self._next = [] + self._prev = [] + + def add_next(self, node): + self._next.append(node) + self.deduplicate(self._next) + + def add_prev(self, node): + self._prev.append(node) + self.deduplicate(self._prev) + + def deduplicate(self, lst): + seen = set() + lst[:] = [n for n in lst if not (n in seen or seen.add(n))] + + def __str__(self): + return self.name + + def __repr__(self): + return self.name + + +class DependencyAnalysis(object): + def __init__(self, deps): + self.deps = deps + # issue: duplicate name when we have two same ops. + self.name2dep = self._construct_unique_name2dep(deps) + self.mapping = {} # name -> TensorDepNode + + def _construct_unique_name2dep(self, deps): + """ + This is a workaround for the issue that we have two same ops' fuse case. + See https://github.com/apache/tvm/issues/16433 + """ + _names:Set = set() + name2dep:Mapping = {} + for dep in deps: + output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] + base_name = output_buffer.name + if base_name not in _names: + _names.add(base_name) + else: + i = 1 + while f"{base_name}_{i}" in _names: + i += 1 + base_name = f"{base_name}_{i}" + _names.add(base_name) + name2dep[base_name] = dep + return name2dep + + def get_or_create_node(self, name): + if name not in self.mapping: + self.mapping[name] = TensorDepNode(name) + return self.mapping[name] + + def traverse_dependencies(self, compute): + if isinstance(compute, Statement): + node = self.get_or_create_node( + compute.block_analyzer.get_output_buffers(compute.block)[0].name + ) + # Loop through input tensors + for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): + # Get the input node + input_node = self.traverse_dependencies(input_buffer) + input_node.add_next(node) + node.add_prev(input_node) + elif isinstance(compute, tir.Buffer): + node = self.get_or_create_node(compute.name) + return node + + def analyze(self): + # Starting point for traversal + for _, compute in self.name2dep.items(): + self.traverse_dependencies(compute) + + def print_dependencies(self): + for name, node in self.mapping.items(): + print(f"{name} depends on {', '.join([prev.name for prev in node._prev])}") + + def find_path_from_source(self, start_name, target_name): + """ + Finds the path (if it exists) from a starting node (source) to a target node. + Returns the path as a list of nodes. + """ + visited = set() + path = [] + if self._find_path_recursive(self.mapping[start_name], target_name, visited, path): + return path + return [] + + def _find_path_recursive(self, current_node, target_name, visited, path): + """ + Recursive helper function for find_path_from_source. + """ + if current_node.name == target_name: + path.append(current_node) + return True + + if current_node.name in visited: + return False + + visited.add(current_node.name) + path.append(current_node) + + for next_node in current_node._next: + if self._find_path_recursive(next_node, target_name, visited, path): + return True + + path.pop() + return False + + +class InputShapeInference: + def __init__(self, deps: List[Statement]): + self.deps = deps + self.target_mapping = {} + self.buffer_mapping = {} + self.reduce_axes = [] + for dep in self.deps: + for ax in dep.block_analyzer.get_reduce_axis(dep.block): + self.reduce_axes.append(ax) + self.dep_analysis = DependencyAnalysis(self.deps) + self.dep_analysis.analyze() + + def construct_dependency_target(self, targets: Tuple[str]): + if targets in self.target_mapping: + return self.target_mapping[targets] + # should be buffer name instead of block name + name2dep = { + dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps + } + mapping = {} + input_vars = [] + for target in targets: + vars = [ + iter.var + for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block) + ] + input_vars.append(vars) + mapping[target] = [vars] + ana = arith.Analyzer() + + for dep in self.deps: + for name in dep.dependent_region: + if name not in mapping: + continue + dep_name = dep.dep_name + indices = mapping[name][0] + output_indices = dep.make_reverse(name, indices) + if dep_name in targets: + continue + if dep_name not in mapping: + mapping[dep_name] = [output_indices] + elif not region_exist_in_list(output_indices, mapping[dep_name]): + mapping[dep_name].append(output_indices) + + for dep in reversed(self.deps): + indices_list = mapping[dep.dep_name] + ax_vars = [iter.var for iter in dep.block_analyzer.get_spatial_axis(dep.block)] + for input_name, regions in dep.dependent_region.items(): + if input_name in targets: + continue + if input_name not in mapping: + mapping[input_name] = [] + for indices in indices_list: + for region in regions: + vmap = { + k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) + for k, v in zip(ax_vars, indices) + } + region = [ + ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region + ] + if not region_exist_in_list(region, mapping[input_name]): + mapping[input_name].append(region) + buffers = [] + for dep in self.deps: + for buffer in dep.block_analyzer.get_buffers(dep.block): + buffers.append(buffer) + + for buffer in buffers: + self.buffer_mapping[buffer.name] = buffer + + self.target_mapping[targets] = input_vars, mapping + return input_vars, mapping + + def infer( + self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int] = {}, targets=None + ): + compute_targets = tuple(shape.keys()) + input_vars, mapping = self.construct_dependency_target(compute_targets) + ana = arith.Analyzer() + results = {} + intermediate_bind = {} + for vars, bounds in zip(input_vars, shape.values()): + for var, bound in zip(vars, bounds): + ana.update(var, bound, True) + for ax in self.reduce_axes: + # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. + if ax.var.name in rstep: + bound = arith.ConstIntBound( + int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1) + ) + else: + bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) + ana.update(ax.var, bound, True) + + for name, regions in mapping.items(): + if targets is not None and name not in targets: + continue + if compute_targets[0:1] == compute_targets: + (compute_target,) = compute_targets + path = self.dep_analysis.find_path_from_source(name, compute_target) + if len(path) > 2: + intermediate_nodes = path[1:-1] + for node in intermediate_nodes: + iters = mapping[node.name] + if len(iters) != len(regions) or len(iters) != 1: + continue + if len(*iters) != len(*regions): + break + regions = iters + intermediate_bind[name] = compute_target + + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + else: + for region in regions: + bound = [ana.const_int_bound(indice) for indice in region] + if name in results: # simply merge two bounds + bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] + results[name] = bound + + for name, bounds in results.items(): + results[name] = [c.max_value - c.min_value + 1 for c in bounds] + return results, intermediate_bind + + def get_input_exprs(self, output_exprs): + input_vars, mapping = self.construct_dependency_target(tuple(output_exprs.keys())) + ana = arith.Analyzer() + for ax in self.reduce_axes: + ana.bind(ax.var, 0) + vmap = {} + for vars, exprs in zip(input_vars, output_exprs.values()): + for var, expr in zip(vars, exprs): + if expr.dtype != var.dtype: + expr = tir.Cast(var.dtype, expr) + vmap[var] = expr + result = {} + + for name, regions in mapping.items(): + region = regions[0] + result[name] = [ + ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region + ] + return result + + +def region_exist_in_list(a, list) -> bool: + def expr_is_same(a, b) -> bool: + if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): + return a.value == b.value + return structural_equal(a, b) + + def region_is_same(a, b) -> bool: + for indice_a, indice_b in zip(a, b): + if not expr_is_same(indice_a, indice_b): + return False + return True + + return any([region_is_same(a, x) for x in list]) + + +def walk_indice(expr): + if isinstance(expr, tir.expr.BinaryOpExpr): + a = walk_indice(expr.a) + b = walk_indice(expr.b) + if a is not None and b is not None: + return expr + else: + return None + elif isinstance(expr, tir.expr.ConstExpr): + return expr + elif isinstance(expr, tir.Var): + return expr + elif isinstance(expr, tir.ProducerLoad): + return None + elif isinstance(expr, tir.Cast): + a = walk_indice(expr.value) + if a is not None: + return expr + return None + elif isinstance(expr, tir.Call): + return None + else: + raise Exception("Unhandled node type in walk_indice(): %s" % expr) + + +def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]: + input_buffers = block_analyzer.get_input_buffers(block) + dependent_region = {buffer.name: [] for buffer in input_buffers} + + def fvisit(x): + if not isinstance(x, tir.BufferLoad): + return + if x.buffer.name not in dependent_region: + return + index = [] + for indice, shape_limit in zip(x.indices, x.buffer.shape): + expr = walk_indice(indice) + if expr is None: + expr = tir.Var("undefined", dtype="int8") % shape_limit + if isinstance(expr, tir.IntImm) and expr.value == 0: + """for tensor ir zero dim smplification case. + for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1024), T.int64(1024)): + with T.block("T_dense"): + v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) + T.reads(A_reindex[T.int64(0), v0, v2], B_reindex[T.int64(0), v1, v2]) + T.writes(T_dense_reindex[T.int64(0), v0, v1]) + with T.init(): + T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) + T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] + For exmaple, the T_dense_reindex has three dims, however there're only two spatial loops. + """ + continue + index.append(expr) + if not region_exist_in_list(index, dependent_region[x.buffer.name]): + dependent_region[x.buffer.name].append(index) + + stmt = block_analyzer.sch.get(block) + tir.stmt_functor.post_order_visit(stmt, fvisit=fvisit) + return dependent_region + + +def get_analyzer_by_tir(block_analyzer, args) -> InputShapeInference: + deps = [Statement(block_analyzer, block) for block in args] + + return InputShapeInference(deps) diff --git a/bitblas/base/schedule_rule.py b/bitblas/base/schedule_rule.py new file mode 100644 index 000000000..53319b4fc --- /dev/null +++ b/bitblas/base/schedule_rule.py @@ -0,0 +1,149 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm schedule_rule.py in dlight. +"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.""" +from typing import Callable, List, Union + +from tvm import tir +from tvm.target import Target + + +class ScheduleRule: # pylint: disable=too-few-public-methods + """A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc. + + Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule + returns either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is + allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to + contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to + return a Schedule, and the Schedule is not allowed to contain tunable instructions. + """ + + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + """Apply the ScheduleRule to the given PrimFunc. + + Parameters + ---------- + func : tir.PrimFunc + The PrimFunc to apply the ScheduleRule to. + target : Target + The compilation target the schedule is supposed to be built for. + tunable : bool + Whether the schedule is allowed to contain tunable instructions. + + Returns + ------- + results : Union[None, tir.Schedule, List[tir.Schedule]] + Either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. + """ + raise NotImplementedError + + def apply_config( + self, + func: tir.PrimFunc, + config, + ): + """Apply the ScheduleRule to the given PrimFunc. + + Parameters + ---------- + func : tir.PrimFunc + The PrimFunc to apply the ScheduleRule to. + target : Target + The compilation target the schedule is supposed to be built for. + configs : + # todo: Discribe the configs + Returns + ------- + results : Union[None, tir.Schedule, List[tir.Schedule]] + Either a Schedule, a list of Schedules, or None, where None means that the rule + is not applicable to the given PrimFunc. + """ + raise NotImplementedError + + @staticmethod + def from_callable( + name, + ) -> Callable[ + [ + Callable[ + [tir.PrimFunc, Target, bool], + Union[None, tir.Schedule, List[tir.Schedule]], + ], + ], + "ScheduleRule", + ]: + """Create a ScheduleRule from a callable. + + Parameters + ---------- + name : str + + Returns + ------- + decorator : Callable + A decorator that takes a callable and returns a ScheduleRule. + + Examples + -------- + .. code-block:: python + + @ScheduleRule.from_callable("MyRule") + def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] + # Do something with func and target + """ + + def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name + class _Rule(ScheduleRule): + def apply( + self, + func: tir.PrimFunc, + target: Target, + tunable: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + return f(func, target, tunable) + + _Rule.__name__ = name + return _Rule() + + return decorator + + def is_target_available( + self, target: Target + ) -> bool: # pylint: disable=unused-argument + """Check whether the rule is available for the given target. + + Parameters + ---------- + target : Target + The compilation target the schedule is supposed to be built for. + + Returns + ------- + available : bool + Whether the rule is available for the given target. + """ + return True diff --git a/bitblas/base/transform.py b/bitblas/base/transform.py new file mode 100644 index 000000000..647efa772 --- /dev/null +++ b/bitblas/base/transform.py @@ -0,0 +1,218 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Apply ScheduleRules onto an IRModule to generate default schedules without tuning, +or a space for MetaSchedule tuning +""" +from typing import List, Optional, Dict +import os +import shutil +import tempfile +import os.path as osp +import tvm +from tvm import tir +from tvm import meta_schedule as ms +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm.target import Target +from .schedule_rule import ScheduleRule +from ..base.analysis import check_func_with_dynamic +from .utils import fast_tune, fast_tune_with_dynamic_range +import logging + +logger = logging.getLogger(__name__) + + +def _is_scheduled(func: tir.PrimFunc) -> bool: + if not isinstance(func, tir.PrimFunc): + return False + if not func.attrs: + return False + if "tir.is_scheduled" not in func.attrs: + return False + return func.attrs["tir.is_scheduled"] == 1 + + +@module_pass(opt_level=0, name="ApplyDefaultSchedule") +class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods + """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" + + def __init__(self, *rules: ScheduleRule): + """Construct a new ApplyDefaultSchedule pass. + + Parameters + ---------- + *rules : ScheduleRule + The ScheduleRules to apply to all PrimFuncs in the module. + """ + self.rules = list(rules) + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + target = Target.current(allow_none=False) + + updated_functions = {} + for g_var, func in mod.functions_items(): + if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + sch = _apply_rules(func, target, self.rules, tunable=False) + if sch is not None: + assert len(sch) == 1 + updated_functions[g_var] = (sch[0].mod["main"].with_attr("tir.is_scheduled", 1)) + for g_var, func in updated_functions.items(): + mod[g_var] = func + return mod + + +@module_pass(opt_level=0, name="ApplyFastTuning") +class ApplyFastTuning: # pylint: disable=too-few-public-methods + """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" + + def __init__( + self, + topk: int = 10, + target: Optional[Target] = None, + parallel_build: bool = True, + meta_database_dir: str = None, + whitelist: Optional[List[str]] = None, + dynamic_range: Optional[Dict[str, List[int]]] = None, + ): + """Construct a new ApplyFastTuning pass. + + Parameters + ---------- + meta_database : str + The path of database. + dynamic_range : Dict[str, List[int]] + Use for generate kernel based on dynamic range. + """ + if whitelist is None: + whitelist = [] + if dynamic_range is None: + dynamic_range = {} + self.topk = topk + self.target = Target.current() if target is None else target + self.parallel_build = parallel_build + self.meta_database_dir = meta_database_dir + self.whitelist = whitelist + self.dynamic_range = dynamic_range + self.temp_dir = tempfile.TemporaryDirectory() + path_workload = osp.join(self.temp_dir.name, "database_workload.json") + path_tuning_record = osp.join(self.temp_dir.name, "database_tuning_record.json") + self.cache_meta_database = ms.database.JSONDatabase( + path_workload, path_tuning_record, module_equality="structural") + + def _in_white_list(self, func_name: str) -> bool: + if len(self.whitelist) == 0: + return True + return any([name in func_name for name in self.whitelist]) + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + target = self.target + updated_functions = {} + + for g_var, func in mod.functions_items(): + if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): + if not self._in_white_list(g_var.name_hint): + continue + normalize_mod_func_ = tvm._ffi.get_global_func("tvm.meta_schedule.normalize_mod") + _normalized_func_mod = normalize_mod_func_(func) + + if self.cache_meta_database.has_workload(_normalized_func_mod): + tuning_record = self.cache_meta_database.query_tuning_record( + _normalized_func_mod, + target, + g_var.name_hint, + ) + if tuning_record: + trace = tuning_record.trace + sch = tvm.tir.Schedule(func) + trace.apply_to_schedule(sch, remove_postproc=False) + updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) + continue + + if check_func_with_dynamic(func): + + dispatch_mod = fast_tune_with_dynamic_range( + func, + target=target, + topk=self.topk, + parallel_build=self.parallel_build, + global_symbol=g_var.name_hint, + dynamic_range=self.dynamic_range, + ) + + if dispatch_mod: + for g, f in dispatch_mod.functions_items(): + if g.name_hint == g_var.name_hint: + # avoid duplicated global symbol + updated_functions[g_var] = f.without_attr( + "global_symbol").with_attr("tir.is_scheduled", 1) + else: + updated_functions[g] = f.with_attr("tir.is_scheduled", 1) + # cannot reuse meta database as it cannot be recorvered from the trace + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) + else: + # otherwise is static shape analysis + _, best = fast_tune( + func, + target=target, + topk=self.topk, + parallel_build=self.parallel_build, + ) + + if best is not None: + updated_functions[g_var] = best.sch.mod["main"].with_attr( + "tir.is_scheduled", 1) + workload = self.cache_meta_database.commit_workload(_normalized_func_mod) + # only record the best schedule + self.cache_meta_database.commit_tuning_record( + ms.database.TuningRecord( + best.sch.trace, + workload, + [best.latency], + target, + ms.arg_info.ArgInfo.from_prim_func(func=best.sch.mod["main"]), + )) + + for g_var, func in updated_functions.items(): + mod[g_var] = func + + # copy database + if self.meta_database_dir is not None: + if not osp.exists(self.meta_database_dir): + os.makedirs(self.meta_database_dir) + # TODO(lei): maybe another way to copy the database + shutil.copytree(self.temp_dir.name, self.meta_database_dir, dirs_exist_ok=True) + + return mod + + def __del__(self): + # clean up the temp cache + self.temp_dir.cleanup() + + +def _apply_rules( + func: tir.PrimFunc, + target: Target, + rules: List[ScheduleRule], + tunable: bool, +) -> Optional[List[tir.Schedule]]: + for rule in rules: + try: + space = rule.apply(func, target, tunable) + except Exception: + logger.debug(f"[BitBLAS][Error] applying rule {rule} failed") + space = None + if space is None: + continue + if isinstance(space, tir.Schedule): + space = [space] + return space + return None diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py new file mode 100644 index 000000000..50adc135f --- /dev/null +++ b/bitblas/base/utils.py @@ -0,0 +1,517 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tvm +import os +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +from typing import List, Tuple, Optional, Dict, Union, Literal +from tvm import tir, IRModule +from tvm.runtime import Module +from tvm.tir import Schedule +from tvm.relax.expr import Function +import bitblas +from .analysis import get_root_block, get_reduction_blocks, find_var_from_func +from bitblas.base.roller.arch import CUDA +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +import tempfile +import itertools +from tvm.ir.supply import GlobalVarSupply +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +import logging + +logger = logging.getLogger(__name__) + + +def get_rasterization_code(pannel_width: int = 8) -> str: + return f""" + const int MAX_BLOCK_N = {pannel_width}; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + """ + + +class CompileResult: + """ + Class to store the result of compilation + """ + + def __init__(self, config, sch, mod: Module): + self.config = config + self.sch = sch + self.mod = mod + self.code = mod.imported_modules[0].get_source() if mod else None + self.latency = 1e9 + self.profile_tensors = [] + self.time_evaluator = None + + def profile(self): + profile_tensors = self.profile_tensors + return self.time_evaluator(*profile_tensors).mean * 1e3 + + +def _apply_config( + func: tir.PrimFunc, + config=None, # todo(lei): update typing +) -> Optional[tir.Schedule]: + """ + find rules: + case 1. if the main block has no reduce op, then use the Elementwise rule. + case 2. if the config enabled tensorcore, then use the TensorCore rule. + case 3. if any([t > 1 for t in config.reduce_thread]), we should use the InnerThread Reduction Rule. + case 4. else we should use general reduction rule. + """ + logger.debug("Apply config {}".format(config)) + + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + + if not reduction_blocks: + return bitblas.gpu.ElementWise().apply_config(func, config) + elif config.use_tc: + if config.arch.sm_version >= 80: + # For A100(sm_80) or more advanced gpu, use MMA tensorization. + return bitblas.gpu.MatmulTensorizationMMA().apply_config(func, config) + else: + # For other GPUs, use WMMA tensorization. + return bitblas.gpu.MatmulTensorizationWMMA().apply_config(func, config) + else: + _reduction_rules = [] + + _reduction_rules.append(bitblas.gpu.GEMV()) + if not any([t > 1 for t in config.reduce_thread]): + # Matrix multiplication template doesn't support inner thread reduction + _reduction_rules.append(bitblas.gpu.Matmul()) + _reduction_rules.append(bitblas.gpu.GeneralReduction()) + + for rule in _reduction_rules: + sch = rule.apply_config(func, config) + try: + sch = rule.apply_config(func, config) + except Exception as e_msg: + logger.debug("Apply config failed: ", e_msg) + continue + if sch is not None: + return sch + return None + + +def get_dummy_input_arrays( + func: Union[tir.PrimFunc, Function], + device: tvm.runtime.Device, + distribution: Literal["uniform", "onefill"] = "uniform", +): + + def var_wrapper(v): + if isinstance(v, tvm.tir.Var): + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return func.attrs["opt_shapes"][v.name].value + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + profile_tensors = [] + for param in func.params: + if isinstance(func, tir.PrimFunc): + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + elif isinstance(func, Function): + arg = param.struct_info + else: + raise ValueError("Not supported type: ", type(func)) + + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + + numpy_dtype = map_numpy_type(arg.dtype) + if distribution == "uniform": + profile_tensors.append( + tvm.nd.array( + np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), + device=device, + )) + elif distribution == "onefill": + profile_tensors.append( + tvm.nd.array( + np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), + device=device, + )) + else: + raise ValueError("Not supported distribution: ", distribution) + return profile_tensors + + +def apply_and_build_parallel(func, + configs, + arch, + num_repeats=3, + max_workers=10, + timeout=30, + data_distribution="uniform") -> CompileResult: + cpresults = [] + + profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution) + max_workers = min(len(configs), os.cpu_count(), max_workers) + + # apply config in thread parallel + _sched: List[Schedule] = [] + + def _apply_schedule(f, c): + try: + sch = _apply_config(f, c) + except Exception as apply_schedule_error: + logger.debug("Apply schedule failed: {}".format(apply_schedule_error)) + sch = None + return sch + + with ThreadPoolExecutor(max_workers=4) as scheduler: + futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} + for future in as_completed(futures, timeout=timeout): + _sched.append(future.result()) + + builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout) + + # build in process parallel + def _build(context) -> str: + idx, mod, arch = context + if mod is None: + return idx, None, None + # TODO(lei): + # this is a trick to implement rasteration, will be removed in the future + config = configs[idx] + + @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + def tvm_callback_cuda_postproc(code, _): + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + + with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}): + rt_mod = tvm.build(mod, target=arch.target) + + from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel + + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) + code = rt_mod.imported_modules[0].get_source() + rt_mod.export_library(artifact_path, fcompile=tar) + return idx, code, artifact_path + + _mods = [sch.mod if sch is not None else None for sch in _sched] + + for map_result in builder.map_with_error_catching( + _build, + [(i, mod, arch) for i, mod in enumerate(_mods)], + ): + if map_result.status == StatusKind.TIMEOUT: + logger.debug("LocalBuilder: Timeout") + elif map_result.status == StatusKind.EXCEPTION: + # TODO(lei): redirect the exception to file if needed + logger.debug("LocalBuilder: An exception occurred {}".format(map_result.value)) + continue + elif map_result.status == StatusKind.COMPLETE: + idx, code, artifact_path = map_result.value + if artifact_path is None: + logger.debug("Artifact path is None") + continue + sch = _sched[idx] + config = configs[idx] + rt_mod = tvm.runtime.load_module(artifact_path) + cpresult = CompileResult(config, sch, rt_mod) + timer_cuda_mod = rt_mod.time_evaluator( + rt_mod.entry_name, arch.device, number=num_repeats) + cpresult.profile_tensors = profile_tensors + cpresult.time_evaluator = timer_cuda_mod + cpresult.code = code + cpresults.append(cpresult) + else: + raise ValueError(f"Unreachable: unexpected result: {map_result}") + + del builder + + best = None + best_latency = 1e9 + for cpresult in cpresults: + config = cpresult.config + try: + latency = cpresult.profile() + except Exception as e_mesg: + logger.debug(f"Evaluation with config failed {e_mesg}") + continue + logger.info("Evaluation with config {}".format(config)) + logger.info("Time cost of this config: {:.3f} ms".format(latency)) + + cpresult.latency = latency + if latency < best_latency: + best_latency = latency + best = cpresult + + return cpresults, best + + +def apply_and_build( + func, + configs, + arch, + parallel_build=False, + data_distribution="uniform", +) -> Tuple[List[CompileResult], CompileResult]: + max_workers = 10 if parallel_build else 1 + return apply_and_build_parallel( + func, configs, arch, max_workers=max_workers, data_distribution=data_distribution) + + +def fast_tune( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, + data_distribution: Literal["uniform", "onefill"] = "uniform", +): + # check the function is a primfunc + if not isinstance(func, tir.PrimFunc): + raise ValueError("Only support func is PrimFunc") # pragma: no cover + + if target.kind.name != "cuda": + logger.error("Only support CUDA target") + return None, None + + specilized_func = func + if func.attrs is not None and "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # should be int value + if not all([isinstance(v.value, int) for v in opt_shapes.values()]): + logger.error("The opt_shapes should be int value") + return None, None + # currently only support one dynamic range + if len(opt_shapes) > 1: + logger.error("Currently only support one dynamic range") + return None, None + + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis.name not in opt_shapes: + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = find_var_from_func(func, name) + specilized_func = func.specialize({ + var: shape.astype(var.dtype) + }).with_attr("is_specialized") + + arch = CUDA(target) + + policy = DefaultPolicy(func=func, arch=arch) + try: + specilized_func, tags = get_tensorized_func_and_tags(specilized_func, arch.target) + except Exception as e_msg: + logger.debug("Get tensorized func and tags failed: ", e_msg) + tags = None + if tags: + policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) + + configs = policy.emit_config(topk) + + if len(configs) == 0: + raise ValueError("No valid config generated") + + cpresults, best = apply_and_build( + func, + configs, + arch, + parallel_build=parallel_build, + data_distribution=data_distribution, + ) + + return cpresults, best + + +# always use the first function as the base +def collect_buffers_to_declare(func): + params = [] + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + buffers_to_declare = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + buffers_to_declare.append(buffer) + params.append(buffer.data) + + # the args should be buffers + dynamic symbolic + params += list(dyn_symbolic) + + return params, buffers_to_declare + + +def refactor_specialized_func(g_var, func, params, buffers_to_declare): + body = func.body + attrs = func.attrs + global_symbol = g_var + if "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + + def serialize_name(opt_shapes: Dict): + return "_opt_" + "_".join([f"{k}_{v}" for k, v in opt_shapes.items()]) + + global_symbol += serialize_name(opt_shapes) + ret_type = func.ret_type + for buf in buffers_to_declare: + body = tvm.tir.DeclBuffer(buf, body=body) + + # device func must be private + device_func = tvm.tir.PrimFunc( + params, body, ret_type, attrs=attrs).without_attr("global_symbol") + return global_symbol, device_func + + +def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[str]): + global_symbol = g_var + attrs = func.attrs + buffer_map = func.buffer_map + params = func.params + ret_type = func.ret_type + + # collect dynamic symbolic + dyn_symbolic: List[tvm.tir.Var] = [] + _invoke_params = [] + for param in func.params: + if param not in func.buffer_map: + continue + buffer = func.buffer_map[param] + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: + dyn_symbolic.append(axis) + _invoke_params.append(buffer.data) + _invoke_params += list(dyn_symbolic) + + func_range: List[int] = [] + global_symbols = [] + for g_var, refactor_func in refactored_funcs: + opt_shapes = refactor_func.attrs["opt_shapes"] + func_range.append(list(opt_shapes.values())[0]) + global_symbols.append(g_var) + + # TODO(lei): general the dispatch function to support multiple dynamic symbolics + assert len(dyn_symbolic) == 1, "Only support one dynamic symbolics currently" + + ib = tvm.tir.ir_builder.create() + syb = list(dyn_symbolic)[-1] + last_range = 0 + for i, (_range, g_var) in enumerate(zip(func_range, global_symbols)): + if i == 0: + with ib.if_scope(syb <= _range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + else: + with ib.if_scope(tvm.tir.all(syb > last_range, syb <= _range)): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + last_range = _range + with ib.if_scope(syb > last_range): + ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) + stmt = ib.get() + dispatch_func = tvm.tir.PrimFunc(params, stmt, ret_type, buffer_map, attrs).with_attrs({ + "tir.is_global_func": True, + "global_symbol": global_symbol + }) + return dispatch_func + + +def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc, + specialized_funcs: List[tir.PrimFunc]) -> IRModule: + dispatch_mod: IRModule = tvm.IRModule() + g_var_supply = GlobalVarSupply(dispatch_mod) + refactored_funcs = [] + for func in specialized_funcs: + params, buffers_to_declare = collect_buffers_to_declare(func) + global_symbol, device_func = refactor_specialized_func(g_var, func, params, + buffers_to_declare) + global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) + dispatch_mod[global_symbol] = device_func + refactored_funcs.append((global_symbol, device_func)) + dispatch_func = create_dispatch_func(g_var, original_func, refactored_funcs=refactored_funcs) + dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func)) + return dispatch_mod + + +def fast_tune_with_dynamic_range( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, + global_symbol: Optional[str] = None, + dynamic_range: Optional[Dict[str, List[int]]] = None, +) -> IRModule: + if dynamic_range is None: + dynamic_range = {} + if target.kind.name != "cuda": + logger.error("Only support CUDA target") + return None + if not global_symbol: + global_symbol = func.attrs["global_symbol"] + + # set opt_shapes for the primfunc with dynamic symbolic + opt_shapes: Dict[str, List[int]] = {} + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var): + if axis.name in dynamic_range: + opt_shapes[axis.name] = dynamic_range[axis.name] + else: + raise ValueError(f"[BitBLAS] The axis {axis.name} is not in dynamic_range") + func = func.with_attr("opt_shapes", opt_shapes) + + if "opt_shapes" not in func.attrs: + logger.error( + "[BitBLAS] The primfunc has no opt_shapes, please set opt_shapes for the primfunc") + return None + else: + # should be list value + if not all([isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()]): + logger.error("The opt_shapes should be list value") + return None + + logger.info("Start fast tuning with dynamic range") + opt_shapes = func.attrs["opt_shapes"] + + # Step 1.Calculate the Cartesian product using itertools.product + product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes))) + + # Convert the Cartesian product to a list of dictionaries + specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] + + specilized_tuned_funcs: List[tir.PrimFunc] = [] + for item in specialize_items: + func = func.with_attr("opt_shapes", item) + _, best = fast_tune(func, target, topk, parallel_build) + if best is None: + return None + specilized_tuned_funcs.append(best.sch.mod["main"]) + + return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs) diff --git a/bitblas/cache/__init__.py b/bitblas/cache/__init__.py new file mode 100644 index 000000000..0c8fd3b9c --- /dev/null +++ b/bitblas/cache/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .operator import ( + global_operator_cache, # noqa: F401 + load_global_ops_cache, # noqa: F401 + get_database_path, # noqa: F401 + set_database_path, # noqa: F401 +) diff --git a/bitblas/cache/operator.py b/bitblas/cache/operator.py new file mode 100644 index 000000000..9b30a6200 --- /dev/null +++ b/bitblas/cache/operator.py @@ -0,0 +1,179 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import bitblas +from bitblas.ops.operator import OperatorConfig, Operator +from dataclasses import asdict +import os +import json +import tempfile +from hashlib import sha256 +import shutil +import tvm +from tvm.contrib.tar import tar +import logging + +logger = logging.getLogger(__name__) + +BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") + + +class OperatorCache: + """ + Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations. + """ + + def __init__(self): + self.cache = {} + + def add(self, config: OperatorConfig, op_inst: Operator): + self.cache[config] = op_inst + + def get(self, config: OperatorConfig): + return self.cache.get(config) + + def exists(self, config): + return config in self.cache + + def clear(self): + self.cache.clear() + + def size(self): + return len(self.cache) + + def save_into_database(self, database_path=None, target=None): + database_path = self._ensure_database_path(database_path) + for config, op_inst in self.cache.items(): + arch_str = self._determine_arch_str(op_inst, target) + arch_path = os.path.join(database_path, arch_str) + self._ensure_directory(arch_path) + hash_str = sha256(repr(config).encode()).hexdigest() + config_path = os.path.join(arch_path, hash_str) + # if the config already exists, skip saving + if os.path.exists(config_path): + continue + self._ensure_directory(config_path) + self._save_operator_config_and_artifact(config, op_inst, config_path) + + def load_from_database(self, database_path, target=None): + if not os.path.exists(database_path): + logger.info( + f"Database path {database_path} does not exist, skipping loading operators from the database" + ) + return + arch_str = self._determine_target_arch_str(target) + arch_path = os.path.join(database_path, arch_str) + if not os.path.exists(arch_path): + logger.info( + f"Target {arch_str} does not exist in the database, skipping loading operators from the database" + ) + return + self._load_operators_from_arch_path(arch_path, target) + + def _ensure_database_path(self, database_path): + if database_path is None: + return tempfile.mkdtemp() + os.makedirs(database_path, exist_ok=True) + return database_path + + def _determine_arch_str(self, op_inst, target): + return (target if target else "-".join(list(op_inst.target.keys) + [op_inst.target.arch])) + + def _ensure_directory(self, path): + os.makedirs(path, exist_ok=True) + + def _save_operator_config_and_artifact(self, config, op_inst, config_path): + config_type, operator_type = type(config).__name__, type(op_inst).__name__ + with open(os.path.join(config_path, f"{config_type}.json"), "w") as json_file: + json.dump(asdict(config), json_file) + artifact_path = os.path.join(config_path, "tvm_rt_mod." + tar.output_format) + try: + op_inst.rt_mod.export_library(artifact_path, fcompile=tar) + except Exception as e: + # library does not support export_library + export_error = e # noqa: F841 + pass + json_data = {"config_type": config_type, "operator_type": operator_type} + json_file_path = os.path.join(config_path, "mapping.json") + with open(json_file_path, "w") as json_file: + json.dump(json_data, json_file) + + # For writing source.cu file + source_file_path = os.path.join(config_path, "source.cu") + with open(source_file_path, "w") as source_file: + source_file.write(op_inst.get_source()) + + # For writing optimized.py file + optimized_file_path = os.path.join(config_path, "optimized.py") + with open(optimized_file_path, "w") as optimized_file: + if op_inst.optimized_func is not None: + optimized_file.write(op_inst.optimized_func.script(show_meta=False)) + if op_inst.wrapper.lib_name is not None: + # copy lib name to the same directory as the artifact + src_name = op_inst.wrapper.src_name + shutil.copy( + src_name, + os.path.join(config_path, os.path.basename("wrapper_source.cu")), + ) + lib_name = op_inst.wrapper.lib_name + shutil.copy( + lib_name, + os.path.join(config_path, os.path.basename("wrapper_compiled.so")), + ) + + def _determine_target_arch_str(self, target): + return (target if isinstance(target, str) else "-".join(list(target.keys) + [target.arch])) + + def _load_operators_from_arch_path(self, arch_path, target): + for root, dirs, _ in os.walk(arch_path): + for directory in dirs: + config_path = os.path.join(root, directory) + self._load_operator(config_path, target) + + def _load_operator(self, config_path, target): + mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None + for file in os.listdir(config_path): + full_path = os.path.join(config_path, file) + if file == "mapping.json": + with open(full_path) as f: + mapping = json.load(f) + elif file.endswith(".json"): + with open(full_path) as f: + config = json.load(f) + elif file.endswith(".tar"): + rt_mod = tvm.runtime.load_module(full_path) + elif file == "wrapper_compiled.so": + lib_name = full_path + elif file == "wrapper_source.cu": + src_name = full_path + + if mapping and config and rt_mod: + self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target) + + def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target): + config_cls = getattr(bitblas, mapping["config_type"]) + operator_cls = getattr(bitblas, mapping["operator_type"]) + op_inst = operator_cls( + config=config_cls(**config), target=target, enable_tuning=False, from_database=True) + op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name) + self.add(config_cls(**config), op_inst) + + +global_operator_cache = OperatorCache() + + +def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None): + if target is None: + target = bitblas.auto_detect_nvidia_target() + logger.info(f"Loading operators from database {database_path} for target {target}") + global_operator_cache.load_from_database(database_path, target) + return global_operator_cache + + +def get_database_path(): + return BITBLAS_DATABASE_PATH + + +def set_database_path(path): + global BITBLAS_DATABASE_PATH + BITBLAS_DATABASE_PATH = path + return BITBLAS_DATABASE_PATH diff --git a/bitblas/generator.py b/bitblas/generator.py new file mode 100644 index 000000000..4ac6f2be2 --- /dev/null +++ b/bitblas/generator.py @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class BitBLASGenerator: + + def __init__(self): + # Initialize the generator with configuration + pass + + def generate_cuda_code(self): + pass + + def generate_header(self): + pass diff --git a/bitblas/gpu/__init__.py b/bitblas/gpu/__init__.py new file mode 100644 index 000000000..df0635b3c --- /dev/null +++ b/bitblas/gpu/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +GPU-generic schedule rules. +For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead +""" +from .fallback import Fallback # noqa: F401 +from .element_wise import ElementWise # noqa: F401 +from .gemv import GEMV # noqa: F401 +from .gemv_dequantize import GEMVWithDequantizeInfo # noqa: F401 +from .general_reduction import GeneralReduction # noqa: F401 +from .matmul import ( + Matmul, # noqa: F401 + MatmulTensorizationMMA, # noqa: F401 + MatmulTensorizationWMMA, # noqa: F401 +) +from .matmul_mma_dequantize import ( + MatmulTensorizationMMAWithDequantizeInfo, # noqa: F401 +) +from .matmul_wmma import MatmulTensorizationLegacy # noqa: F401 + +from .reduction import Reduction # noqa: F401 +from .transpose import Transpose # noqa: F401 diff --git a/bitblas/gpu/base.py b/bitblas/gpu/base.py new file mode 100644 index 000000000..3bf927244 --- /dev/null +++ b/bitblas/gpu/base.py @@ -0,0 +1,44 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# /* Modifications Copyright (c) Microsoft. */ +# The code below is mostly copied from apache/tvm base.py in dlight. +"""Base schedule rule for GPU operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to GPU targets, will return None if the target is not GPU.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for gpu rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "gpu" in target.keys diff --git a/bitblas/gpu/element_wise.py b/bitblas/gpu/element_wise.py new file mode 100644 index 000000000..07ea3a27e --- /dev/null +++ b/bitblas/gpu/element_wise.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring +"""A fallback schedule rule for GPU operators.""" +from typing import List + +from tvm import tir + +from ..base import ScheduleRule, normalize_prim_func, try_inline + + +class ElementWise(ScheduleRule): + """ + An elementwise schedule rule for GPU operators. + """ + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + block_factors = config.block + thread_factors = config.thread + step_factors = config.step + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + block_infos = try_inline(sch, block_infos) + + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + sch.reorder(*s_loops, *r_loops, *o_loops) + + block_loops = [] + vthread_loops = [] + thread_loops = [] + inner_loops = [] + for s_loop, block_factor, step_factor, thread_factor in zip( + s_loops, block_factors, step_factors, thread_factors + ): + block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) + vthread_loop, inner_loop = sch.split( + inner_loop, factors=[None, thread_factor * step_factor] + ) + thread_loop, inner_loop = sch.split( + inner_loop, factors=[None, step_factor] + ) + block_loops.append(block_loop) + vthread_loops.append(vthread_loop) + thread_loops.append(thread_loop) + inner_loops.append(inner_loop) + + # inner virtual thread first + vthread_loops = list(reversed(vthread_loops)) + sch.reorder( + *block_loops, + *vthread_loops, + *thread_loops, + *inner_loops, + *r_loops, + *o_loops + ) + sch.bind(sch.fuse(*block_loops), "blockIdx.x") + sch.bind(sch.fuse(*thread_loops), "threadIdx.x") + if len(vthread_loops) > 3: + vthread_loops = vthread_loops[0:2] + [sch.fuse(*vthread_loops[2:])] + + for i, ax in enumerate(vthread_loops): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + + return sch diff --git a/bitblas/gpu/fallback.py b/bitblas/gpu/fallback.py new file mode 100644 index 000000000..3711d3682 --- /dev/null +++ b/bitblas/gpu/fallback.py @@ -0,0 +1,95 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm fallback.py in dlight. +# pylint: disable=missing-docstring +"""A fallback schedule rule for GPU operators.""" +from typing import List, Tuple + +from tvm import tir +from tvm.target import Target + +from ..base import normalize_prim_func, try_inline +from . import utils +from .base import GPUScheduleRule + + +class Fallback(GPUScheduleRule): + """ + A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, + and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + max_threads_per_block = utils.max_threads_per_block(target) + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + block_infos = try_inline(sch, block_infos) + reduction_blocks: List[Tuple[tir.schedule.BlockRV, tir.schedule.LoopRV]] = [] + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block) + ] + ) + or len(sch.get_loops(block)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + sch.reorder(*s_loops, *r_loops, *o_loops) + bx, tx = sch.split( # pylint: disable=invalid-name + sch.fuse(*s_loops), + factors=[None, max_threads_per_block], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + if len(r_loops) > 0: + reduction_blocks.append((block, r_loops[0])) + + for block, r_loop in reduction_blocks: + sch.decompose_reduction(block, r_loop) + + return sch + \ No newline at end of file diff --git a/bitblas/gpu/gemv.py b/bitblas/gpu/gemv.py new file mode 100644 index 000000000..60a290a81 --- /dev/null +++ b/bitblas/gpu/gemv.py @@ -0,0 +1,794 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm gemv.py in dlight. +"""A rule for GEMV and DecodeGEMV.""" + +from functools import reduce +from typing import List, Optional, Union, Dict + +from tvm import DataType, arith, ir, tir +from tvm.target import Target + +from ..base import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, + try_inline_contiguous_spatial, + get_output_blocks, +) +from .base import GPUScheduleRule +from .gemv_dequantize import GEMVWithDequantizeInfo + + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): + loop: tir.For = sch.get(loop_rv) + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent + + +def get_bytes(dtype: Union[DataType, str]) -> int: + if isinstance(dtype, str): + dtype = DataType(dtype) + return int(dtype.bits) // 8 + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(_get_reduction_expr(block_stmt) is not None) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0) + if not all(conditions): + return None + + iter_num = len(block_stmt.iter_vars) + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 + ] + if len(ret) == len(block_stmt.reads): + func = sch.mod["main"] + opt_shapes: Dict = {} + if "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # check with dynamic symbolic and at least one is unit + if not all([opt_shapes.get(buf.name, (1,))[0] == 1 for buf in ret]): + return None + elif len(ret) == 0: + return None + return ret + + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend([ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ]) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction + + +class GEMV(GPUScheduleRule): + """A rule for GEMV and DecodeGEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + if "dequantize_info" in func.attrs: + dequantize_rule = GEMVWithDequantizeInfo() + return dequantize_rule.apply(func, target, False) + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + block = block_info.block_rv + vector_input_buffers = is_gemv(sch, block_info) + if vector_input_buffers is None: + return None + + # Step 1. Normalize the block, merge spatial and reduction iters + is_inner_reduction = normalize(sch, block_info) + + # Step 2. Do the scheduling + if is_inner_reduction is None: + return None + elif is_inner_reduction: + self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) + return sch + else: + return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) + + def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the inner reduction block.""" + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + TILE_S, + TILE_R, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + ): + # rfactor: reduce to tx * vec_c + _, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(_, s) + r = sch.fuse(r, c) + bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) + r, tr, tile_r_vec_n, vec_c = sch.split( + r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True) + sch.reorder(r, tile_r_vec_n, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) + sch.bind(bx, "blockIdx.x") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_c) + + shared_mem_usage = 0 + for buf in vector_input_buffers: + buf_size = reduce(lambda x, y: x * y, buf.shape, tir.IntImm( + buf.shape[0].dtype, 1)) * get_bytes(buf.dtype) + shared_mem_usage += buf_size + try: + max_shared_memory_per_block = target.max_shared_memory_per_block + except Exception: + max_shared_memory_per_block = 49152 + LOAD_V_SHARED = ( + LOAD_V_SHARED and isinstance(shared_mem_usage, tir.IntImm) and + shared_mem_usage.value <= max_shared_memory_per_block) + + # vectorize load A + # (TODO) this is now actually problematic since the number of loops is dependent on the + # number of dimensions of A_q + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, r, preserve_unit_loops=True) + s_local, r_local = sch.get_loops(block=Aq_local)[-2:] + s_local, vec_load = sch.split( + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True) + sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 + sch.vectorize(vec_load) + + # load vector into shared memory, shape should be the whole vector + if LOAD_V_SHARED: + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") + sch.compute_at(V_shared, tr, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] # noqa: E741 + loop: tir.For = sch.get(l) + if isinstance(loop.extent, tir.IntImm): + # avoid introducing predicates when vector length is too large + vec_length = max( + min( + get_max_factor( + (int)(loop.extent), + [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], + ) // TS // TR, + LOAD_V_VEC, + ), + 1, + ) + else: + vec_length = LOAD_V_VEC + if TAG_R == "threadIdx.x": + _, ty, tx, vec = sch.split( + l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True) + else: + _, ty, tx, vec = sch.split( + l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + tile_s, vec_s = sch.split( + tile_s, + factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], + preserve_unit_iters=True, + ) + sch.reorder(ts, tr, tile_s, vec_s, vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_s) + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.reorder(tile_s, ts, tr) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + unroll_factor = UNROLL + + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], + ann_key="pragma_unroll_explicit", + ann_val=1, + ) + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_unroll_explicit", + ann_val=1, + ) + + if LOAD_V_SHARED: + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_unroll_explicit", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_vectorize", + ann_val=1, + ) + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TS]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.bind(ts, TAG_S) + sch.set_scope(block, 0, "local") + # pylint: enable=invalid-name + return sch + + # Specify the `len_tx` and `len_ty` according to the loop extent + batch, s, r, c = sch.get_loops(block=block) + len_batch, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" + if target.kind.name == "cuda": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 64 + else: + TS, TR = 16, 32 + elif target.kind.name == "metal": + # Note that the following tile size is tuned on M2 Ultra for 7B + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 16 + else: + TS, TR = 2, 64 + elif target.kind.name == "rocm": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 1, 128 + else: + TS, TR = 8, 64 + elif target.kind.name == "opencl" and "android" in str(target.host): + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 8 + TS, TR = 2, 32 + elif target.kind.name == "vulkan": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 4 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 32 + else: + TS, TR = 16, 32 + elif target.kind.name == "opencl" and "mali" in str(target.attrs): + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + else: + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + + if not isinstance(len_S, int): + TS, TR = 1, 64 + + while TS * TR > target.max_num_threads: + if TS > 1: + TS //= 2 + else: + TR //= 2 + + TILE_S, TILE_R = ( + 1, + (len_c if len_c > 1 else max( + get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1)), + ) + VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + TILE_S=TILE_S, + TILE_R=TILE_R, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + ) + + def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the outer reduction block.""" + # NOTE: Only Android is supported so far + if not (target.kind.name == "opencl" and "android" in str(target.host)): + return None + batch, s, r, c = sch.get_loops(block) + len_s = get_extent(sch, s) + + # The config is designed for Adreno + tx_len = 64 + vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1 + inner_r = 4 + + bx, tx, vec = sch.split(s, factors=[None, tx_len, vec_len]) + r0, r1 = sch.split(r, factors=[None, inner_r]) + sch.bind(batch, "blockIdx.y") + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.reorder(bx, tx, r0, r1, c, vec) + + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + + cache_v = sch.cache_read(block, vector_input_buffers[0], "local") + sch.compute_at(cache_v, r1, preserve_unit_loops=True) + sch.vectorize(sch.get_loops(cache_v)[-1]) + + sch.vectorize(vec) + + # Schedule epilogue + if epilogue_info is not None: + sch.reverse_compute_at(epilogue_info.block_rv, tx) + + sch.set_scope(block, 0, "local") + + sch.decompose_reduction(block, r0) + + return sch + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + # skip analysis for following blocks + break + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + vec = 1 + if len(config.vectorize): + vec = list(config.vectorize.values())[-1] + + num_warps = int(prod(config.thread)) + warp_size = int(prod(config.reduce_thread)) + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) + # compute inline + for block_info in reversed(block_infos): + block = block_info.block_rv + if block not in (reduction_block, *output_blocks): + sch.compute_inline(block) + try: + i, j, k = sch.get_loops(block_b) + except Exception: + j, k = sch.get_loops(block_b) + block_local_A = sch.cache_read(block_b, 0, "local") + block_local_B = sch.cache_read(block_b, 1, "local") + block_local_C = sch.cache_write(block_b, 0, "local") + # reverse inline + if reduction_block is not None and reduction_block != output_blocks[0]: + sch.reverse_compute_inline(output_blocks[0]) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_local_B)[-1] + sch.vectorize(block_local_b_v) + + return sch + + def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + # skip analysis for following blocks + break + + C = reduction_block + CL = sch.cache_write(reduction_block, 0, "local") + + blck_axis = [] + vthd_axis = [] + thrd_axis = [] + tile_axis = [] + # for gemv, we should skip dynamic symbolic in s_loops + s_loops = [loop for loop in s_loops if isinstance(sch.get(loop).extent, tir.IntImm)] + assert len(s_loops) == len(config.block), f"{len(s_loops)} != {len(config.block)}" + for i, loop in enumerate(s_loops): + if sch.get(loop).extent % config.block[i]: + raise NotImplementedError("Undivisible block in TIR schedule is still buggy.") + bx, _t = sch.split(loop, factors=[None, config.block[i]]) + blck_axis.append(bx) + if config.step[i] > 1: + _t, tn = sch.split(_t, factors=[None, config.step[i]]) + tile_axis.append(tn) + if config.block[i] <= config.thread[i] * config.step[i]: + tx = _t + else: + vx, tx = sch.split(_t, factors=[None, config.thread[i]]) + vthd_axis.append(vx) + thrd_axis.append(tx) + + reduce_outer_axis, reduce_inner_axis = [], [] + + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + vthd_axis = list(reversed(vthd_axis)) # inner virtual thread first + axis_order = ( + blck_axis + vthd_axis + thrd_axis + reduce_outer_axis + reduce_inner_axis + tile_axis) + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*blck_axis) + thrd_fused = sch.fuse(*thrd_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.x") + if len(vthd_axis) > 3: + vthd_axis = vthd_axis[0:2] + [sch.fuse(*vthd_axis[2:])] + for i, ax in enumerate(vthd_axis): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + for ax in tile_axis: + sch.unroll(ax) + + sch.reverse_compute_at(CL, thrd_fused) + if len(tile_axis) > 0: + for ax in sch.get_loops(CL)[-len(tile_axis):]: + sch.unroll(ax) + + sch.decompose_reduction(C, reduce_outer_axis[0]) + + try_inline_contiguous_spatial(sch, block_infos) + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + if not isinstance(func, tir.PrimFunc): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3, 4]: + # either [SK, B, S, R] = [SK, B, S, R] * [SK, B, R] + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + + if is_gemv(sch, block_info) is None: + return None + + if "dequantize_info" in func.attrs: + dequantize_rule = GEMVWithDequantizeInfo() + return dequantize_rule.apply_config(func, config) + + if any([t > 1 for t in config.reduce_thread]): + return self.sch_inner_reduction_with_config(func, config) + + return self.sch_outer_reduction_with_config(func, config) diff --git a/bitblas/gpu/gemv_dequantize.py b/bitblas/gpu/gemv_dequantize.py new file mode 100644 index 000000000..5ccc5b40e --- /dev/null +++ b/bitblas/gpu/gemv_dequantize.py @@ -0,0 +1,369 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""A rule for GEMV and DecodeGEMV.""" +from functools import reduce +from typing import List, Dict +from tvm.target import Target +from tvm.tir.function import PrimFunc +from tvm import DataType, tir +import logging +from ..base import ( + normalize_prim_func, + get_output_blocks, + get_block, +) +from .base import GPUScheduleRule +from .matmul_analysis import auto_inline_producers, auto_inline_consumers + +logger = logging.getLogger(__name__) + + +class GEMVWithDequantizeInfo(GPUScheduleRule): + """A rule for Dequantized GEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ): + sch = tir.Schedule(func) + from .intrin import get_lop3_intrin_group + + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + if not check_dequantize_info(dequantize_info): + logger.debug("Dequantize info is not valid") + return None + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + if not check_weight_decode_info(weight_decode_info): + logger.debug("Weight Dequantize info is not valid") + return None + + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + def get_vectorize_factor(target_format): + # coalesced access requires the vectorize factor to be the same as the transaction size + return 128 // DataType(target_format).bits + + vec = get_vectorize_factor(weight_decode_info["target_format"]) + num_warps = 1 + warp_size = 32 + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 + B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) + + block_decode_B = sch.cache_read(block_b, 1, "local") + sch.compute_inline(B_decode_block) + + j, k = sch.get_loops(block_b)[-2:] + if len(sch.get_loops(block_b)) == 3: + i = sch.get_loops(block_b)[0] + sch.bind(i, "blockIdx.z") + elif len(sch.get_loops(block_b)) == 4: + # splitk case + sk, i = sch.get_loops(block_b)[:2] + sch.bind(sk, "blockIdx.y") + sch.bind(i, "blockIdx.z") + + # get target dequantize buffer's idx + def get_idx(weight_decode_info: Dict): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + block_shared_local_A = sch.cache_read(block_b, 0, "local") + block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") + block_local_C = sch.cache_write(block_b, 0, "local") + + auto_inline_producers(sch, block_shared_local_B) + auto_inline_consumers(sch, block_local_C) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + # for dp4a/hfma2 + inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4 + _, vk = sch.split(vk, factors=[None, inst_factor]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_shared_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_shared_local_B)[-1] + sch.vectorize(block_local_b_v) + + skip_blocks = [block_shared_local_B] + + if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + block_local_scales = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 1, "local") + sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + skip_blocks.append(block_local_scales) + + if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: + block_local_zeros = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 2, "local") + sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + skip_blocks.append(block_local_zeros) + + auto_inline_producers(sch, block_decode_B, skip_blocks) + + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) + sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) + return sch + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + sch = tir.Schedule(func) + from .intrin import get_lop3_intrin_group + + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + if not check_dequantize_info(dequantize_info): + logger.debug("Dequantize info is not valid") + return None + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + if not check_weight_decode_info(weight_decode_info): + logger.debug("Weight Dequantize info is not valid") + return None + + block_infos = normalize_prim_func(sch) + + if block_infos is None: + return None + + reduction_block: tir.schedule.BlockRV = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block = block.block_rv + + if (any([ + sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) + ]) or len(sch.get_loops(block)) == 0): + continue + + for loop, iter_type in zip(sch.get_loops(block), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block)) + if len(r_loops) > 0: + reduction_block = block + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + def get_vectorize_factor(target_format): + # coalesced access requires the vectorize factor to be the same as the transaction size + return config.arch.transaction_size[-1] // DataType(target_format).bits + + vec = get_vectorize_factor(weight_decode_info["target_format"]) + num_warps = int(prod(config.thread)) + warp_size = int(prod(config.reduce_thread)) + + block_b = reduction_block + output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 + B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) + + block_decode_B = sch.cache_read(block_b, 1, "local") + sch.compute_inline(B_decode_block) + + j, k = sch.get_loops(block_b)[-2:] + if len(sch.get_loops(block_b)) == 3: + i = sch.get_loops(block_b)[0] + sch.bind(i, "blockIdx.z") + elif len(sch.get_loops(block_b)) == 4: + # splitk case + sk, i = sch.get_loops(block_b)[:2] + sch.bind(sk, "blockIdx.y") + sch.bind(i, "blockIdx.z") + assert len(config.thread) == 2, "SplitK only support 2D thread config" + num_warps = int(num_warps // config.thread[0]) + + # get target dequantize buffer's idx + def get_idx(weight_decode_info: Dict): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + block_shared_local_A = sch.cache_read(block_b, 0, "local") + block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") + block_local_C = sch.cache_write(block_b, 0, "local") + + auto_inline_producers(sch, block_shared_local_B) + auto_inline_consumers(sch, block_local_C) + + bx, j = sch.split(j, factors=[None, num_warps]) + k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) + # for dp4a/hfma2 + inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4 + _, vk = sch.split(vk, factors=[None, inst_factor]) + sch.reorder(bx, j, k, tx) + + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.bind(j, "threadIdx.y") + + self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] + self.grid_size = [sch.get(bx).extent, 1, 1] + + sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) + sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) + sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) + + block_local_a_v = sch.get_loops(block_shared_local_A)[-1] + sch.vectorize(block_local_a_v) + block_local_b_v = sch.get_loops(block_shared_local_B)[-1] + sch.vectorize(block_local_b_v) + + skip_blocks = [block_shared_local_B] + + if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": + if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: + block_local_scales = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 1, "local") + sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + skip_blocks.append(block_local_scales) + + if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: + block_local_zeros = sch.cache_read(block_decode_B, + get_idx(weight_decode_info) + 2, "local") + sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + skip_blocks.append(block_local_zeros) + + auto_inline_producers(sch, block_decode_B, skip_blocks) + + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) + sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) + return sch + + def apply_config(self, func: PrimFunc, config): + if any([t > 1 for t in config.reduce_thread]): + return self.sch_inner_reduction_with_config(func, config) + else: + return None diff --git a/bitblas/gpu/general_reduction.py b/bitblas/gpu/general_reduction.py new file mode 100644 index 000000000..cc03acd99 --- /dev/null +++ b/bitblas/gpu/general_reduction.py @@ -0,0 +1,465 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=invalid-name +"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" +from typing import List, Union +from functools import reduce + +from tvm import tir +from tvm.target import Target + +from ..base import normalize_prim_func, try_inline_contiguous_spatial +from ..base.analysis import get_root_block, get_reduction_blocks, BlockInfo +from .base import GPUScheduleRule + + +class GeneralReduction(GPUScheduleRule): + """General Reduction rule for operators including softmax, layer norm, RMS norm, etc""" + + def apply( # pylint: disable=too-many-locals + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + + if target.kind.name == "cuda": + len_tx = 256 + unroll_depth = 256 + else: + len_tx = 64 + unroll_depth = 64 + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + dom_kind = block_infos[0].dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + # Align the number of block iters of the last block. + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + try: + # TODO: fix num_leading_s = 0 case + assert num_trailing_r > 0 + for block in block_infos[1:-1]: + assert block.dom_kind() == dom_kind + assert block_infos[-1].is_injective() + assert len(block_infos[-1].dom_kind()) <= len(dom_kind) + except AssertionError: + return None + + loops = sch.get_loops(block_infos[-1].block_rv) + bx = sch.fuse(*loops[:num_leading_s]) + r_loop, tx = sch.split(loops[-1], [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + for block in reversed(block_infos[:-1]): + block = block.block_rv + for i, _ in enumerate(sch.get(block).writes): + sch.set_scope(block, buffer_index=i, storage_scope="shared") + sch.compute_at(block, bx, preserve_unit_loops=True) + r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) + r_loop, tx = sch.split(r_loop, [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + # TODO: It's just a workaround to avoid unroll spatial loops, because of the bug of + # the pass lower-thread-allreduce. We should fix it in the future. + # sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + # sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1) + return sch + + def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + reduce_therad_factors = config.reduce_thread + + # For inter thread reduction case, one thread must only compute one element + assert thread_factors == block_factors + + # inline all the other blocks + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + schedule_block: tir.schedule.BlockRV = None + reduction_blocks: List[tir.schedule.BlockRV] = [] + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block_rv = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ] + ) + or len(sch.get_loops(block.block_rv)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block_rv)) + if len(r_loops) > 0: + # always use the last reduction block for scheduling + schedule_block = block + reduction_blocks.append(block_rv) + + # Align the number of block iters of the last block. + dom_kind = schedule_block.dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + schedule_block = schedule_block.block_rv + loops = sch.get_loops(schedule_block) + s_loops = loops[:num_leading_s] + r_loops = loops[-num_trailing_r:] + + block_axis = [] + thread_axis = [] + + for s_loop, block_factor in zip(s_loops, block_factors): + block_loop, thread_loop = sch.split(s_loop, factors=[None, block_factor]) + block_axis.append(block_loop) + thread_axis.append(thread_loop) + + axis_order = block_axis + thread_axis + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*block_axis) + thrd_fused = sch.fuse(*thread_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.y") + + reduce_outer_axis, reduce_inner_axis, reduce_inter_threads = [], [], [] + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + ri, thd = sch.split(ri, factors=[None, config.reduce_thread[i]]) + reduce_inter_threads.append(thd) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + axis_order = reduce_inter_threads + reduce_outer_axis + reduce_inner_axis + sch.reorder(*axis_order) + fused_reduce_inter_threads = sch.fuse(*reduce_inter_threads) + sch.bind(fused_reduce_inter_threads, "threadIdx.x") + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + reg_tile = sch.cache_write(schedule_block, 0, "local") + + # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. + for i, input_region in enumerate(sch.get(schedule_block).reads): + if input_region.buffer.name not in config.cached_tensors: + continue + + # otherwise cooperative fetch in shared memory. + cache_shared = sch.cache_read(schedule_block, i, "shared") + sch.compute_at(cache_shared, reduce_outer_axis[-1]) + + dim_offset = ( + len(reduce_inner_axis) + len(reduce_outer_axis) + 2 + ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis + if input_region.buffer.name in config.vectorize: + vectorize = config.vectorize[input_region.buffer.name] + else: + vectorize = 1 + + loops = sch.get_loops(cache_shared) + if len(loops) == dim_offset: + # handle fetching only one element + loops.append(sch.add_unit_loop(schedule_block)) + assert len(loops) > dim_offset + + _, ty, tx, tv = sch.split( + sch.fuse(*loops[dim_offset:]), + factors=[ + None, + int(prod(thread_factors)), + int(prod(reduce_therad_factors)), + vectorize, + ], + ) + sch.vectorize(tv) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + sch.reverse_compute_at(reg_tile, thrd_fused) + + # resolve compute_at + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + return sch + + def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + step_factors = config.step + + # inline all the other blocks + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + + schedule_block: BlockInfo = None + for block in block_infos: + s_loops: List[tir.schedule.LoopRV] = [] + r_loops: List[tir.schedule.LoopRV] = [] + o_loops: List[tir.schedule.LoopRV] = [] + dom_kind = block.dom_kind() + block_rv = block.block_rv + + if ( + any( + [ + sch.get(loop_rv).thread_binding is not None + for loop_rv in sch.get_loops(block_rv) + ] + ) + or len(sch.get_loops(block.block_rv)) == 0 + ): + continue + + for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): + {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) + + if not s_loops: + s_loops.append(sch.add_unit_loop(block_rv)) + if len(r_loops) > 0: + # always use the last reduction block for scheduling + schedule_block = block + + # Align the number of block iters of the last block. + dom_kind = schedule_block.dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + schedule_block = schedule_block.block_rv + loops = sch.get_loops(schedule_block) + s_loops = loops[:num_leading_s] + r_loops = loops[-num_trailing_r:] + + reg_tile = sch.cache_write(schedule_block, 0, "local") + + block_axis = [] + vthread_axis = [] + thread_axis = [] + inner_axis = [] + for s_loop, block_factor, step_factor, thread_factor in zip( + s_loops, block_factors, step_factors, thread_factors + ): + block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) + vthread_loop, inner_loop = sch.split( + inner_loop, factors=[None, thread_factor * step_factor] + ) + thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) + block_axis.append(block_loop) + vthread_axis.append(vthread_loop) + thread_axis.append(thread_loop) + inner_axis.append(inner_loop) + + reduce_outer_axis, reduce_inner_axis = [], [] + for i in config.raxis_order: + loop = r_loops[i] + ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) + reduce_outer_axis.append(ro) + reduce_inner_axis.append(ri) + + vthread_axis = list(reversed(vthread_axis)) # inner virtual thread first + axis_order = ( + block_axis + + vthread_axis + + thread_axis + + reduce_outer_axis + + reduce_inner_axis + + inner_axis + ) + + sch.reorder(*axis_order) + blck_fused = sch.fuse(*block_axis) + thrd_fused = sch.fuse(*thread_axis) + sch.bind(blck_fused, "blockIdx.x") + sch.bind(thrd_fused, "threadIdx.x") + if len(vthread_axis) > 3: + vthread_axis = vthread_axis[0:2] + [sch.fuse(*vthread_axis[2:])] + for i, ax in enumerate(vthread_axis): + sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) + + # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. + for i, input_region in enumerate(sch.get(schedule_block).reads): + if input_region.buffer.name not in config.cached_tensors: + continue + + # otherwise cooperative fetch in shared memory. + cache_shared = sch.cache_read(schedule_block, i, "shared") + sch.compute_at(cache_shared, reduce_outer_axis[-1]) + + dim_offset = ( + len(vthread_axis) + len(reduce_outer_axis) + 2 + ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis + if input_region.buffer.name in config.vectorize: + vectorize = config.vectorize[input_region.buffer.name] + else: + vectorize = 1 + + loops = sch.get_loops(cache_shared) + if len(loops) == dim_offset: + # handle fetching only one element + loops.append(sch.add_unit_loop(schedule_block)) + assert len(loops) > dim_offset + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + _, tx, tv = sch.split( + sch.fuse(*loops[dim_offset:]), factors=[None, int(prod(thread_factors)), vectorize] + ) + sch.vectorize(tv) + sch.bind(tx, "threadIdx.x") + + sch.reverse_compute_at(reg_tile, thrd_fused) + + sch.decompose_reduction(schedule_block, reduce_outer_axis[0]) + + # resolve compute_at + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + return sch + + def sch_mutiple_reductions_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + config, + ): + block_factors = config.block + thread_factors = config.thread + reduce_therad_factors = config.reduce_thread + + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None or len(block_infos) == 0: + return None + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + len_tx = prod(thread_factors) * prod(reduce_therad_factors) + block_factor = prod(block_factors) + + dom_kind = block_infos[0].dom_kind() + num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) + num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) + + # Align the number of block iters of the last block. + num_last_block_iter = len(block_infos[-1].dom_kind()) + if num_last_block_iter < len(dom_kind): + index_map = tir.IndexMap.from_func( + lambda *iters: ( + [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) + + list(iters) + ), + ndim=num_last_block_iter, + ) + sch.transform_block_layout(block_infos[-1].block_rv, index_map) + + try: + # TODO: fix num_leading_s = 0 case + assert num_trailing_r > 0 + for block in block_infos[1:-1]: + assert block.dom_kind() == dom_kind + assert block_infos[-1].is_injective() + assert len(block_infos[-1].dom_kind()) <= len(dom_kind) + except AssertionError: + return None + + loops = sch.get_loops(block_infos[-1].block_rv) + bx, _ = sch.split(sch.fuse(*loops[:num_leading_s]), factors=[None, block_factor]) + r_loop, tx = sch.split(loops[-1], [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + for block in reversed(block_infos[:-1]): + block = block.block_rv + for i, _ in enumerate(sch.get(block).writes): + sch.set_scope(block, buffer_index=i, storage_scope="shared") + sch.compute_at(block, bx, preserve_unit_loops=True) + r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) + r_loop, tx = sch.split(r_loop, [None, len_tx]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + # check the number of reduction blocks + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + if len(reduction_blocks) > 1: + # schedule for multiple reduction blocks (e.g. softmax) + return self.sch_mutiple_reductions_with_config(func, config) + + if any([t > 1 for t in config.reduce_thread]): + # todo(lei) should implement block reduction schedule + return self.sch_inner_reduction_with_config(func, config) + else: + return self.sch_outer_reduction_with_config(func, config) diff --git a/bitblas/gpu/intrin/__init__.py b/bitblas/gpu/intrin/__init__.py new file mode 100644 index 000000000..d9d9ba942 --- /dev/null +++ b/bitblas/gpu/intrin/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .lop3 import get_lop3_intrin_group # noqa: F401 diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py new file mode 100644 index 000000000..b5426cf59 --- /dev/null +++ b/bitblas/gpu/intrin/lop3.py @@ -0,0 +1,1667 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.tir.function import TensorIntrin +from tvm.script import tir as T +from typing import Dict, Literal +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_packed_to_unsigned_convert_with_zeros, +) + +decode_i4_to_f16 = """ +template +__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4s, B_local_decode, N); +} + +template +__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) +{ + decode_i4b_to_f16(_i4u, B_local_decode, N); +} +""" + +decode_i4_to_f16_scale = """ +template +__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); +} + +template +__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); +} + +""" + +decode_i4_to_f16_scale_zeros_original = """ +template +__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i4_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} + +""" + +decode_i4_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x000f000f; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + // Minus 7 to scale the value to signed + uint const i4s = *reinterpret_cast(_i4s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) +{ + decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i2_to_f16 = """ +template +__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2s, B_local_decode, N); +} + +template +__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) +{ + decode_i2b_to_f16(_i2u, B_local_decode, N); +} +""" + +decode_i2_to_f16_scale = """ +template +__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); +} + +template +__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) +{ + decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); +} +""" + +decode_i2_to_f16_scale_zeros_original = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); + } +} + +template +__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_rescale(_i2u, B_local_decode, scale, zeros, N); +} +""" + +decode_i2_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1_to_f16 = """ +template +__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + } +} + +template +__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + } +} +""" + +decode_i1_to_f16_scale = """ +template +__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} + +template +__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +""" + +decode_i1_to_f16_scale_zeros_original = """ +template +__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + // input zeros maybe int32(qzeros) or half format + T4 const zero_r = *zeros; + uint const packed_zeros = __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); +} +""" +decode_i1_to_f16_scale_zeros_rescale = """ +template +__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00010001; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = 0x64006400; + // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 + int8_t const i1s_i16 = *reinterpret_cast(_i1s); + int i1s = (i1s_i16 & 0x0f); + i1s |= ((i1s_i16 & 0xf0) << 12); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); + +#pragma unroll + // decode 2 elems at one time. + for (int i = 0; i < (N / 2); i++) + { + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); + } +} + +template +__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); +} +""" + +decode_i1s_to_i8s = """template +__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int i8s[4]; + // vector load + *reinterpret_cast(i8s) = *reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1 + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vadd4(i8s[i], i8s[i]); + i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT); + } + *reinterpret_cast(_i8s) = *reinterpret_cast(i8s); +} + +template +__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) +{ + int *i8s = reinterpret_cast(_i8s); + int16_t i1b_i16 = *reinterpret_cast(_i1b); + // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} + // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} + int i1b = (i1b_i16 & 0x0f0f); + i1b |= ((i1b_i16 & 0xf0f0) << 12); + // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} + // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // First, we extract the i1b and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; + static constexpr uint MEDIAN_NUM = 0x00000000; + + for (int i = 0; i < N / 4; i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} + +""" + +decode_i2s_to_i8s = """template +__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + static constexpr uint MEDIAN_NUM = 0x02020202; +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); + } +} +template +__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) +{ + // convert 8 int2b_t to 8 int8b_t -> 2 int32 + uint *i8s = reinterpret_cast(_i8s); + + // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} + // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} + uint const i2b = *reinterpret_cast(_i2b); + + // First, we extract the i4s and construct an intermediate fp16 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 + static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 + static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 + +#pragma unroll + for (int i = 0; i < (N / 4); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + +decode_i4s_to_i8s = """template +__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 + static constexpr uint MEDIAN_NUM = 0x07070707; +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); + i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM); + } +} + +template +__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) +{ + uint *i8s = reinterpret_cast(_i8s); + uint *i4b = reinterpret_cast(_i4b); + // First, we extract the i4s and construct an intermediate i8 number. + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 + static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 +#pragma unroll + for (int i = 0; i < (N / 8); i++) + { + // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i]) + : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(i8s[i + 2]) + : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); + } +} +""" + + +def get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="uint", + target_dtype="float16", + loops_extent=8, + with_scale=False, + with_zeros=False, + zeros_mode="original", +): + """ + loops extent is the number of elements to be decoded in one stage + for memory friendly process, the loops_extent should be a multiple of (sizeof(int) // 8). + However, for the case of int1b, it is not possible to decode 8 elements in one stage, so we have to use 16. + """ + if target_dtype == "float16": + d4f = "f16" + elif target_dtype == "int8": + d4f = "i8s" + else: + raise ValueError("Unsupported target dtype: {}".format(target_dtype)) + source_symbol = "u" if source_format == "uint" else "s" + func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) + if with_scale: + func_name += "_scale" + if with_zeros: + func_name += f"_zeros_{zeros_mode}" + assert storage_dtype in ["int8", "int32", "uint32"] + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + elem_per_unit = storage_nbit // source_bit + n_storage_elems = loops_extent // elem_per_unit + if with_zeros and zeros_mode == "quantized": + decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) + elif source_format == "int": + if source_bit == 1: + decode_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) + else: + decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) + elif source_format == "uint": + decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if with_scale is False: + + @T.prim_func + def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:n_storage_elems]) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + dtype=target_dtype, + ) + + @T.prim_func + def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:n_storage_elems]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + loops_extent, + ) + + elif with_zeros is False: + + @T.prim_func + def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + scope="global", + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = ( + decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + dtype=target_dtype, + ) * Scale[0]) + + @T.prim_func + def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: + s0 = T.int32() + + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s0], + scope="global", + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + Scale.access_ptr("r"), + loops_extent, + ) + + elif zeros_mode == "quantized": + + def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return [weight, zeros, scale] + elif zeros_mode == "rescale": + return [weight, scale, zeros] + elif zeros_mode == "quantized": + return [weight, zeros, scale] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return (weight - zeros) * scale + elif zeros_mode == "rescale": + return weight * scale - zeros + elif zeros_mode == "quantized": + return weight * scale + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + # Scale with Zeros + @T.prim_func + def fast_decode_desc( + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, + ) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + scope="local", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=storage_dtype, + scope="local", + ) + with T.block("root"): + T.reads(*get_dequantize_buffers_list( + Compressed[0:n_storage_elems], + Scale[0:1], + Zeros[0:1], + zeros_mode=zeros_mode, + )) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = get_dequantize_func( + decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + Zeros[0], + dtype=target_dtype, + ), + Scale[0], + Zeros[0], + zeros_mode, + ) + + @T.prim_func + def fast_decode_impl( + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, + ) -> None: + s0 = T.int32() + s1 = T.int32() + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s0], + scope="local", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=storage_dtype, + offset_factor=1, + strides=[s1], + scope="local", + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + Scale.access_ptr("r"), + Zeros.access_ptr("r"), + loops_extent, + ) + + else: + + def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return [weight, zeros, scale] + elif zeros_mode == "rescale": + return [weight, scale, zeros] + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): + if zeros_mode == "original": + return (weight - zeros) * scale + elif zeros_mode == "rescale": + return weight * scale - zeros + else: + raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") + + # Scale with Zeros + @T.prim_func + def fast_decode_desc( + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, + ) -> None: + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + scope="global", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=target_dtype, + scope="global", + ) + with T.block("root"): + T.reads(*get_dequantize_buffers_list( + Compressed[0:n_storage_elems], + Scale[0:1], + Zeros[0:1], + zeros_mode=zeros_mode, + )) + T.writes(Decompressed[0:loops_extent]) + for i in T.grid(loops_extent): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = get_dequantize_func( + decode_func( + source_bit, + Compressed[vi // elem_per_unit], + vi % elem_per_unit, + dtype=target_dtype, + ), + Scale[0], + Zeros[0], + zeros_mode, + ) + + @T.prim_func + def fast_decode_impl( + compressed: T.handle, + decompressed: T.handle, + scale: T.handle, + zeros: T.handle, + ) -> None: + s0 = T.int32() + s1 = T.int32() + Compressed = T.match_buffer( + compressed, + [ + n_storage_elems, + ], + dtype=storage_dtype, + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + loops_extent, + ], + dtype=target_dtype, + scope="local", + ) + Scale = T.match_buffer( + scale, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s0], + scope="global", + ) + Zeros = T.match_buffer( + zeros, + [ + 1, + ], + dtype=target_dtype, + offset_factor=1, + strides=[s1], + scope="global", + ) + with T.block("root"): + T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) + T.writes(Decompressed[0:loops_extent]) + T.call_extern( + "handle", + func_name, + Compressed.data, + Decompressed.data, + Scale.access_ptr("r"), + Zeros.access_ptr("r"), + loops_extent, + ) + + return fast_decode_desc, fast_decode_impl + + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u2_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u1_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=1, storage_dtype="int8", target_dtype="float16", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int32_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int32", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_uint32_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="uint32", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="original", + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="rescale", + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( + "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_quantized_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="quantized", + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="original", + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="rescale", + ), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="quantized", + ), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_original_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="original", + ), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( + "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_rescale_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="rescale", + ), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8), +) + +LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16), +) + +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u2_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i2_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), +) + +LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), +) + +LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + source_format="int", + storage_dtype="int8", + target_dtype="int8", + loops_extent=16), +) + +LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=4, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i2_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + ), +) + +LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( + "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") +TensorIntrin.register( + LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, + *get_fast_decode_intrin( + source_bit=1, + storage_dtype="int8", + source_format="int", + target_dtype="float16", + loops_extent=8, + with_scale=True, + ), +) + + +def get_lop3_intrin_group( + out_dtype: Literal["float16", "int8"], + source_format: Literal["int", "uint"] = "uint", + source_bit: int = 4, + storage_dtype: Literal["int32", "int8"] = "int8", + with_scaling: bool = False, + with_zeros: bool = False, + zeros_mode: Literal["original", "rescale", "quantized"] = "original", +) -> Dict[str, str]: + """ + This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. + LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of + intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. + + Parameters + ---------- + in_dtype : Literal["int8"] + The data type of the input. It should be "int8". + + out_dtype : Literal["float16", "int8"] + The data type of the output. It can be either "float16" or "int8". + + storage_nbit : int, optional + The number of bits used for storage. By default, it is 4. + + with_scale : bool, optional + A boolean parameter that indicates whether scaling should be applied. By default, it is False. + + Returns + ------- + Dict[str, str] + A dictionary mapping the names of the intrinsics to their corresponding implementations. + """ + assert out_dtype in ["float16", "int8"] + + dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"} + target_dtype = dtype_mapping[out_dtype] + target_bits = tvm.DataType(out_dtype).bits + loop_extent = 128 // target_bits + if source_format not in ["int", "uint"]: + raise ValueError("Invalid source_format. Expected 'int' or 'uint'.") + source_symbol = "i" if source_format == "int" else "u" + + _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{target_dtype}_l{loop_extent}_" + if with_scaling: + _intrin += "scale_" + if with_zeros: + _intrin += f"zeros_{zeros_mode}_" + + import_c_map = { + "i4_to_f16": decode_i4_to_f16, + "i2_to_f16": decode_i2_to_f16, + "i1_to_f16": decode_i1_to_f16, + "i4_to_f16_scale": decode_i4_to_f16_scale, + "i2_to_f16_scale": decode_i2_to_f16_scale, + "i1_to_f16_scale": decode_i1_to_f16_scale, + "i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original, + "i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original, + "i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original, + "i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale, + "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, + "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, + "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, + "i1_to_i8": decode_i1s_to_i8s, + "i2_to_i8": decode_i2s_to_i8s, + "i4_to_i8": decode_i4s_to_i8s, + } + key = f"i{source_bit}_to_{target_dtype}" + if with_scaling: + key += "_scale" + if with_zeros: + key += f"_zeros_{zeros_mode}" + + return { + "c_source": import_c_map[key], + "compute": _intrin, + } diff --git a/bitblas/gpu/matmul.py b/bitblas/gpu/matmul.py new file mode 100644 index 000000000..ad450eff2 --- /dev/null +++ b/bitblas/gpu/matmul.py @@ -0,0 +1,372 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from dataclasses import dataclass +from typing import Optional + +from tvm import tir +from tvm.target import Target +from tvm.tir.stmt import ForKind + +from ..base import analysis +from .base import GPUScheduleRule +from . import utils +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_in_out_dtypes, + get_index_map, + normalize_to_matmul, + get_reduction_blocks, +) +from .matmul_mma import MatmulTensorizationMMA +from .matmul_wmma import ( + MatmulInt8Tensorization, + MatmulTensorizationWMMA, +) +from functools import reduce +import logging + +logger = logging.getLogger(__name__) + + +class Matmul(GPUScheduleRule): + """The schedule rule for matmul-like computation""" + + @dataclass + class Config: + block_size_x: int = 8 + block_size_y: int = 8 + vthread_x: int = 1 + vthread_y: int = 1 + micro_size_x: int = 4 + micro_size_y: int = 4 + micro_size_k: int = 8 + vector_size: int = 1 + unroll: int = 256 # 0 means no unroll + use_shared: bool = True + storage_align: bool = False + inner_x: bool = False + + def get_configs(self, target: Target) -> Config: + """Get the schedule config for the target""" + if target.kind.name == "cuda" or target.kind.name == "rocm": + return Matmul.Config( + block_size_x=8, + block_size_y=16, + vthread_x=1, + vthread_y=1, + micro_size_x=4, + micro_size_y=4, + micro_size_k=16, + vector_size=2, + unroll=256, + use_shared=True, + storage_align=True, + inner_x=False, + ) + elif target.kind.name == "opencl" and "android" in str(target.host): + return Matmul.Config( + block_size_x=8, + block_size_y=8, + vthread_x=1, + vthread_y=1, + micro_size_x=8, + micro_size_y=2, + micro_size_k=16, + vector_size=8, + unroll=64, + use_shared=False, + storage_align=False, + inner_x=True, + ) + else: + return Matmul.Config() + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + sch = normalize_to_matmul(sch, main_block) + if sch is None: + return None + + # Step 1. Check Tensor Core support + # Tensorization config: + # If any value of I, J, K is fixed and less than this threshold, + # tensorization rule will not be applied. + minimal_tensorize_threshold = 64 + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: + apply_tensorization: bool = True + # the batch dimension is not taken into consideration. + # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + if in_dtype not in ["int8", "float16"]: + apply_tensorization = False + for item_var in block_stmt.iter_vars[1:]: + extent = item_var.dom.extent + if isinstance(extent, + tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: + apply_tensorization = False + if apply_tensorization: + if in_dtype == "int8" and out_dtype == "int32": + tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) + elif utils.get_sm_version(target) >= 80: + # For A100(sm_80) or more advanced gpu, use MMA tensorization. + tensorize_sch = MatmulTensorizationMMA().apply(func, target, _) + else: + # For other GPUs, use WMMA tensorization. + tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) + if tensorize_sch is not None: + return tensorize_sch + + # Step 2. Get schedule config. + config = self.get_configs(target) + + # Step 3. Schedule matmul + y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y + x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x + if config.inner_x: + sch.pad_einsum( + main_block, + [1, y_kernel_size, x_kernel_size, config.micro_size_k], + ) + batch, y, x, k = sch.get_loops(main_block) + else: + sch.pad_einsum( + main_block, + [1, x_kernel_size, y_kernel_size, config.micro_size_k], + ) + batch, x, y, k = sch.get_loops(main_block) + by, vy, ty, yi = sch.split( + y, [None, config.vthread_y, config.block_size_y, config.micro_size_y]) + bx, vx, tx, xi = sch.split( + x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]) + ko, ki = sch.split(k, factors=[None, config.micro_size_k]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + by = sch.fuse(batch, by) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(vy, "vthread.y") + sch.bind(vx, "vthread.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y + if inner_loop % config.vector_size == 0: + _, v = sch.split(xi, [None, config.vector_size]) + sch.vectorize(v) + + if config.unroll > 0: + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=config.unroll) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + + l2g = sch.cache_write(main_block, 0, "local") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + if config.micro_size_x % config.vector_size == 0: + _, v = sch.split(sch.get_loops(l2g)[-1], [None, config.vector_size]) + sch.vectorize(v) + + if config.use_shared: + + def _cooperative_fetch(index, vec_len): + block = sch.cache_read(main_block, index, "shared") + num_loops = len(sch.get_loops(block)) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + ty, tx, _, vec = sch.split( + sch.fuse(*loops), + factors=[config.block_size_y, config.block_size_x, None, vec_len], + ) + sch.vectorize(vec) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + if config.storage_align: + sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) + return block + + a_g2s = _cooperative_fetch(0, vec_len=config.vector_size) + b_g2s = _cooperative_fetch(1, vec_len=config.vector_size) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + else: + auto_inline_producers(sch, main_block) + + auto_inline_consumer_chain(sch, l2g) + sch.decompose_reduction(main_block, ko) + + # Step 4. Check if there are unbound blocks. Execute fallback scheduling to them. + def is_scheduled(block: tir.schedule.BlockRV) -> bool: + loops = sch.get_loops(block) + loop_kinds = {sch.get(loop).kind for loop in loops} + return loop_kinds != {ForKind.SERIAL} + + blocks = sch.get_child_blocks(root_block) + max_threads_per_block = utils.max_threads_per_block(target) # noqa: F841 + for block in blocks: + if is_scheduled(block): + continue + # no axis of the block is bound to thread or block + s_loops = sch.get_loops(block) + bx, tx = sch.split( + sch.fuse(*s_loops), + factors=[ + None, + 256, + ], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> tir.Schedule: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + # in some case conv template will use this rule, but the tile config is not + # analyzed by matmul expr. + if len(config.block) != 2: + logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") + return None + + main_block = reduction_blocks[0] + + block_stmt = sch.get(main_block) + + # cuda core prefer b is [k, j] layout without swizzling. + index_maps = get_index_map(block_stmt, ["n", "n", "n"]) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Get schedule config. + block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) + block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) + thread_row_tiles = config.thread[1] // (config.step[0] * 2) + thread_col_tiles = config.thread[1] // (config.step[1] * 2) + vthread_row_tiles = (config.step[0] * 2) # expand vtrhead to avoid load band conflict + vthread_col_tiles = (config.step[1] * 2) # expand vtrhead to avoid load band conflict + chunk = config.rstep[0] + + # Step 3. Schedule matmul + BM = block_row_warps * vthread_row_tiles * thread_row_tiles + BN = block_col_warps * vthread_col_tiles * thread_col_tiles + BK = chunk + + sch.pad_einsum( + main_block, + [1, BM, BN, BK], + ) + batch, y, x, k = sch.get_loops(main_block) + by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) + bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) + ko, ki = sch.split(k, factors=[None, BK]) + sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) + by = sch.fuse(batch, by) + sch.bind(bx, "blockIdx.x") + sch.bind(by, "blockIdx.y") + sch.bind(vy, "vthread.y") + sch.bind(vx, "vthread.x") + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + def prod(iterable): + return reduce(lambda x, y: x * y, iterable, 1) + + l2g = sch.cache_write(main_block, 0, "local") + sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) + + def _cooperative_fetch(index, vec_len): + block = sch.cache_read(main_block, index, "shared") + num_loops = len(sch.get_loops(block)) + block_local = sch.cache_read(main_block, index, "local") + sch.compute_at(block_local, ki, preserve_unit_loops=True) + sch.compute_at(block, ko, preserve_unit_loops=True) + loops = sch.get_loops(block)[-num_loops:] + _, ty, tx, vec = sch.split( + sch.fuse(*loops), + factors=[None, block_row_warps, block_col_warps, vec_len], + ) + + auto_inline_producers(sch, block) + + def is_trivial_load(block): + # avoid vectorize under global[v2, v1]] shared[v1, v2] case + reads = sch.get(block).reads + writes = sch.get(block).writes + if len(reads) != 1 or len(writes) != 1: + return False + return all( + read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) + + if is_trivial_load(block): + sch.vectorize(vec) + + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + + _, vec = sch.split( + sch.fuse(*sch.get_loops(block_local)[-2:]), + [None, vec_len // prod(config.step)], + ) + sch.vectorize(vec) + + return block + + for i, input_region in enumerate(sch.get(main_block).reads): + _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") + if _buffer_name not in config.cached_tensors: + logger.warning( + f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." + ) + continue + + # otherwise cooperative fetch in shared memory. + vectorize = config.vectorize.get(_buffer_name, 1) + + _cooperative_fetch(i, vec_len=vectorize) + + auto_inline_consumer_chain(sch, l2g) + + _, vec = sch.split( + sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) + sch.vectorize(vec) + + sch.decompose_reduction(main_block, ko) + return sch diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py new file mode 100644 index 000000000..6537a555a --- /dev/null +++ b/bitblas/gpu/matmul_analysis.py @@ -0,0 +1,786 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Set, Union, Tuple, Dict +from tvm import tir +from tvm.ir import Range +from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap +from tvm.tir.analysis import undefined_vars +from tvm.tir.schedule.schedule import BlockRV +from ..base.analysis import ( + collect_block_iter_vars_used_in_access_region, + get_root_block, + get_reduction_blocks, +) +from tvm.target.target import Target +from tvm.tir.stmt_functor import pre_order_visit +import logging + +logger = logging.getLogger(__name__) + + +def collect_vars_from_expr(prim_expr): + vars = [] + + def callback(node): + if isinstance(node, Var): + vars.append(node) + return True + + pre_order_visit(prim_expr, callback) + + return vars + + +def _is_one(x: PrimExpr) -> bool: + return isinstance(x, tir.IntImm) and x.value == 1 + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def auto_inline_producers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, +): + skip_blocks = skip_blocks or [] + while True: + inlined_cnt = 0 + producers = _collect_producers(sch, block) + for producer in producers: + if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): + continue + try: + sch.compute_inline(producer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + while True: + inlined_cnt = 0 + consumers = _collect_consumers(sch, block) + for consumer in consumers: + try: + sch.compute_inline(consumer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + for consumer in consumers: + try: + sch.reverse_compute_inline(consumer) + inlined_cnt += 1 + except Exception: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumer_chain( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + auto_inline_consumers(sch, block) + remaining_consumers = sch.get_consumers(block) + + if len(remaining_consumers) != 0: + # Some blocks have failed to be inlined to the producer cache-write stage. + # This could be due to another producer block that has not been scheduled. + for c in remaining_consumers: + for p in sch.get_producers(c): + if sch.get(p) != sch.get(block): + sch.compute_inline(p) + + # Try inlining into the cache-write stage again, this time it should succeed. + auto_inline_consumers(sch, block) + + +# used to match the similar region with dequantize op. +def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): + for region in regions: + if len(region.buffer.shape) == len(buffer.shape): + return region + return None + + +# used to match the similar buffer with dequantize op. +def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): + for region in regions: + if len(region.buffer.shape) == len(buffer.shape): + return region.buffer + return None + + +# find the block that required to be reindex and scope. +def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: + # block that most near to the arguments + block = main_block + buffer = buffer + + while True: + last_buffer = buffer + producers = sch.get_producers(block) + + if len(producers) == 0: + # do not have any producer means it is the first block + break + + for producer in producers: + for write in sch.get(producer).writes: + if write.buffer == buffer: + block = producer + buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer) + if buffer == last_buffer: + break + return block + + +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, + buffer: tir.Buffer) -> int: + """traverse to find the arg index from the buffer""" + producers = sch.get_producers(main_block) + + # a head buffer has no producer blocks + def find_args_index(sch: tir.Schedule, buffer: tir.Buffer): + for i, param in enumerate(sch.mod["main"].params): + if sch.mod["main"].buffer_map[param] == buffer: + return i + return None + + is_head_buffer = len(producers) == 0 + if is_head_buffer: + return find_args_index(sch, buffer) + for block in sch.get_producers(main_block): + if len(sch.get(block).reads) != 1 or len(sch.get(block).writes) != 1: + continue + for write in sch.get(block).writes: + if write.buffer == buffer: + return find_arg_idx_from_buffer_chain(sch, block, buffer) + + # if no buffer producer block found, it means the buffer is an input buffer + return find_args_index(sch, buffer) + + +class IterKind(Enum): + """Iter kinds for GEMM-liked programs. + We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], + where `I, J, K` are fundamental axes for gemm and `S` represents all + other spatial axes (e.g. batches) + kIter_S: spatial axes + kIter_I: I axes + kIter_J: J axes + kIter_K: K axes + kIter_T: trivial axes (i.e. with extent 1) + """ + + kIter_S = 0 + kIter_I = 1 + kIter_J = 2 + kIter_K = 3 + kIter_T = 4 + + +@dataclass +class IterTrait: + kind: IterKind + extent: PrimExpr + + +def make_iter_fusion_index_map( + traits: List[IterTrait], + kind_order: List[IterKind], +) -> tir.IndexMap: + fused_iters: Dict[IterKind, PrimExpr] = {} + input_iters: List[tir.Var] = [] + for i, trait in enumerate(traits): + v_i = tir.Var(f"i{i}", trait.extent.dtype) + input_iters.append(v_i) + if trait.kind == IterKind.kIter_T: + continue + if trait.kind not in kind_order: + raise ValueError(f"Unknown iter kind {trait.kind}") + if trait.kind in fused_iters: + fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i + else: + fused_iters[trait.kind] = v_i + + final_indices: List[tir.PrimExpr] = [ + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order + ] + + return tir.IndexMap(input_iters, final_indices, None) + + +def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: + """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + Returns + ------- + traits : Optional[Tuple[List[IterTrait]]] + The detected iter traits for axes in A, B and C. None if the block + does not match the pattern. + + """ + + if len(block.reads) != 2 or len(block.writes) != 1: + return None + + def get_access_axes(region: List[Range]) -> Set[Var]: + axes: Set[Var] = set() + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes = axes.union(set(undefined_vars(r.min))) + return axes + + try: + A_axes = get_access_axes(block.reads[0].region) + B_axes = get_access_axes(block.reads[1].region) + C_axes = get_access_axes(block.writes[0].region) + except ValueError: + return None + + traits: Dict[Var, IterTrait] = {} + for iter_var in block.iter_vars: + var = iter_var.var + kind: IterKind + if _is_one(iter_var.dom.extent): + if iter_var.iter_type == tir.IterVar.CommReduce: + # for simplified case (e.g. 1x1 conv kernel) + kind = IterKind.kIter_K + else: + kind = IterKind.kIter_T + elif iter_var.iter_type == iter_var.DataPar: + if var in A_axes and var in B_axes and var in C_axes: + kind = IterKind.kIter_S + elif var in A_axes and var in C_axes: + kind = IterKind.kIter_I + elif var in B_axes and var in C_axes: + kind = IterKind.kIter_J + else: + return None + elif iter_var.iter_type == tir.IterVar.CommReduce: + if var in A_axes and var in B_axes and var not in C_axes: + kind = IterKind.kIter_K + else: + return None + else: + return None + traits[var] = IterTrait(kind, iter_var.dom.extent) + + # A Gemm-kernel requires have I, J and K axes + gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} + if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: + return None + + A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] + B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] + C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] + block_traits = [traits[i.var] for i in block.iter_vars] + return A_traits, B_traits, C_traits, block_traits + + +def get_index_map(block: tir.Block, + layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: + """Get index maps for the block + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + layout : List[str] + the target layout index map to be used. + 'n' for [i, k] layout + 't' for [k, j] layout + 'a' for auto inference based on whether the last axis is reduction. + + Returns + ------- + index_maps : Optional[Tuple[tir.IndexMap]] + The index maps for the block, or None if the block is not a gemm-liked kernel + """ + if layout is None: + layout = ["n", "t", "n"] + traits = detect_iter_traits(block) + if traits is None: + return None + A_traits, B_traits, C_traits, block_traits = traits + + def get_ordered_axes(region: List[Range]) -> Set[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + for iter_var in block.iter_vars: + if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: + return True + return False + + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + return any(is_common_reduce(v) for v in vars) + + def check_last_trait(region: List[Range]): + axes = get_ordered_axes(region) + return has_common_reduce(axes[-1]) + + def infer_layout(layout: str, region: List[Range], kind: str = "A"): + """ + Infer the layout based on the region and the kind of buffer + kind: "A", "B", "C" + """ + primary_iter, secondary_iter, reduction_iter = { + "A": (IterKind.kIter_I, IterKind.kIter_K, IterKind.kIter_K), + "B": (IterKind.kIter_K, IterKind.kIter_J, IterKind.kIter_K), + "C": (IterKind.kIter_I, IterKind.kIter_J, None), + }[kind] + + spatial_iter = { + "A": IterKind.kIter_I, + "B": IterKind.kIter_J, + "C": None, + }[kind] + + if layout == "n": + return [IterKind.kIter_S, primary_iter, secondary_iter] + elif layout == "t": + return [IterKind.kIter_S, secondary_iter, primary_iter] + elif layout == "a": + # auto inference layout + # for buffer with reduction axis, we put it as the last axis + # otherwise, we put it as the first axis + if kind == "C": + return [IterKind.kIter_S, primary_iter, secondary_iter] + else: + return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter]) + else: + raise ValueError(f"Unknown layout {layout}") + + A_index_map = make_iter_fusion_index_map( + A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) + B_index_map = make_iter_fusion_index_map( + B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) + C_index_map = make_iter_fusion_index_map( + C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) + + matmul_index_map = make_iter_fusion_index_map( + block_traits, + [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K], + ) + + return ( + matmul_index_map, + A_index_map, + B_index_map, + C_index_map, + ) + + +def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: + # check at least two input and one output + # at lease one input has uint dtype, and the output dtype is float + def is_dequantize(block: BlockRV) -> bool: + block_stmt = sch.get(block) + if len(block_stmt.reads) < 2: + return False + has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) + if not has_uint_input: + return False + if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): + return False + return True + + dequantize_blocks = [block for block in blocks if is_dequantize(block)] + return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None + + +def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + if iter_types != {IterVar.DataPar}: + return False, False + if not isinstance(block_stmt.body, tir.BufferStore): + return False, False + if not isinstance(block_stmt.body.value, tir.BufferLoad): + return False, False + + def get_access_vars(region: List[Range]) -> List[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + return None + axes.extend(undefined_vars(r.min)) + # remove trivial axis + trivial_vars = set( + iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) + axes = [axis for axis in axes if axis not in trivial_vars] + # remove duplicate axis + axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] + return axes + + lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] + rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] + is_identity = list(lhs_access_vars) == list(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( + rhs_access_vars) + return is_identity, is_transpose + + +def is_identity_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[0] + + +def is_transpose_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[1] + + +def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): + result_blocks = [] + for block in blocks: + if not is_transpose_block(sch.get(block)): + result_blocks.append(block) + continue + try: + sch.compute_inline(block) + except Exception: + try: + sch.reverse_compute_inline(block) + except Exception: + result_blocks.append(block) + return result_blocks + + +def normalize_to_matmul(sch: tir.Schedule, + main_block: BlockRV, + layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: + if layout is None: + layout = ["n", "t", "n"] + block_stmt = sch.get(main_block) + + # let layout be 'a' to auto inference the layout + index_maps = get_index_map(block_stmt, layout=layout) + if index_maps is None: + logger.debug("Cannot find the appropriate index map for tensorcore") + return None + + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # `skip_simplify` to avoid the bug in the 1x1 conv + block = sch.reindex(main_block, ("read", 0), skip_simplify=True) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1), skip_simplify=True) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0), skip_simplify=True) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True) + return sch + + +def get_tensorized_func_and_tags( + func: tir.PrimFunc, + target: Target, + layout: Optional[List[str]] = None, + skip_normalize: bool = False, + allow_gemv: bool = False, +) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + """ + transform function to matmul if necessary (e.g. transform conv2d with im2col) + """ + if layout is None: + layout = ["a", "a", "a"] + # step1. detect whether the function can utilize tensorcore + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + reduction_blocks = get_reduction_blocks(sch, blocks) + if not reduction_blocks or len(reduction_blocks) != 1: + return func, None + + def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: + block_stmt = sch.get(block) + conditions = [] + conditions.append(len(block_stmt.reads) == 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append( + len( + collect_block_iter_vars_used_in_access_region(block_stmt, + block_stmt.writes[0].region)) > 0) + if not all(conditions): + return False + return True + + # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: + tags: Dict[str, Union[List[int], int]] = {} + block_stmt = sch.get(block) + + # analysis tensorcore axis + # todo(lei): maybe we can remove this in the future + (write_buffer_region,) = block_stmt.writes + out_axis = len(write_buffer_region.buffer.shape) + tags["tensorcore_config"] = [out_axis - 2, out_axis - 1] + + # analysis pipeline stage + # todo(lei): maybe we can integrate this into policy in the future + tags["pipeline_stage"] = 1 + if target.kind.name == "cuda" and check_sm_version(target.arch) == 80: + # enable pipeline stage only for sm_80 devices + tags["pipeline_stage"] = 2 + + # analysis async copy + # todo(lei): maybe we can integrate this into policy in the future + tags["use_async_copy"] = False + if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) >= 80: + # async copy only works in software pipeline. + tags["use_async_copy"] = True + + # analysis intrin information + def get_ordered_axes(region: List[Range]) -> Set[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + raise ValueError("Expect elemwise block access") + axes.append(r.min) + return axes + + def is_common_reduce(var: Var) -> bool: + for iter_var in block_stmt.iter_vars: + if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: + return True + return False + + def has_common_reduce(var: Var) -> bool: + vars = collect_vars_from_expr(var) + return any(is_common_reduce(v) for v in vars) + + def check_last_trait(region: List[Range]): + axes = get_ordered_axes(region) + return has_common_reduce(axes[-1]) + + intrin_info: dict = {} + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + intrin_info["in_dtype"] = in_dtype + intrin_info["out_dtype"] = out_dtype + # if the last dimension is reduce axis, the B is transposed + intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) + if func.attrs is not None and "input_transform_kind" in func.attrs: + intrin_info["input_transform_kind"] = func.attrs["input_transform_kind"] + if func.attrs is not None and "weight_transform_kind" in func.attrs: + intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"] + tags["intrin_info"] = intrin_info + # Analysis Block Reduction Optimization + # Currently, we only support block reduction depth 2 for small M + # When the func is a dequantize like ops, we should consider the M + if hasattr(func.attrs, "dequantize_info"): + for arg in func.params: + inp_shape = func.buffer_map[arg].shape + M = inp_shape[0] + if isinstance(M, tir.IntImm) and M <= 128: + tags["block_reduction_depth"] = 2 + break + + return tags + + (main_block,) = reduction_blocks + if _can_be_tensorized(sch, main_block) is None: + return func, None + + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + # TODO(lei): we should consider the dtype of the input a and b + # instead of assuming both a and b share the same dtype. + # As the tensorcore may supports e4m3_float8 * e5m2_float8 + in_dtype, out_dtype = get_in_out_dtypes(block_stmt) + try: + _ = get_mma_intrin_group( + a_dtype=in_dtype, + b_dtype=in_dtype, + out_dtype=out_dtype, + ) + except Exception: + logger.debug("Cannot find the corresponding mma intrin group") + return func, None + + # reindex and transform functions + # Normalize tensor functions to C[S, I, J] += A[S, I, K] * B[S, J, K] + # or C[S, I, J] += A[S, I, K] * B[S, K, J] + # skip normalize when we want to detect tags only. + if not skip_normalize: + sch = normalize_to_matmul(sch, main_block, layout) + if sch is None: + return func, None + + block_stmt = sch.get(main_block) + + minimal_tensorize_threshold = 16 + # the batch dimension is not taken into consideration. + extent = block_stmt.iter_vars[1].dom.extent + if isinstance(extent, + tir.expr.IntImm) and (extent.value < + (1 if allow_gemv else minimal_tensorize_threshold)): + return func, None + for item_var in block_stmt.iter_vars[2:]: + extent = item_var.dom.extent + if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): + return func, None + tags = analysis_tensorcore_tags(sch, main_block, target) + return sch.mod["main"], tags + + return func, None + + +def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ) + + assert dtype in [ + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", + ], "Only support float16, int8, e4m3_float8, e5m2_float8" + if dtype == "float16": + ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout + ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout + elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + # int8 mma only support 32x16 to 16x32 layout + if matrix_name == "A" and trans is False: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a + elif matrix_name == "B" and trans is True: + ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b + else: + raise ValueError("Unknown matrix name ", matrix_name) + + # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out + def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout(thread_id, local_id) + + def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 8 + local_id = kernel_j % 8 + return ldmatrix_layout_trans(thread_id, local_id) + + def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): + thread_id = kernel_i * 2 + kernel_j // 16 + local_id = kernel_j % 16 + return ldmatrix_layout(thread_id, local_id) + + if dtype == "float16": + ldmatrix_index_map = ( + ldmatrix_trans_permutation_16x16_32x8_16x16 + if trans else ldmatrix_permutation_16x16_32x8_16x16) + else: + ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 + + ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype) + # TODO(lei): index_dtype should be analyzed from the schedule + row, col = [16, 16] if dtype == "float16" else [16, 32] + inversed_index_map = ldmatrix_index_map.inverse([row, col]) + return ldmatrix_index_map, inversed_index_map + + +def layout_propagate_chain( + sch: tir.Schedule, + start_block: BlockRV, + start_buffer: tir.Buffer, + end_block: BlockRV, + index_map: IndexMap, +): + # some layout transformation may only apply to the last n dimensions + # propagate the layout transformation to the chain of blocks + block = start_block + buffer = start_buffer + index_map = index_map + while True: + last_buffer = buffer + producers = sch.get_producers(block) + if len(producers) == 0: + break + for producer in producers: + if len(sch.get(producer).writes) != 1: + return index_map + if sch.get(producer) == sch.get(end_block): + return index_map + (write,) = sch.get(producer).writes + + read = find_first_similar_region(sch.get(producer).reads, last_buffer) + if write.buffer == buffer: + block = producer + buffer = read.buffer + write_indices = [r.min for r in write.region] + read_indices = [r.min for r in read.region] + # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout + tmp_index_map = IndexMap(write_indices, read_indices, None) + tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] + + # if dequantize like ops are used, the scaling factor should be considered + # to be applied to the final indices + scaling_factor = 1 + for i, j in zip(write.buffer.shape, read.buffer.shape): + scaling_factor *= i // j + final_indices = list( + index_map.map_indices(tmp_index_map.map_indices(write_indices))) + final_indices[-1] = final_indices[-1] // scaling_factor + index_map = IndexMap( + write_indices, + final_indices, + None, + ) + if buffer == last_buffer: + break + return index_map diff --git a/bitblas/gpu/matmul_mma.py b/bitblas/gpu/matmul_mma.py new file mode 100644 index 000000000..4bf8be4e6 --- /dev/null +++ b/bitblas/gpu/matmul_mma.py @@ -0,0 +1,1069 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional, List + +from tvm import tir, DataType +from tvm.target import Target + +from ..base.roller import Hint +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( + auto_inline_consumer_chain, + is_transpose_block, + is_identity_block, + _collect_producers, + inline_transpose_block, + auto_inline_producers, + get_index_map, + get_reduction_blocks, + get_dequantize_block, + normalize_to_matmul, + get_propagate_map, +) + + +def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 + + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_warp_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMMA(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.apply(func, target, False) + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + # We first inline all transpose blocks for later analysis of transposed A and B + blocks = inline_transpose_block(sch, blocks) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + dequantize_block = get_dequantize_block(sch, blocks) + + main_block = reduction_blocks[0] + main_block_stmt = sch.get(main_block) + + # Supported data types: + # fp16, fp16, fp16: fp16 precision + # fp16, fp16, fp32: fp16 mixed precision + dtype_a = main_block_stmt.reads[0].buffer.dtype + dtype_b = main_block_stmt.reads[1].buffer.dtype + dtype_c = main_block_stmt.writes[0].buffer.dtype + if dtype_a != dtype_b: + return None + + # Get index maps + index_maps = get_index_map(main_block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # Tensorization by hardware intrinsics + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, shared_16x16_to_mma_32x8_layout, + ) + + # tile size + block_m, block_n, block_k = 128, 128, 32 + + # tensor core intrinsic size + micro_size_m, micro_size_n, micro_size_k = 16, 16, 16 + + # thread size + # thread_x == warp_size + thread_z, thread_y, thread_x = 2, 2, 32 + + vector_size = 8 + unroll_depth = 4 # noqa: F841 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + is_transpose_a = is_transpose_block(sch.get(block)) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + is_transpose_b = is_identity_block(sch.get(block)) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + batch, i, j, k = sch.get_loops(main_block) + + swizzle_factor_for_l2_m = [1, None] + swizzle_factor_for_l2_n = [1, None] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + swizzle_factor_for_l2_m[0] * block_m, + swizzle_factor_for_l2_n[0] * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # Step 3.1 inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # Step 3.2 outer loops for tiling + # split factors for i, j, and k + micro_block_cnt_in_warp_m = block_m // thread_z // micro_size_m + micro_block_cnt_in_warp_n = block_n // thread_y // micro_size_n + micro_block_cnt_in_warp_k = block_k // micro_size_k + + i_factors = swizzle_factor_for_l2_m + [thread_z, micro_block_cnt_in_warp_m] + j_factors = swizzle_factor_for_l2_n + [thread_y, micro_block_cnt_in_warp_n] + k_factors = [None, micro_block_cnt_in_warp_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_axis = sch.fuse(batch, i0, j0, i1, j1) + sch.bind(block_axis, "blockIdx.x") + + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read/write to shared mem and register + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose): + # 1) Read to shared memory + block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read_smem, k0) + auto_inline_producers(sch, block_read_smem, + [dequantize_block] if dequantize_block else []) + + # For transposed read, we directly load transposed tensor from global + # Then use ldmatrix.trans to handle transpose later + if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose): + # specifical handle transpose read (for NN matmul or TT matmul) + v0, v1 = sch.get_loops(block_read_smem)[-2:] + sch.reorder(v1, v0) + sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:]) + f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + + # swizzling + sch.annotate(block_read_smem, ann_key="permuted_layout", ann_val=1) + + # 2) Read to register + block_read_reg = sch.cache_read(block_outer, read_buffer_idx, "warp") + sch.compute_at(block_read_reg, k1) + + # bind_loops + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + micro_size_1, micro_size_2 = ((micro_size_spatial, + micro_size_k) if not is_transpose else + (micro_size_k, micro_size_spatial)) + v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1]) + v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2]) + sch.reorder(v00, v10, v01, v11) + + # reorder read axis to match the layout of ldmatrix + sch.transform_layout( + block_read_reg, + ("write", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_1, + v2 // micro_size_2, + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2), + ), + ) + + # swizzling + mma_read_block = sch.blockize(sch.get_loops(block_read_reg)[-2]) + sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) + + return block_read_smem, block_read_reg + + block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a) + block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b) + + # Write to register, and then smem + def store_output(block_outer, write_buffer_idx): + # 1) Write to shared memory + block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write_smem, block_axis) + auto_inline_consumer_chain(sch, block_write_smem) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:]) + f0, f1, f2 = sch.split(fused, [None, thread_x, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + # 2) Write to register + block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp") + + # bind loops + v0, v1, v2 = sch.get_loops(block_write_reg)[-3:] + v11, v12, v13 = sch.split(v1, factors=[thread_z, None, micro_size_m]) + v21, v22, v23 = sch.split(v2, factors=[thread_y, None, micro_size_n]) + sch.reorder(v11, v21, v12, v22, v13, v23) + sch.bind(v11, "threadIdx.z") + sch.bind(v21, "threadIdx.y") + + # reorder write axis to match the layout of ldmatrix + sch.transform_layout( + block_write_reg, + ("read", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_m, + v2 // micro_size_n, + *shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n), + ), + ) + + return block_write_smem, block_write_reg + + _, block_write_reg = store_output(block_outer, 0) + + # Step 5. Schedule tensor core computation + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + intrin_group = get_mma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + a_dtype=str(dtype_a), + b_dtype=str(dtype_b), + out_dtype=str(dtype_c), + trans_a=is_transpose_a, + trans_b=is_transpose_b, + not_use_mma_store_intrinic=False, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) + + # Step 6. Async pipeline + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) + + # Step 7. Handle dequantize block + # Now we just add a dummy kernel to compute dequantize + if dequantize_block is not None: + auto_inline_producers(sch, dequantize_block) + loops = sch.get_loops(dequantize_block) + loop = sch.fuse(*loops) + v0, v1, v2, v3 = sch.split(loop, [None, 128, 2, 4]) + sch.bind(v0, "blockIdx.x") + sch.bind(v1, "threadIdx.x") + sch.unroll(v2) + sch.vectorize(v3) + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config: Hint, + ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.apply_config(func, config) + + if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: + return self.apply_block_reduction_with_config(func, config) + + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] + + def check_require_cache(func: tir.PrimFunc, config): + conditions: List[bool] = [] + + # check if has dynamic symbolic + def check_has_dynamic(func: tir.PrimFunc): + for param in func.params: + if param not in func.buffer_map: + continue + arg = func.buffer_map[param] + for i in arg.shape: + if isinstance(i, tir.Var): + return True + return False + + conditions.append(check_has_dynamic(func)) + # check if has post process + conditions.append(sch.get(main_block) not in output_blocks) + # check if not use async copy + conditions.append(config.use_async is False) + return any(conditions) + + cache_write_required = check_require_cache(func, config=config) + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + # tensor core intrinsic size + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if chunk * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + # rewrite smooth layout of shared memory + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_smem_layout_rewrite( + block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_smem_layout_rewrite( + block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) + smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_1, "threadIdx.z") + sch.bind(f_0, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + if len(config.vectorize.values()) < 2: + return None + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + trans=intrin_info.trans_a, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + vec_len=list(config.vectorize.values())[1], + can_swizzle=can_swizzle_b, + is_smooth=intrin_info.smooth_b, + trans=intrin_info.trans_b, + ) + + # rewrite global smooth layout + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + producers = _collect_producers(sch, block) + g2s_block = a_g2s if matrix_name == "A" else b_g2s + propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + smooth_gmem_layout_rewrite( + sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) + j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def apply_block_reduction_with_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config: Hint, + ) -> Optional[tir.Schedule]: + if "dequantize_info" in func.attrs: + dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() + return dequantize_rule.sch_shared_memory_prefetch_block_reduction_with_config( + func, config) + + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] + + def check_require_cache(func: tir.PrimFunc, config): + conditions: List[bool] = [] + + # check if has dynamic symbolic + def check_has_dynamic(func: tir.PrimFunc): + for param in func.params: + if param not in func.buffer_map: + continue + arg = func.buffer_map[param] + for i in arg.shape: + if isinstance(i, tir.Var): + return True + return False + + conditions.append(check_has_dynamic(func)) + # check if has post process + conditions.append(sch.get(main_block) not in output_blocks) + # check if not use async copy + conditions.append(config.use_async is False) + return any(conditions) + + cache_write_required = check_require_cache(func, config=config) + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + # Start Schedule + # Step 0. Get schedule config. + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + assert (config.block_reduction_depth is not None), "block_reduction_depth is required" + reduce_k = config.block_reduction_depth + chunk = config.rstep[0] // reduce_k + + # tensor core intrinsic size + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + # introduce the constraint of reduce_k because reduce_k will doubling the size of + if (chunk * reduce_k) * DataType(dtype).bits != (512): + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, (reduce_k * chunk) // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + k0, kr = sch.split(k0, [None, reduce_k]) + + sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(thread_idy, "threadIdx.y") + sch.bind(kr, "threadIdx.z") + + # rewrite smooth layout of shared memory + def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_smem_layout_rewrite( + block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_smem_layout_rewrite( + block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) + smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + f_0 = f_1 = sch.fuse(f_0, f_1) + sch.bind(f_0, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_2) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_2, "pragma_unroll_explicit", False) + return block_read + + if len(config.vectorize.values()) < 2: + return None + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + trans=intrin_info.trans_a, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + vec_len=list(config.vectorize.values())[1], + can_swizzle=can_swizzle_b, + is_smooth=intrin_info.smooth_b, + trans=intrin_info.trans_b, + ) + + # rewrite global smooth layout + def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + producers = _collect_producers(sch, block) + g2s_block = a_g2s if matrix_name == "A" else b_g2s + propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + smooth_gmem_layout_rewrite( + sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) + j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_warp_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + # sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + # sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py new file mode 100644 index 000000000..679e84395 --- /dev/null +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -0,0 +1,2295 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Optional, List +from contextlib import suppress + +from tvm import tir, DataType + +from ..base.roller.hint import Hint, IntrinInfo +from tvm.target import Target +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from ..base.analysis import get_coalesced_veclen +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_reduction_blocks, + normalize_to_matmul, + get_propagate_map, + layout_propagate_chain, + find_last_producer_from_buffer, + _collect_producers, + get_in_out_dtypes, +) + + +def _bind_thread_based_on_config(sch, block, block_row_warps, block_col_warps, warp_size): + # assume the block loops has been fused + last_loop = sch.get_loops(block)[-1] + loop_extent = sch.get(last_loop).extent + vec = get_coalesced_veclen(sch.get(block)) + + if loop_extent // (vec * warp_size) == 0: + last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) + sch.bind(B_shared_tx, "threadIdx.x") + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. + if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // + vec) & (warp_size - 1) == 0: + last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: + last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + loop_extent = sch.get(last_loop).extent + + if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: + last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) + sch.bind(B_shared_ty, "threadIdx.y") + loop_extent = sch.get(last_loop).extent + + if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: + last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) + sch.bind(B_shared_tz, "threadIdx.z") + loop_extent = sch.get(last_loop).extent + + sch.unroll(last_loop) + sch.annotate(last_loop, "pragma_unroll_explicit", False) + + +def _bind_thread_based_with_block_reduce_on_config(sch, block, block_row_warps, block_col_warps, + warp_size, reduce_k): + # assume the block loops has been fused + last_loop = sch.get_loops(block)[-1] + loop_extent = sch.get(last_loop).extent + vec = get_coalesced_veclen(sch.get(block)) + + if loop_extent // (vec * warp_size) == 0: + last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) + sch.bind(B_shared_tx, "threadIdx.x") + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. + if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // + vec) & (warp_size - 1) == 0: + last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) + sch.vectorize(B_shared_vi) + loop_extent = sch.get(last_loop).extent + + if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: + last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + loop_extent = sch.get(last_loop).extent + + B_shared_ty = None + if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: + last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) + loop_extent = sch.get(last_loop).extent + + B_shared_tz = None + if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: + last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) + loop_extent = sch.get(last_loop).extent + + if B_shared_ty and B_shared_tz: + B_shared_ty = sch.fuse(B_shared_tz, B_shared_ty) + sch.bind(B_shared_ty, "threadIdx.y") + elif B_shared_ty: + sch.bind(B_shared_ty, "threadIdx.y") + + if loop_extent // reduce_k >= 1 and loop_extent % reduce_k == 0: + last_loop, B_shared_tk = sch.split(last_loop, factors=[None, reduce_k]) + sch.bind(B_shared_tk, "threadIdx.z") + loop_extent = sch.get(last_loop).extent + + sch.unroll(last_loop) + sch.annotate(last_loop, "pragma_unroll_explicit", False) + + +def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 + + def index_map_3d(b, i, j): + return ( + b, + i // l, + j // r, + *index_map(i % l, j % r), + ) + + return index_map_3d + + +def get_index_map_5d(index_map): + """ + for layout transformed gemm, the index map should be 5d + """ + + def index_map_5d(b, i, j, ii, jj): + return ( + b, + i, + j, + *index_map(ii, jj), + ) + + return index_map_5d + + +def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 + if is_5d: + return get_index_map_5d(index_map) + return get_index_map_3d(index_map, l, r) + + +class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ): + """ + For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" + + # Start Schedule + # Step 1. Get default schedule config. + + # tensor core intrinsic size + in_dtype, out_dtype = get_in_out_dtypes(sch.get(main_block)) + intrin_info = IntrinInfo( + in_dtype=in_dtype, + out_dtype=out_dtype, + trans_b=True, + ) + if "weight_transform_kind" in func.attrs: + intrin_info.weight_transform_kind = int(func.attrs["weight_transform_kind"]) + + if "input_transform_kind" in func.attrs: + intrin_info.input_transform_kind = int(func.attrs["input_transform_kind"]) + # default Hint + config = Hint().from_dict({ + "block": [128, 128], + "warp": [64, 64], + "rstep": [32], + "pipeline_stage": 1, + "use_async": False, + "intrin_info": intrin_info, + "shared_scope": "shared.dyn", + }) + shared_scope = config.shared_scope + + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if chunk * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["n", "t", "n"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.z") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_0) + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + sch.annotate(f_0, "pragma_unroll_explicit", False) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=4, + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = _collect_producers(sch, weight_dequantize_block) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if ("zeros_mode" in weight_decode_info and + weight_decode_info["zeros_mode"] == "quantized"): + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def sch_dequantize_in_register_with_config( + self, + func: tir.PrimFunc, + config: Hint, + ): + """ + For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. + quantized weight + | + V + dequantized in register + | + V + save into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + intrin_info = config.intrin_info + shared_scope = config.shared_scope + + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if chunk * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.z") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = _collect_producers(sch, weight_dequantize_block) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if ("zeros_mode" in weight_decode_info and + weight_decode_info["zeros_mode"] == "quantized"): + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def sch_shared_memory_prefetch_with_config( + self, + func: tir.PrimFunc, + config: Hint, + ): + """ + For A100 Like devices, the shared memory prefetch(async) is required + to achieve optimal performance. + quantized weight + | + V + shared memory prefetch (with async copy) + | + V + dequantized into shared memory + | + V + compute + """ + if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: + return self.sch_shared_memory_prefetch_block_reduction_with_config(func, config) + + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" + + # Start Schedule + # Step 0. Get schedule config. + # tensor core intrinsic size + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if chunk * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite( + sch, + main_block, + enable=True, + trans=False, + matrix_name="A", + intrin_group=intrin_group, + ): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + intra_index_map, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # get target dequantize buffer's offset + def get_offset(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_dequantize_info["source_format"]["format"] == "nf": + return 1 + return 0 + + offset = get_offset() + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + group_size = weight_dequantize_info["group_size"] + + _, mn, mk = intrin_group["micro_kernel"] + + def get_param_indices( + indexmap, + l=mn, + r=mk, + group_size=group_size # noqa: E741 + ): # noqa: E741 + # assume the param layout is n, k + rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + return new_indices + + with_scaling = bool(weight_dequantize_info["with_scaling"]) + if with_scaling: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 1), + get_param_indices(intra_index_map), + ) + with_zeros = bool(weight_dequantize_info["with_zeros"]) + if with_zeros: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 2), + get_param_indices(intra_index_map), + ) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + sch.bind(thread_idz, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + sch.bind(f_2, "threadIdx.z") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + # TODO(lei): the factor should be analyzed more deeper. + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block, should skip qzeros + is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and + weight_decode_info["zeros_mode"] == "quantized") + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = ( + _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + # global -> prefetch_shared -> vector load -> dequantzed_local -> shared + block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, + shared_scope) + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if is_qzeros: + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_tz, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_shared_local_local_shared).iter_vars) + _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) + + _bind_thread_based_on_config(sch, block_shared_local_local_shared, num_ty, num_tz, + warp_size) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + thread_idz, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def sch_shared_memory_prefetch_block_reduction_with_config( + self, + func: tir.PrimFunc, + config: Hint, + ): + """ + For A100 Like devices, the shared memory prefetch(async) is required + to achieve optimal performance. + quantized weight + | + V + shared memory prefetch (with async copy) + | + V + dequantized into shared memory + | + V + compute + """ + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + from .intrin import get_lop3_intrin_group + + import_source: List[str] = [] + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + # always enable shared memory rewrite + cache_write_required = True + + # Check Dequantize Info + # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. + dequantize_info = func.attrs["dequantize_info"] + + def check_dequantize_info(dequantize_info): + conditions = [] + # currently only support weight only dequantization + conditions.append(len(dequantize_info) == 1) + # TODO(@lei) check if the dequantize value name is weight + return all(conditions) + + assert check_dequantize_info(dequantize_info) + + (weight_decode_info,) = list(dequantize_info.values()) + + def check_weight_decode_info(weight_decode_info): + conditions = [] + # check source format in ["int", "fp", "nf"] + conditions.append("source_format" in weight_decode_info) + conditions.append(weight_decode_info["source_format"]["format"] in + ["uint", "int", "fp", "nf", "fp_e4m3"]) + # check source bits in [1, 2, 4, 8] + conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) + # check target format in ["float16", "int8"] + conditions.append("target_format" in weight_decode_info) + conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) + return all(conditions) + + assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" + + # Start Schedule + # Step 0. Get schedule config. + # tensor core intrinsic size + shared_scope = config.shared_scope + + intrin_info = config.intrin_info + intrin_group = get_mma_intrin_group( + load_scope=shared_scope, + store_scope=shared_scope if cache_write_required else "global", + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_a=intrin_info.trans_a, + trans_b=intrin_info.trans_b, + smooth_a=intrin_info.smooth_a, + smooth_b=intrin_info.smooth_b, + not_use_mma_store_intrinic=False, + ) + + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + assert (config.block_reduction_depth is not None), "block_reduction_depth is required" + reduce_k = config.block_reduction_depth + chunk = config.rstep[0] // reduce_k + + micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] + + # get the axis for layout transform + def get_axis(l, r, trans): # noqa: E741 + return (r, l) if trans else (l, r) # noqa: E741 + + a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) + b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) + + def can_enable_swizzle(dtype: str, smooth: bool): + # inject_permuted_layout only support float16 currently + if dtype == "float16" or dtype == "int8": + if (chunk * reduce_k) * DataType(dtype).bits != 512: + # currently the swizzle rule only support 512 bit. + return False + # if we use smooth layout, we don't need to do swizzling + return not smooth + return False + + can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) + can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) + + # rewrite global smooth layout, for dequantize, currently only support weight only recover. + def smooth_gmem_layout_rewrite( + sch, + main_block, + enable=True, + trans=False, + matrix_name="A", + intrin_group=intrin_group, + ): + if not enable: + return + + # normalized block may have three read buffers, while the first one is the write buffer. + buffer_offset = (1 if sch.get(main_block).reads[0].buffer + == sch.get(main_block).writes[0].buffer else 0) + buffer_idx = 0 if matrix_name == "A" else 1 + source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer + + # step1: find the first producer block + # Notes: we assume the layout propagate happens in the first producer block + # otherwise, the layout transform will have no effect as it will transform both + # read and write buffer + propagate_block: tir.Block = find_last_producer_from_buffer( + sch, main_block, source_buffer) + # some trick impl may not have reindex block + (weight_dequantize_info,) = dequantize_info.values() + if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): + return + # step2: transform the layout with inverse permutation + intra_indexmap, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # step3: propagate the matmul layout to the first reindex block + + intra_indexmap = layout_propagate_chain( + sch, + start_block=main_block, + start_buffer=source_buffer, + end_block=propagate_block, + index_map=intra_indexmap, + ) + + def inverse_permutation(i, j, ii, jj): + return (i, j, *intra_indexmap.map_indices([ii, jj])) + + sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) + + intra_index_map, _ = get_propagate_map( + trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) + + # get target dequantize buffer's offset + def get_offset(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_dequantize_info["source_format"]["format"] == "nf": + return 1 + return 0 + + offset = get_offset() + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + group_size = weight_dequantize_info["group_size"] + + _, mn, mk = intrin_group["micro_kernel"] + + def get_param_indices( + indexmap, + l=mn, + r=mk, + group_size=group_size # noqa: E741 + ): # noqa: E741 + # assume the param layout is n, k + rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) + new_indices = ( + spatial_i * l + warp_i, + (spatial_j * r + warp_j) // group_size, + ) + return new_indices + + with_scaling = bool(weight_dequantize_info["with_scaling"]) + if with_scaling: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 1), + get_param_indices(intra_index_map), + ) + with_zeros = bool(weight_dequantize_info["with_zeros"]) + if with_zeros: + sch.unsafe_rewrite_buffer_region( + dequantize_block, + ("read", offset + 2), + get_param_indices(intra_index_map), + ) + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") + + smooth_gmem_layout_rewrite( + sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] + num_tz = j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + k0, kr = sch.split(k0, [None, reduce_k]) + + sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) + + block_idy = sch.fuse(i0, j0) + block_idx = sch.fuse(i1, j1) + thread_idy = i2 + thread_idz = j2 + + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) + sch.bind(thread_idy, "threadIdx.y") + sch.bind(kr, "threadIdx.z") + + def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 + if not enable: + return + sch.transform_layout( + block, + scope, + lambda b, i, j: ( + b, + i // l, + j // r, + i % l, + j % r, + ), + ) + + smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) + smooth_layout_recover( + block_outer, + ("read", 1), + *b_lr, + enable=intrin_info.inter_transform_b, + ) + smooth_layout_recover(block_outer, ("write", 0), enable=True) + + def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): + block_read = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_read, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_read).iter_vars) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + f_0, f_r, f_1, f_2, f_3, f_4 = sch.split( + fused, factors=[None, reduce_k, num_ty, num_tz, warp_size, vec_len]) + + sch.bind(f_3, "threadIdx.x") + f_1 = f_2 = sch.fuse(f_1, f_2) + sch.bind(f_1, "threadIdx.y") + sch.bind(f_r, "threadIdx.z") + sch.vectorize(f_4) + sch.unroll(f_0) + sch.annotate(f_0, "pragma_unroll_explicit", False) + + # Apply Swizzling + sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) + # if not, apply padding to alleviate bank conflict + if not (can_swizzle or is_smooth): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + vec_len=list(config.vectorize.values())[0], + can_swizzle=can_swizzle_a, + is_smooth=intrin_info.smooth_a, + ) + + auto_inline_producers(sch, a_g2s) + + def decode_fetch_to_shared(block, idx): + # step1. create memory hierarchy + # global -> local -> shared + block_shared = sch.cache_read(block, idx, shared_scope) + sch.compute_at(block_shared, k0, preserve_unit_loops=True) + + # TODO(lei): the factor should be analyzed more deeper. + decode_factor = get_coalesced_veclen(sch.get(block_shared)) + _, B_shared_vi, _ = sch.split( + sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) + block_shared_local = sch.cache_read(block_shared, 0, "local") + # global -> dequantzed_local -> shared + # step2. inline to local block, should skip qzeros + is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and + weight_decode_info["zeros_mode"] == "quantized") + weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) + weight_producers = ( + _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) + auto_inline_producers(sch, block_shared_local, weight_producers) + + # get target dequantize buffer's idx + def get_idx(): + # for LUT dequantize, the expr is LUT(w), the idx is 1 + # maybe we can use a more general and structural based way + # to analysis the idx + if weight_decode_info["source_format"]["format"] == "nf": + return 1 + return 0 + + b_idx = get_idx() + # global -> prefetch_local -> dequantzed_local -> shared + block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") + # global -> prefetch_shared -> vector load -> dequantzed_local -> shared + block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, + shared_scope) + sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) + sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) + + dequantize_block_local = block_shared_local + if is_qzeros: + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) + + for producer in weight_producers: + with suppress(Exception): + auto_inline_producers(sch, producer) + sch.compute_inline(producer) + + # fast type conversion + if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): + source_bit = weight_decode_info["source_format"]["bits"] + out_dtype = weight_decode_info["target_format"] + lop3_intrin_info = get_lop3_intrin_group( + out_dtype=out_dtype, + storage_dtype=weight_decode_info["storage_dtype"], + source_format=weight_decode_info["source_format"]["format"], + source_bit=source_bit, + with_scaling=weight_decode_info["with_scaling"], + with_zeros=weight_decode_info["with_zeros"], + zeros_mode=weight_decode_info["zeros_mode"], + ) + sch.tensorize( + sch.get_loops(dequantize_block_local)[-1], + lop3_intrin_info["compute"], + ) + import_source.append(lop3_intrin_info["c_source"]) + + sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) + union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) + B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) + _, B_shared_rk, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( + B_shared_fused, factors=[None, reduce_k, num_ty, num_tz, warp_size]) + if not (can_swizzle_b or intrin_info.smooth_b): + pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 + sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) + sch.bind(B_shared_tx, "threadIdx.x") + B_shared_tz = B_shared_ty = sch.fuse(B_shared_ty, B_shared_tz) + sch.bind(B_shared_ty, "threadIdx.y") + sch.bind(B_shared_rk, "threadIdx.z") + sch.vectorize(sch.get_loops(block_shared)[-1]) + sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) + + sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) + ndim = len(sch.get(block_shared_local_local_shared).iter_vars) + _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) + + _bind_thread_based_with_block_reduce_on_config(sch, block_shared_local_local_shared, + num_ty, num_tz, warp_size, reduce_k) + + # cache small tensors, e.g. LUT + if b_idx: + block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) + sch.reverse_compute_at(block_shared_lut, j2) + _, B_shared_tx = sch.split( + sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) + sch.bind(B_shared_tx, "threadIdx.x") + return block_shared_local + + _ = decode_fetch_to_shared(block_outer, 1) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "warp") + B_mat = sch.cache_read(block_outer, 1, "warp") + sch.compute_at(A_mat, k1, preserve_unit_loops=True) + sch.compute_at(B_mat, k1, preserve_unit_loops=True) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + if cache_write_required: + accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) + + store = sch.cache_write(block_outer, 0, "warp") + sch.reverse_compute_at(store, j2) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, micro_size_x]) + j0, j1 = sch.split(j, factors=[None, micro_size_y]) + sch.reorder(i0, j0, i1, j1) + + if cache_write_required: + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + sch.reverse_compute_at( + accumulator_shared_to_global, + sch.get_loops(store)[-5], + preserve_unit_loops=True, + ) + vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) + f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + sch.unroll(f0) + sch.annotate(f0, "pragma_unroll_explicit", False) + else: + auto_inline_consumer_chain(sch, store) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + + index_map_a, index_map_b, index_map_c = intrin_group["index_map"] + + sch.transform_layout( + A_mat, + ("write", 0), + get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), + ) + sch.transform_layout( + B_mat, + ("write", 0), + get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), + ) + sch.transform_layout( + store, + ("read", 0), + get_index_map(index_map_c, is_5d=True), + ) + + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, a_lr[0]]) + j0, j1 = sch.split(j, factors=[None, a_lr[1]]) + sch.reorder(i0, j0, i1, j1) + ba = sch.blockize(i1) + sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) + sch.tensorize(ba, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, b_lr[0]]) + j0, j1 = sch.split(j, factors=[None, b_lr[1]]) + sch.reorder(i0, j0, i1, j1) + bb = sch.blockize(i1) + sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) + sch.tensorize(bb, intrin_group["load_b"]) + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + tensorize_init_store_compute() + + if stage > 1: + sch.annotate( + k0, + ann_key="software_pipeline_stage", + ann_val=[0, 0, stage - 1, stage - 1], + ) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + # plan rasteration + if not isinstance(config.rasterization_plan, NoRasterization): + device_func, invoke_func = config.rasterization_plan.get_code() + import_source.append(device_func) + sch.annotate( + sch.get_loops(block_init_c)[-2], + ann_key="inject_customized_code_prepend", + ann_val=invoke_func, + ) + # plan import source + if len(import_source) > 0: + sch.annotate( + j2, + ann_key="pragma_import_c", + ann_val=("\n").join(import_source), + ) + return sch + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config: Hint, + ) -> Optional[tir.Schedule]: + + def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + if check_sm_version(config.arch.target.arch) < 80: + """MMA Template only support sm_80 and above""" + return None + + if (config.arch.target.kind.name == "cuda" and + check_sm_version(config.arch.target.arch) == 80): + return self.sch_shared_memory_prefetch_with_config(func, config) + else: + return self.sch_dequantize_in_register_with_config(func, config) diff --git a/bitblas/gpu/matmul_wmma.py b/bitblas/gpu/matmul_wmma.py new file mode 100644 index 000000000..60817258f --- /dev/null +++ b/bitblas/gpu/matmul_wmma.py @@ -0,0 +1,892 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional + +from tvm import DataType, tir +from tvm.target import Target + +from ..base.roller.rasterization import NoRasterization +from ..base import analysis +from .base import GPUScheduleRule +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_index_map, + get_reduction_blocks, + normalize_to_matmul, +) + + +class MatmulTensorizationWMMA(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + block_m = 128 + block_n = 128 + block_k = 32 + + # tensor core intrinsic size + micro_size_m = 16 + micro_size_n = 16 + micro_size_k = 16 + + thread_z = 2 + thread_y = 2 + warp_size = 32 + + vector_size = 8 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + + # # Step 2.1 Swizzle for l2, for better performance on inputs exceeding l2 size + # # Get input shape + batch, i, j, k = sch.get_loops(main_block) + # input_b, input_m, input_n, input_k = [sch.get(loop).extent for loop in [batch, i, j, k]] + + # # Get input/output dtype + dtype_a, dtype_b = [DataType(region.buffer.dtype) for region in sch.get(main_block).reads] + dtype_c = DataType(sch.get(main_block).writes[0].buffer.dtype) + # dtype_a_bytes, dtype_b_bytes = [math.ceil(d.bits / 8) for d in [dtype_a, dtype_b]] + + # # Get l2 size + # l2_size = target.l2_cache_size_bytes + + # # Analyse swizzle factor + # def get_swizzle_factor(l2_size, input_k, dtype_bytes, input_spatial, block_size): + # if l2_size != 0 and isinstance(input_k, (int, tir.IntImm)): + # # div by 3: suppose the two inputs and the output uses the same amount of l2 + # swizzle_factor = l2_size / 3 / int(input_k) / dtype_bytes / block_size + # # optimization: try find the best swizzle factor (aka the least additional padding) + # if isinstance(input_spatial, (int, tir.IntImm)): + # block_cnt = math.ceil(int(input_spatial) / block_size) + # swizzle_factor = math.ceil(block_cnt / math.ceil(block_cnt / swizzle_factor)) + # else: + # swizzle_factor = math.floor(swizzle_factor) + # return [None, swizzle_factor] + # else: + # return [4, None] + + # swizzle_factor_m = get_swizzle_factor(l2_size, input_k, dtype_a_bytes, input_m, block_m) + # swizzle_factor_n = get_swizzle_factor(l2_size, input_k, dtype_b_bytes, input_n, block_n) + + swizzle_factor_m = [4, None] + swizzle_factor_n = [4, None] + + # Step 2.2 Add padding + sch.pad_einsum( + main_block, + [ + 1, + (swizzle_factor_m[0] or swizzle_factor_m[1]) * block_m, + (swizzle_factor_n[0] or swizzle_factor_n[1]) * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # split factors for i, j, and k + in_wrap_block_cnt_m = block_m // thread_z // micro_size_m + in_wrap_block_cnt_n = block_n // thread_y // micro_size_n + in_wrap_block_cnt_k = block_k // micro_size_k + + i_factors = swizzle_factor_m + [thread_z, in_wrap_block_cnt_m] + j_factors = swizzle_factor_n + [thread_y, in_wrap_block_cnt_n] + k_factors = [None, in_wrap_block_cnt_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) + block_axis = sch.fuse(batch, i0, j0, i1, j1) + + sch.bind(block_axis, "blockIdx.x") + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read to/write from shared mem, and from/to wmma fragments + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wmma_name): + block_read = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + f0, f1, f2, f3, f4 = sch.split(fused, + [None, thread_z, thread_y, warp_size, vector_size]) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + + auto_inline_producers(sch, block_read) + + wmma_read = sch.cache_read(block_outer, read_buffer_idx, wmma_name) + sch.compute_at(wmma_read, k1) + + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + v0, v1 = sch.get_loops(wmma_read)[-2:] + sch.split(v0, factors=[None, micro_size_spatial]) + + return wmma_read + + wmma_read_a = fetch_input(block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], + "wmma.matrix_a") + wmma_read_b = fetch_input(block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], + "wmma.matrix_b") + + def store_output(block_outer, write_buffer_idx, wmma_name): + block_write = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write, block_axis) + + fused = sch.fuse(*sch.get_loops(block_write)[-2:]) + + f0, f1, f2, f3, f4 = sch.split(fused, + [None, thread_z, thread_y, warp_size, vector_size]) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + # sch.storage_align(block_write, 0, axis=-2, factor=128, offset=16) + + auto_inline_consumer_chain(sch, block_write) + + wmma_store = sch.cache_write(block_outer, write_buffer_idx, wmma_name) + v0, v1 = sch.get_loops(wmma_store)[-2:] + v00, v01, v02 = sch.split(v0, factors=[thread_z, None, micro_size_m]) + v10, v11, v12 = sch.split(v1, factors=[thread_y, None, micro_size_n]) + sch.reorder(v00, v10, v01, v11, v02, v12) + sch.bind(v00, "threadIdx.z") + sch.bind(v10, "threadIdx.y") + return wmma_store + + wmma_store = store_output(block_outer, 0, "wmma.accumulator") + + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + # unroll k + sch.unroll(k0) + + # Step 5. Schedule tensor core computation + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group,) + + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=str(dtype_a), + out_dtype=str(dtype_c), + trans_b=True, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(wmma_read_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(wmma_read_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(wmma_store)[-2], intrin_group["store"]) + + return sch + + +class MatmulInt8Tensorization(GPUScheduleRule): + """ + The schedule rule for int8 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group,) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 1], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="int8", + out_dtype="int32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except Exception: # pylint: disable=bare-except + return None + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + except Exception: # pylint: disable=bare-except + return None + + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch + + +class MatmulTensorizationLegacy(GPUScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group,) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 4], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except Exception: # pylint: disable=bare-except + return None + + # Try to tensorize the init, store and compute block with f16 or f32 intrinsics + tensorize_success: bool = False + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + tensorize_success = True + except Exception: # pylint: disable=bare-except + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float16", + trans_b=True, + ) + + if not tensorize_success: + try: + tensorize_init_store_compute() + tensorize_success = True + except Exception: # pylint: disable=bare-except + return None + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch if tensorize_success else None + + def apply_config( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + config, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group,) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + intrin_info = config.intrin_info + warp_row_tiles = config.warp[0] + warp_col_tiles = config.warp[1] + block_row_warps = config.block[0] // warp_row_tiles + block_col_warps = config.block[1] // warp_col_tiles + stage = config.pipeline_stage + use_async = config.use_async + chunk = config.rstep[0] + + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + + i_factors, j_factors, k_factors = ( + [None, 1, block_row_warps, warp_row_tiles // micro_size_x], + [1, None, block_col_warps, warp_col_tiles // micro_size_y], + [None, chunk // micro_size_k], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] + if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): + sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + # plan rasteration + if (not isinstance(config.rasterization_plan, NoRasterization) and + sch.get(batch).extent.value == 1): + device_func, invoke_func = config.rasterization_plan.get_code() + factor = config.rasterization_plan.panel_width_ + + # TODO(lei): this is a trick for rasterization implementation + # wait for https://github.com/apache/tvm/pull/16113 to be merged + # require a solution for general block rasterization + factor = 8 # should be divisible by block_idy + if sch.get(block_idy).extent.value % factor == 0: + block_k, block_idy = sch.split(block_idy, factors=[None, factor]) + sch.bind(block_k, "blockIdx.z") + else: + sch.bind(batch, "blockIdx.z") + + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vec_len]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + offset: int = 0 + if dtype == "float16": + offset = 8 + elif dtype == "int8": + offset = 16 + # todo(lei): the pad value should be varied according to the data type + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=offset) + return block_read + + a_g2s = fetch_to_shared( + block_outer, + 0, + 2, + vec_len=list(config.vectorize.values())[0], + dtype=intrin_info.in_dtype, + ) + b_g2s = fetch_to_shared( + block_outer, + 1, + 2, + vec_len=list(config.vectorize.values())[1], + dtype=intrin_info.in_dtype, + ) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=intrin_info.in_dtype, + out_dtype=intrin_info.out_dtype, + trans_b=intrin_info.trans_b, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except Exception: # pylint: disable=bare-except + return None + + # Try to tensorize the init, store and compute block with f16 or f32 intrinsics + tensorize_success: bool = False + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + tensorize_success = True + except Exception: # pylint: disable=bare-except + return None + + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split( + fused, factors=[None, warp_size, max(list(config.vectorize.values()))]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + if stage > 1: + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + if use_async: + sch.annotate(k0, "software_pipeline_async_stages", [0]) + + return sch if tensorize_success else None diff --git a/bitblas/gpu/reduction.py b/bitblas/gpu/reduction.py new file mode 100644 index 000000000..9d6aada75 --- /dev/null +++ b/bitblas/gpu/reduction.py @@ -0,0 +1,301 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm reduction.py in dlight. +"""A rule for reduction. """ +from typing import List, Optional, Tuple, Union + +from tvm import arith, ir, tir +from tvm.target import Target + +from ..base import ( + BlockInfo, + normalize_prim_func, + try_inline_contiguous_spatial, + detect_dominant_read, + is_broadcast_epilogue, +) +from . import utils +from .base import GPUScheduleRule + + +def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + # Detect and return `Y` in `X[...] = X[...] + Y` + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +class Reduction(GPUScheduleRule): + """A rule for Reduction.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + if block_infos is None: + return None + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + block = block_info.block_rv + block_stmt = sch.get(block) + + # Step 1. Check reduction block + if ( + (not block_info.is_reduction()) + or len(block_stmt.writes) != 1 + or _get_reduction_expr(block_stmt) is None + ): + return None + # Step 2. Normalize the block, merge spatial and reduction iters + is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( + sch, + block_info, + arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ), + ) + if is_inner_reduction is None and c_factor is None: + return None + # Step 3. Do the scheduling + if is_inner_reduction: + self._sch_inner_reduction( + sch, target, block, c_factor, epilogue, loop_order, s_split_index + ) + else: + self._sch_inner_spatial( + sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index + ) + return sch + + def _normalize( # pylint: disable=too-many-branches + self, + sch: tir.Schedule, + block_info: BlockInfo, + access: arith.IterSumExpr, + ) -> Tuple[Optional[bool], Optional[int]]: + if access.base != 0: + return None, None, None, None + iter_to_info = {i.var: i for i in block_info.iters} + s_loops, r_loops, c_loops, c_factor = [], [], [], None + s_split_loop, s_split_index = None, None + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.pop(var) + loop = info.loop_rv + is_inner_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None, None, None, None + s_split_loop = loop + s_split_index = len(s_loops) + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + c_loops.append(c_loop) + if not is_inner_reduction: + c_factor = split_expr.lower_factor + if is_inner_reduction: + r_loops.append(loop) + else: + s_loops.append(loop) + + if iter_to_info: + for var, info in iter_to_info.items(): + if info.kind == "S" and info.dom.extent == 1: + s_loops.append(info.loop_rv) + else: + return None, None, None, None + + loop_order = {} + s_block_var_loops = [] + for i in block_info.iters: + if i.loop_rv in s_loops or i.loop_rv == s_split_loop: + s_block_var_loops.append(i.loop_rv) + + for i in range(len(s_block_var_loops)): + for j in range(len(s_loops)): + if s_block_var_loops[i] == s_loops[j]: + loop_order[i] = j + break + if s_block_var_loops[i] == s_split_loop: + loop_order[i] = s_split_index + break + + assert s_loops + assert r_loops + if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]): + return None, None + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*s_loops, *r_loops, *c_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction, c_factor, loop_order, s_split_index + + def _sch_inner_reduction( # pylint: disable=too-many-arguments + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + unroll_spatial_factor: Optional[int], + epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, + ): + # pylint: disable=invalid-name + _, r, _ = sch.get_loops(block) + (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking + target, [sch.get(r)] + ) + + _, tx = sch.split(r, factors=[None, len_tx]) + # Schedule the RF block + rf = sch.rfactor(tx, 0) + bx, r, tx, _ = sch.get_loops(rf) + sch.reorder(bx, tx, r) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256) + sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, tx, *s = sch.get_loops(block) + + if unroll_spatial_factor: + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, tx, c) + else: + s = sch.fuse(*s) + sch.reorder(s, tx) + sch.bind(tx, "threadIdx.x") + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + sch.reverse_compute_at(epilogue, bx) + if is_broadcast_epilogue(sch, block, epilogue): + sch.set_scope(block, 0, "shared") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx]) + sch.bind(tx, "threadIdx.x") + else: + sch.set_scope(block, 0, "local") + # pylint: enable=invalid-name + + def _sch_inner_spatial( + self, + sch: tir.Schedule, + _: Target, + block: tir.schedule.BlockRV, + block_info: BlockInfo, + unroll_spatial_factor: Optional[int], + epilogue_info: Optional[BlockInfo], + loop_order, + s_split_index, + ): + # pylint: disable=invalid-name + s, r, _ = sch.get_loops(block) + len_tx, len_ty = 16, 16 + s_factor = [i.dom.extent for i in block_info.iters if i.kind == "S"][-1] + # get perfect spatial factor, spatial factor should be divide the innermost spatial loop so + # that the block after r_factor and be reversed compute at the original scope + while len_tx > 1: + if s_factor % len_tx == 0: + break + len_tx -= 1 + _, _ = sch.split(s, factors=[None, len_tx]) + _, ty = sch.split(r, factors=[None, len_ty]) + # Schedule the RF block + rf = sch.rfactor(ty, 0) + bx, tx, r, ty, _ = sch.get_loops(rf) + sch.reorder(bx, tx, ty, r) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + sch.bind(bx, "blockIdx.x") + sch.set_scope(rf, 0, "local") + sch.decompose_reduction(rf, r) + # Schedule the write back block + sch.reverse_compute_at(block, bx, preserve_unit_loops=True) + _, r, *s = sch.get_loops(block) + if unroll_spatial_factor: + assert len(s) == len(loop_order) + new_order_s = [s[loop_order[i]] for i in range(len(s))] + sch.reorder(*new_order_s) + new_order_s[s_split_index], c = sch.split( + new_order_s[s_split_index], factors=[None, unroll_spatial_factor] + ) + sch.reorder(*new_order_s, c) + s = sch.fuse(*new_order_s) + sch.reorder(s, c, r) + else: + s = sch.fuse(*s) + sch.reorder(s, r) + sch.bind(s, "threadIdx.x") + sch.bind(r, "threadIdx.y") + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + sch.reverse_compute_at(epilogue, bx) + if is_broadcast_epilogue(sch, block, epilogue): + sch.set_scope(block, 0, "shared") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, len_ty]) + sch.bind(tx, "threadIdx.x") + sch.bind(ty, "threadIdx.y") + else: + # The epilogue is element-wise without broadcasting. + # Thus the remaining spatial part should be bind to tx. + sch.set_scope(block, 0, "local") + _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + tx, _ = sch.split(sch.fuse(*s), factors=[len_tx, None]) + sch.bind(tx, "threadIdx.x") + # pylint: enable=invalid-name diff --git a/bitblas/gpu/rmsnorm.py b/bitblas/gpu/rmsnorm.py new file mode 100644 index 000000000..6e6d3e247 --- /dev/null +++ b/bitblas/gpu/rmsnorm.py @@ -0,0 +1,144 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm rmsnorm.py in dlight. +# pylint: disable=missing-docstring +"""A RMS norm schedule rule for GPU operators.""" + +import tvm +from tvm import tir +from tvm.tir import Block, BufferStore +from tvm.tir.expr import Cast, BufferLoad, Call +from tvm.target import Target + +from ..base import ScheduleRule + + +def identify_cast_or_load_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + # check types + if isinstance(store.value, BufferLoad): + load = store.value + elif isinstance(store.value, Cast): + load = store.value.value + if not isinstance(load, BufferLoad): + return False + else: + return False + + # check indices + if len(load.indices) != len(store.indices): + return False + + for lhs, rhs in zip(load.indices, store.indices): + if not lhs.same_as(rhs): + return False + + return True + + +def identify_rsqrt_block(block: Block) -> bool: + if len(block.reads) != 1 or len(block.writes) != 1: + return False + + if not isinstance(block.body, BufferStore): + return False + store = block.body + + if not isinstance(store.value, Call): + return False + call = store.value + op = call.op + + return op == tvm.ir.op.Op.get("tir.rsqrt") + + +class RMSNorm(ScheduleRule): + """A rule for RMS norm.""" + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> tir.Schedule: + if target.kind.name == "cuda": + num_tx = 512 + else: + num_tx = 64 + + sch = tir.Schedule(func) + root = sch.get_block(name="root", func_name="main") + + blocks = sch.get_child_blocks(root) + + if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]): + return None + + read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local") + write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local") + + for block in blocks: + if identify_cast_or_load_block(sch.get(block)): + sch.compute_inline(block) + + blocks = sch.get_child_blocks(root) + + read, sqr, redsum, rsqrt, norm, write = blocks + + if not identify_rsqrt_block(sch.get(rsqrt)): + return None + + for name in [read, sqr, redsum, rsqrt, norm, write]: + loops = sch.get_loops(name) + sch.fuse(*loops[:-1]) + + block_loop, loops = sch.get_loops(block=read) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(block_loop, thread_axis="blockIdx.x") + sch.bind(thread_loop, thread_axis="threadIdx.x") + sch.vectorize(sch.get_loops(block=read)[-1]) + sch.reverse_compute_at(block=sqr, loop=thread_loop) + sch.reverse_compute_at(block=redsum, loop=thread_loop) + + sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) + sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) + block_loop, loops = sch.get_loops(block=norm) + thread_loop, _, _ = sch.split( + loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True + ) + sch.bind(thread_loop, thread_axis="threadIdx.x") + + sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) + sch.vectorize(sch.get_loops(block=write)[-1]) + + sch.set_scope(block=sqr, buffer_index=0, storage_scope="local") + sch.set_scope(block=redsum, buffer_index=0, storage_scope="local") + sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared") + sch.set_scope(block=norm, buffer_index=0, storage_scope="local") + + return sch diff --git a/bitblas/gpu/transpose.py b/bitblas/gpu/transpose.py new file mode 100644 index 000000000..6dc025c07 --- /dev/null +++ b/bitblas/gpu/transpose.py @@ -0,0 +1,133 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from apache/tvm transpose.py in dlight. +"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" +from typing import List, Union + +from tvm import arith, tir +from tvm.target import Target +from tvm.tir import Schedule +from tvm.tir.schedule import BlockRV + +from ..base import ( + detect_dominant_read, + normalize_prim_func, + try_inline_contiguous_spatial, +) +from .base import GPUScheduleRule + + +class Transpose(GPUScheduleRule): + """Schedule rule for transpose""" + + def is_transpose(self, sch: Schedule, block_rv: BlockRV): + block = sch.get(block_rv) + if isinstance(block.body, tir.BufferStore): + rhs = block.body.value + if isinstance(rhs, tir.BufferLoad): + lhs_indices = block.body.indices + rhs_indices = rhs.indices + if list(lhs_indices) != list(rhs_indices) and set(lhs_indices) == set(rhs_indices): + return True + return False + + def apply( # pylint: disable=too-many-locals + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + # pylint: disable=invalid-name + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + if target.kind.name == "cuda": + len_tx = 16 + len_ty = 8 + unroll_depth = 256 + else: + len_tx = 8 + len_ty = 4 + unroll_depth = 64 + len_vec = 4 + + sch = tir.Schedule(func) + blocks = normalize_prim_func(sch) + transpose_block_idx = -1 + for idx, block in reversed(list(enumerate(blocks))): + if self.is_transpose(sch, block.block_rv): + transpose_block_idx = idx + break + if not block.is_injective(): + return None + if transpose_block_idx == -1: + return None + transpose_block = blocks[transpose_block_idx].block_rv + + prologue = None # the optional decoding block + if transpose_block_idx > 0: + spatials = try_inline_contiguous_spatial(sch, blocks[: transpose_block_idx - 1]) + assert len(spatials) == 0 + prologue = blocks[transpose_block_idx - 1].block_rv + + loops = sch.get_loops(transpose_block) + if len(loops) != 2: + # transpose with more than 2 axes is not supported + return None + + c_factor = 1 + if prologue is not None: + block_stmt = sch.get(prologue) + result = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom.extent for i in block_stmt.iter_vars}, + ) + if len(result.args) > 0: + c_factor = int(result.args[0].lower_factor) + + i, j = loops + i, vi = sch.split(i, factors=[None, c_factor], preserve_unit_iters=True) + bi, ti = sch.split(i, factors=[None, len_ty], preserve_unit_iters=True) + bj, tj = sch.split(j, factors=[None, len_tx], preserve_unit_iters=True) + sch.reorder(bi, bj, ti, tj, vi) + sch.bind(bi, "blockIdx.y") + sch.bind(bj, "blockIdx.x") + sch.bind(ti, "threadIdx.y") + sch.bind(tj, "threadIdx.x") + len_vec = min(len_vec, c_factor) + _, vi = sch.split(vi, factors=[None, len_vec]) + if len_vec > 1: + sch.vectorize(vi) + + cache_read = sch.cache_read(transpose_block, read_buffer_index=0, storage_scope="shared") + sch.compute_at(cache_read, bj) + loops = sch.get_loops(cache_read)[2:] + fused = sch.fuse(*loops) + _, ty, tx, v = sch.split(fused, factors=[None, len_ty, len_tx, c_factor]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.unroll(v) + sch.storage_align(block=cache_read, buffer_index=0, axis=0, factor=32, offset=1) + + sch.annotate(bi, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(bi, ann_key="pragma_unroll_explicit", ann_val=1) + + if prologue is not None: + sch.compute_inline(prologue) + return sch diff --git a/bitblas/gpu/utils.py b/bitblas/gpu/utils.py new file mode 100644 index 000000000..e3a5b6098 --- /dev/null +++ b/bitblas/gpu/utils.py @@ -0,0 +1,86 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# pylint: disable=missing-docstring +"""Utility methods for generic GPU.""" +from typing import List, Optional + +from tvm import tir +from tvm.target import Target + + +def max_threads_per_block(target: Target) -> int: + """Get the maximum number of threads per block for a given target. + + Parameters + ---------- + target : Target + The target to get the maximum number of threads per block for. + + Returns + ------- + max_threads_per_block : int + The maximum number of threads per block for the given target. + """ + for name in ["max_threads_per_block", "max_num_threads"]: + result = target.attrs.get(name, None) + if result is not None: + return result + if target.kind.name == "cuda": + return 1024 + return 256 + + +def suggest_threads_per_block( + target: Target, + loops: List[tir.For], + max_threads_for_dynamic_loop: int = 32, +) -> List[int]: + if target.kind.name == "cuda": + threads = 1024 + elif target.kind.name == "rocm": + threads = 256 + elif target.kind.name == "metal": + threads = 256 + else: + threads = 64 + results: List[Optional[int]] = [] + dynamic: List[int] = [] + for i, loop in enumerate(loops): + loop_extent = loop.extent + if isinstance(loop_extent, tir.IntImm): + loop_extent = loop_extent.value + extent = 1 + while extent <= loop_extent and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results.append(extent) + else: + results.append(None) + dynamic.append(i) + + for i in dynamic: + extent = 1 + while extent <= max_threads_for_dynamic_loop and extent <= threads: + extent *= 2 + extent //= 2 + assert extent >= 1 + assert threads % extent == 0 + threads //= extent + results[i] = extent + + if dynamic: + results[dynamic[0]] *= threads + + return results + + +def get_sm_version(target: Target) -> int: + if target.kind.name != "cuda": + return -1 + arch = target.arch + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py new file mode 100644 index 000000000..f353228a5 --- /dev/null +++ b/bitblas/module/__init__.py @@ -0,0 +1,305 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import ctypes +import operator +from functools import reduce +from logging import getLogger + +import torch +import torch.nn as nn + +logger = getLogger(__name__) + +from typing import List, Union, Optional + +from bitblas.cache import global_operator_cache, get_database_path +from bitblas import Matmul, MatmulConfig +from bitblas.quantization.utils import general_compress +from bitblas import auto_detect_nvidia_target + +BITBLAS_TARGET = auto_detect_nvidia_target() +BITBLAS_DATABASE_PATH = get_database_path() + + +def unpack_qzeros(qzeros, bits): + qzeros = qzeros.view(torch.int32) + elems_per_int32 = 32 // bits + unpacked_zeros = torch.zeros( + (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), + dtype=torch.int8, + device=qzeros.device, + requires_grad=False, + ) + for col in range(unpacked_zeros.shape[1]): + i = col % elems_per_int32 + unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) + + # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 + # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. + return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) + + +class Linear(nn.Module): + opt_M = [1, 16, 32, 64, 128, 256, 512] + STORAGE_DTYPE = "int8" # assume int8 storage + TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) + BITBLAS_DTYPES = { + torch.float32: "float32", + torch.float16: "float16", + torch.half: "float16", + torch.int8: "int8", + } + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + A_dtype: str = "float16", + W_dtype: str = "float16", + accum_dtype: str = "float16", + out_dtype: str = "float16", + # configs for weight only quantization + group_size: int = -1, + with_scaling: bool = None, + with_zeros: bool = False, + zeros_mode: str = None, + opt_M: Union[int, List[int]] = opt_M, + # performance related configs + enable_tuning: bool = True, + fast_decoding: Optional[bool] = None, + propagate_b: bool = False, + ): + """ + @opt_M: optimize range of the input shape for dynamic symbolic + if the input shape is a range, we will optimize the matmul with dynamic symbolic. + if the input shape is int, we will optimize the matmul with static symbolic. + """ + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.opt_M = opt_M + self.group_size = self._set_group_size(group_size, in_features) + self.torch_dtype = getattr(torch, A_dtype) + self.is_consitent = A_dtype == W_dtype + self.zeros_mode = zeros_mode + self._validate_parameters(self.group_size, in_features, out_features) + self._configure_bitblas_matmul( + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_scaling, + with_zeros, + zeros_mode, + enable_tuning, + fast_decoding, + bias, + propagate_b, + ) + self._initialize_buffers(in_features, out_features, bias) + + def init_params(self): + # eliminate runtime overhead like exllama state + if self.is_consitent: + param_list = [self.weight] + if self.bitblas_matmul.config.with_bias: + param_list.append(self.bias) + self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] + else: + param_list = [self.qweight] + if self.bitblas_matmul.config.with_scaling: + param_list.append(self.scales) + if self.bitblas_matmul.config.with_zeros: + param_list.append(self.zeros) + if self.bitblas_matmul.config.with_bias: + param_list.append(self.bias) + self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] + + def _validate_parameters(self, group_size, in_features, out_features): + if in_features % 16 != 0 or out_features % 16 != 0: + raise ValueError("`in_features` and `out_features` must be divisible by 16.") + if in_features % group_size != 0: + raise ValueError("`in_features` must be divisible by `group_size`.") + + def _set_group_size(self, group_size, in_features): + return in_features if (group_size == -1 or group_size is None) else group_size + + def _initialize_buffers(self, in_features, out_features, bias): + if self.consistent: + self.register_buffer( + "weight", + torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), + ) + else: + self.register_buffer( + "qweight", + torch.zeros( + self.bitblas_matmul.retrieve_weight_shape(), + dtype=self.TORCH_STORAGE_DTYPE, + ), + ) + self.register_buffer( + "scales", + torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), + ) + if self.zeros_mode == "quantized": + storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit())) + self.register_buffer( + "zeros", + torch.zeros( + ( + in_features // self.group_size, + out_features // storage_nbit * self.bits, + ), + dtype=self.TORCH_STORAGE_DTYPE, + ), + ) + else: + self.register_buffer( + "zeros", + torch.zeros( + (out_features, in_features // self.group_size), + dtype=self.torch_dtype, + ), + ) + if bias: + self.register_buffer("bias", torch.zeros((out_features), dtype=self.torch_dtype)) + else: + self.bias = None + + def _configure_bitblas_matmul( + self, + A_dtype, + W_dtype, + accum_dtype, + out_dtype, + with_scaling, + with_zeros, + zeros_mode, + enable_tuning, + fast_decoding, + bias, + propagate_b, + ): + matmul_config = MatmulConfig( + M=self.opt_M, + N=self.out_features, + K=self.in_features, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + storage_dtype=self.STORAGE_DTYPE, + with_scaling=with_scaling, + with_zeros=with_zeros, + group_size=self.group_size, + fast_decoding=fast_decoding, + with_bias=bias, + propagate_b=propagate_b, + zeros_mode=zeros_mode, + ) + self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, enable_tuning) + self.bits = self.bitblas_matmul.bit + self.source_format = self.bitblas_matmul.source_format + + def _get_or_create_bitblas_operator(self, config, enable_tuning): + if global_operator_cache.size() == 0: + global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + logger.info(f"Loaded {global_operator_cache.size()} operators from database.") + + bitblas_matmul = global_operator_cache.get(config) + if bitblas_matmul is None: + # should disable tuning for the first time because we may require loading bitblas operator from database. + bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) + if enable_tuning: + bitblas_matmul.hardware_aware_finetune(topk=20) + global_operator_cache.add(config, bitblas_matmul) + global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) + print("BitBLAS Tuning done, appended operator to global_operator_cache.") + else: + print("BitBLAS Operator created.") + else: + print("BitBLAS Operator found in global_operator_cache.") + return bitblas_matmul + + def warmup(self, topk=20): + self.bitblas_matmul.hardware_aware_finetune(topk=topk) + + def forward(self, A, output=None): + if A.dtype != torch.float16: + A = A.half() + # can be lifted to post init. + self.init_params() + + if output is None: + output = torch.empty( + A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) + m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) + A = self.bitblas_matmul.transform_input(A) + stream = torch.cuda.current_stream() + + A_void = ctypes.c_void_p(A.data_ptr()) + stream_handle = ctypes.c_void_p(stream.cuda_stream) + # m is the product of the last n - 1 dimensions of A + self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, + stream_handle) + + return output + + def load_and_transform_weight( + self, + weight: torch.Tensor, + scales: torch.Tensor = None, + zeros: torch.Tensor = None, + bias: torch.Tensor = None, + ): + if self.consistent: + assert scales is None, "scales should be None for consistent mode." + assert zeros is None, "zeros should be None for consistent mode." + weight = self.bitblas_matmul.transform_weight(weight) + self.weight = nn.Parameter(weight) + if bias is not None: + self.bias = bias + else: + weight = self.bitblas_matmul.transform_weight(weight) + self.qweight = weight + if scales is not None: + self.scales = scales + if zeros is not None: + self.zeros = zeros + if bias is not None: + self.bias = bias + + def repack_from_gptq(self, gptq_module): + # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. + qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) + if self.bitblas_matmul.weight_transform is not None: + qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() + self.qweight = qweight + # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. + scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) + self.scales = scales + # qzeros should be dequantized to int zeros. + intzeros = unpack_qzeros(gptq_module.qzeros, self.bits).T.contiguous() + if self.bitblas_matmul.config.zeros_mode == "original": + self.zeros = intzeros.to(torch.float16).contiguous() + elif self.bitblas_matmul.config.zeros_mode == "rescale": + self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] + elif self.bitblas_matmul.config.zeros_mode == "quantized": + self.zeros = ( + torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to( + self.qweight.device).to(self.zeros.dtype).contiguous()) + else: + raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}") + if self.bias is not None: + self.bias = gptq_module.bias.data.to(torch.float16).contiguous() + + @property + def consistent(self): + return self.is_consitent + + +__all__ = ["Linear"] diff --git a/bitblas/ops/__init__.py b/bitblas/ops/__init__.py new file mode 100644 index 000000000..cdacc5bad --- /dev/null +++ b/bitblas/ops/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .operator import Operator # noqa: F401 +from .matmul import Matmul, MatmulConfig # noqa: F401 +from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 +from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 diff --git a/bitblas/ops/general_matmul.py b/bitblas/ops/general_matmul.py new file mode 100644 index 000000000..af2da3f02 --- /dev/null +++ b/bitblas/ops/general_matmul.py @@ -0,0 +1,588 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +import operator +from functools import reduce +from bitblas.base.roller.arch.cuda import CUDA +from typing import Any, Literal, Optional, Tuple, Union +from .operator import Operator, TransformKind, OPExecutorCPU +from .impl.matmul_dequantize_impl import ( + select_implementation as weight_dequantize_implementation,) +from .impl.matmul_impl import select_implementation as consistent_implementation +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils.target_detector import auto_detect_nvidia_target +from dataclasses import dataclass +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +import logging +import torch + +logger = logging.getLogger(__name__) + +WORKSPACE_SIZE = 1024 * 1024 * 256 + +# TODO(lei): This should be improved into a general +# Method to get the consistent compute patterns. +NATIVE_COMPUTE_PATTERNS = [ + # A_dtype, W_dtype + ("float64", "float64"), + ("float32", "float32"), + ("float16", "float16"), + ("int8", "int8"), + ("e4m3_float8", "e4m3_float8"), + ("e4m3_float8", "e5m2_float8"), + ("e5m2_float8", "e4m3_float8"), + ("e5m2_float8", "e5m2_float8"), +] + + +def is_native_compute(A_dtype, W_dtype) -> bool: + return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS + + +@dataclass(frozen=True) +class MatmulConfig: + M: Union[int, Tuple[int]] = None + N: int = None + K: int = None + A_dtype: str = "float16" + # is a wrapper for source_format and bit + W_dtype: str = A_dtype # W_dtype is the same as A_dtype by default + out_dtype: str = "float16" + accum_dtype: str = "float16" + layout: Literal["nn", "nt", "tn", "tt"] = "nt" + with_bias: bool = False + group_size: int = -1 + with_scaling: bool = False + with_zeros: bool = False + # documents for zeros_mode: + # original: target = (dequantize_weight - zero_point) * scale + # rescale: target = dequantize_weight * scale - zero_point + # quantized: target = (dequantize_weight - dequantize_zeros) * scale + # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. + zeros_mode: Literal["original", "rescale", "quantized"] = "original" + storage_dtype: str = "int8" + + # weight transform related flags + fast_decoding: Optional[bool] = ( + None # enable fast decoding by default, if not specified, it is enabled by a rule. + ) + propagate_a: Optional[TransformKind] = ( + None # propagate_a is a flag to control the ladder permutation. + ) + propagate_b: Optional[TransformKind] = ( + None # propagate_b is a flag to control the ladder permutation + ) + + def __legalize_dynamic_symbolic(self, M): + return tuple(self.M) if isinstance(self.M, list) else self.M + + def __legalize_propagate(self, propagate): + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + return propagate + + def __initialize_propagate(self, propagate_a: Optional[TransformKind], + propagate_b: Optional[TransformKind]): + MICRO_KERNEL_SIZE = 16 + if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and + (self.K % MICRO_KERNEL_SIZE) == 0): + object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) + else: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + + if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or + isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")): + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + else: + object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform) + + # set a and b value if is not None + if propagate_a is not None: + object.__setattr__(self, "propagate_a", propagate_a) + if propagate_b is not None: + object.__setattr__(self, "propagate_b", propagate_b) + + # TODO(lei): This is a limitation arose by pytorch and llvm + # Should be removed in the future. + if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) + + def __initialize_zeros_mode(self, zeros_mode: Optional[str]): + if zeros_mode is None: + object.__setattr__(self, "zeros_mode", "original") + + def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): + + def is_not_fast_decoding_supported(): + conditions = [] + conditions.append("int" not in self.W_dtype) + conditions.append(self.W_dtype == self.A_dtype) + # int8,uint8 also do not implement and also do not require fast decoding + conditions.append(self.W_dtype in ["int8", "uint8"]) + return any(conditions) + + if fast_decoding is not None: + object.__setattr__(self, "fast_decoding", fast_decoding) + elif is_not_fast_decoding_supported(): + object.__setattr__(self, "fast_decoding", False) + else: + object.__setattr__(self, "fast_decoding", True) + + def __post_init__(self): + # set M to default dynamic range if it is None + if self.M is None: + object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) + if self.N is None: + raise ValueError("N should be specified currently.") + if self.K is None: + raise ValueError("K should be specified currently.") + + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) + + # set propagate_a and propagate_b to default value if it is None + object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) + object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) + + # This is hack to legalize propagate_a and b + # TODO(lei): should be removed in the future when tc+br template is ready. + self.__initialize_propagate(self.propagate_a, self.propagate_b) + + self.__initialize_zeros_mode(self.zeros_mode) + + self.__initialize_fast_decoding(self.fast_decoding) + + if self.with_bias is None: + object.__setattr__(self, "with_bias", False) + + if self.group_size is None: + object.__setattr__(self, "group_size", -1) + + if self.with_scaling is None: + object.__setattr__(self, "with_scaling", False) + + if self.with_zeros is None: + object.__setattr__(self, "with_zeros", False) + + if self.A_dtype == self.W_dtype and self.W_dtype in [ + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", + ]: + object.__setattr__(self, "storage_dtype", self.W_dtype) + + +class Matmul(Operator): + + # TODO(lei): This should be improved into a general datatype. + BITBLAS_TRICK_DTYPE_MAP = { + "float64": ("fp", 64), + "float32": ("fp", 32), + "float16": ("fp", 16), + "int32": ("int", 32), + "uint32": ("uint", 32), + "int16": ("int", 16), + "uint16": ("uint", 16), + "int8": ("int", 8), + "uint8": ("uint", 8), + "int4": ("int", 4), + "uint4": ("uint", 4), + "int2": ("int", 2), + "uint2": ("uint", 2), + "int1": ("int", 1), + "uint1": ("uint", 1), + "nf4": ("nf", 4), + "fp4_e2m1": ("fp", 4), + "e4m3_float8": ("fp_e4m3", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" + "e5m2_float8": ("fp_e5m2", 8), + } + + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Optional[Union[str, Target]] = None, + enable_tuning: bool = True, + from_database: bool = False, + ): + # if from database, we should disable default schedule + # to save compilation time + if target is None: + target = auto_detect_nvidia_target() + logger.info(f"Auto detected target: {target}") + assert (config.A_dtype + in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" + source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] + + self.source_format = source_format + self.bit = bit + super().__init__(name, config, target) + + if source_format == "int" and self.with_zeros: + logger.warning( + "[BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already." + ) + + target = self.target + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + + self.arch = CUDA(target) + + if isinstance(self.M, Tuple): + self.dynamic_range = {"m": self.M} + self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( + {"opt_shapes": self.dynamic_range}) + else: + self.dynamic_range = None + + if not from_database: + self._build_default_module(target) + + self.workspace = None + if self.propagate_a: + # for general purpose, we use propagate_a to control the ladder permutation. + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.A_dtype, + storage_dtype=self.A_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=self.propagate_a, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=target, + enable_tuning=enable_tuning, + ) + self.workspace = torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.A_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + propagate_kind="B", + transpose_matrix=self.layout == "nt", + transform_kind=self.propagate_b, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + if self.fast_decoding: + assert self.source_format in ["int", "uint"] + lop3_permutate_config = LOP3PermutateConfig( + M=self.N, + N=self.K, + datatype=self.A_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + ) + self.lop3_permutate = LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.lop3_permutate = None + + input_executors = OPExecutorCPU() + if self.ladder_permutate_a is not None: + input_executors.append(self.ladder_permutate_a) + self.input_executors = input_executors + + weight_executors = OPExecutorCPU() + if self.lop3_permutate is not None: + weight_executors.append(self.lop3_permutate) + + if self.ladder_permutate_b is not None: + weight_executors.append(self.ladder_permutate_b) + + self.weight_executors = weight_executors + + if enable_tuning: + self.hardware_aware_finetune() + + if source_format == "nf": + self.lut = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=getattr(torch, self.A_dtype), + ).cuda() + else: + self.lut = None + + # output data type + self.torch_output_dtype = getattr(torch, self.out_dtype) + + def _build_default_module(self, target: Target): + try: + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + except Exception: + self.optimized_func = None + logger.warning( + "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." + ) + + self._build_runtime_module(target) + + def _select_implementation(self): + if is_native_compute(self.A_dtype, self.W_dtype): + return consistent_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + else: + return weight_dequantize_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def transform_weight(self, weight, scale=None, zeros=None, bias=None): + """ + Transforms the given weight tensor based on the specified quantization parameters and + returns the transformed weight along with optional scale, zeros, and bias. + + Parameters: + - weight: The input weight tensor to be transformed. + - scale: Optional scaling factor for the weight tensor. + - zeros: Optional zero-point adjustment for the weight tensor. + - bias: Optional bias to be added to the weight tensor. + + Returns: + A list containing the transformed weight tensor and optionally the scale, zeros, and bias. + """ + weight = weight.contiguous() + if self.W_dtype == self.A_dtype: + if self.weight_transform is not None: + return self.weight_transform(weight.cpu()).cuda().contiguous() + return weight + + from bitblas.quantization import general_compress + import torch + import numpy as np + + source_format, bit = self.source_format, self.bit + + # Process integer source format + if source_format == "int" and bit < 8: + assert not self.with_scaling, "scale should be False for int source format" + assert not self.with_zeros, "zeros should be False for int source format" + maxq = 2**(bit - 1) + # Clamp weight values to be within the quantizable range and adjust + weight = torch.clamp(weight, -maxq, maxq).int() + maxq + elif source_format in ["fp_e5m2", "fp_e4m3"]: + weight = weight.view(torch.int8) + weight = weight.int() + else: + # For non-integer formats, simply convert weights to integers + weight = weight.int() + + np_storage_dtype = getattr(np, self.storage_dtype) + + weight = general_compress( + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) + + weight = torch.from_numpy(weight).cuda().contiguous() + + # Apply an optional weight transformation if specified + if self.weight_transform is not None: + weight = self.weight_transform(weight.cpu()).cuda().contiguous() + + # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias + result = [weight] + if scale is not None: + result.append(scale) + if zeros is not None: + result.append(zeros) + if bias is not None: + result.append(bias) + + return next(iter(result), result) + + def transform_input(self, input_tensor): + if self.propagate_a is not TransformKind.NonTransform: + # check workspace size + if input_tensor.numel() > WORKSPACE_SIZE: + raise ValueError( + f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." + ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) + return self.workspace + return input_tensor + + def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: + args = [] + args.append(self.transform_input(A)) + args.append(W) + + if self.lut is not None: + args.append(self.lut) + + if output is None: + output = torch.empty( + A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) + if scale is not None: + args.append(scale) + if zeros is not None: + args.append(zeros) + if bias is not None: + args.append(bias) + args.append(output) + + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + + stream = torch.cuda.current_stream() + + if self.lib is None: + self._forward_from_torch_func(*args) + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + + return output + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def A_dtype(self): + return self.config.A_dtype + + @property + def W_dtype(self): + return self.config.W_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def with_scaling(self): + return self.config.with_scaling + + @property + def with_zeros(self): + return self.config.with_zeros + + @property + def group_size(self): + return self.config.group_size + + @property + def fast_decoding(self): + return self.config.fast_decoding + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def layout(self): + return self.config.layout + + @property + def zeros_mode(self): + return self.config.zeros_mode + + @property + def input_transform(self): + return self.input_executors if self.input_executors.size else None + + @property + def weight_transform(self): + return self.weight_executors if self.weight_executors.size else None diff --git a/bitblas/ops/general_matmul_splitk.py b/bitblas/ops/general_matmul_splitk.py new file mode 100644 index 000000000..28e3cbbf2 --- /dev/null +++ b/bitblas/ops/general_matmul_splitk.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +import operator +from functools import reduce +from typing import Any, Optional, Union +from .operator import TransformKind +from .impl.matmul_splitk_impl import select_implementation as consistent_implementation +from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation +from dataclasses import dataclass +import logging +import torch +from .general_matmul import MatmulConfig, Matmul +from .general_matmul import is_native_compute + +logger = logging.getLogger(__name__) + +WORKSPACE_SIZE = 1024 * 1024 * 256 + + +@dataclass(frozen=True) +class MatmulConfigWithSplitK(MatmulConfig): + k_split: int = 1 # split K dimension + + +class MatmulWithSplitK(Matmul): + + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Optional[Union[str, Target]] = None, + enable_tuning: bool = True, + from_database: bool = False, + ): + super().__init__(config, name, target, enable_tuning, from_database) + + def _select_implementation(self): + # the major implementation + if is_native_compute(self.A_dtype, self.W_dtype): + return consistent_implementation( + SplitK=self.k_split, + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + else: + return weight_dequantize_implementation( + SplitK=self.k_split, + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.A_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def transform_weight(self, weight, scale=None, zeros=None, bias=None): + """ + Transforms the given weight tensor based on the specified quantization parameters and + returns the transformed weight along with optional scale, zeros, and bias. + + Parameters: + - weight: The input weight tensor to be transformed. + - scale: Optional scaling factor for the weight tensor. + - zeros: Optional zero-point adjustment for the weight tensor. + - bias: Optional bias to be added to the weight tensor. + + Returns: + A list containing the transformed weight tensor and optionally the scale, zeros, and bias. + """ + weight = weight.contiguous() + if self.W_dtype == self.A_dtype: + if self.weight_transform is not None: + return self.weight_transform(weight.cpu()).cuda().contiguous() + return weight + + from bitblas.quantization import general_compress + import torch + import numpy as np + + source_format, bit = self.source_format, self.bit + + # Process integer source format + if source_format == "int" and bit < 8: + assert not self.with_scaling, "scale should be False for int source format" + assert not self.with_zeros, "zeros should be False for int source format" + maxq = 2**(bit - 1) + # Clamp weight values to be within the quantizable range and adjust + weight = torch.clamp(weight, -maxq, maxq).int() + maxq + elif source_format in ["fp_e5m2", "fp_e4m3"]: + weight = weight.view(torch.int8) + weight = weight.int() + else: + # For non-integer formats, simply convert weights to integers + weight = weight.int() + + np_storage_dtype = getattr(np, self.storage_dtype) + + weight = general_compress( + weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) + + weight = torch.from_numpy(weight).cuda().contiguous() + + # Apply an optional weight transformation if specified + if self.weight_transform is not None: + weight = self.weight_transform(weight.cpu()).cuda().contiguous() + + # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias + result = [weight] + if scale is not None: + result.append(scale) + if zeros is not None: + result.append(zeros) + if bias is not None: + result.append(bias) + + return next(iter(result), result) + + def transform_input(self, input_tensor): + if self.propagate_a is not TransformKind.NonTransform: + # check workspace size + if input_tensor.numel() > WORKSPACE_SIZE: + raise ValueError( + f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." + ) + self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) + return self.workspace + return input_tensor + + def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: + args = [] + args.append(self.transform_input(A)) + args.append(W) + + if self.lut is not None: + args.append(self.lut) + + if output is None: + output = torch.empty( + A.shape[:-1] + (self.N,), + dtype=self.torch_output_dtype, + device=A.device) + if scale is not None: + args.append(scale) + if zeros is not None: + args.append(zeros) + if bias is not None: + args.append(bias) + + sk_output = torch.empty((self.k_split,) + + A.shape[:-1] + (self.N,), + dtype=self.torch_output_dtype, + device=A.device) + args.append(sk_output) + + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + + stream = torch.cuda.current_stream() + + if self.lib is None: + self._forward_from_torch_func(*args) + self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) + torch.sum(sk_output, dim=0, out=output) + return output + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def k_split(self): + return self.config.k_split + + +__all__ = ["MatmulConfigWithSplitK", "MatmulWithSplitK"] diff --git a/bitblas/ops/impl/__init__.py b/bitblas/ops/impl/__init__.py new file mode 100644 index 000000000..a254dc7fb --- /dev/null +++ b/bitblas/ops/impl/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .lop3_permutate_impl import tir_interleave_weight diff --git a/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/bitblas/ops/impl/batch_matmul_dequantize_impl.py new file mode 100644 index 000000000..a3ab5ebef --- /dev/null +++ b/bitblas/ops/impl/batch_matmul_dequantize_impl.py @@ -0,0 +1,392 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16) + + +def matmul_nt_dequantize_b( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def decode_func(b, n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[b, n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[b, n, k // group_size] + + return w + + B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inverse_indexmap.initial_indices + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inverse_indexmap = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((Batch, N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder(( + Batch, + N, + ), name="Bias", dtype=in_dtype) + + def fcompute(b, i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (b, *spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (Batch, N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(b, n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[b, n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[b, n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[b, n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[b, n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[b, n, k // group_size]) * Scale[b, n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[b, n, k // group_size] - Zeros[b, n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind.value) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + Batch=1, + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + Batch, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind=propagate_b, + ) + else: + return matmul_nt_dequantize_b( + Batch, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/batch_matmul_impl.py b/bitblas/ops/impl/batch_matmul_impl.py new file mode 100644 index 000000000..1828ed15d --- /dev/null +++ b/bitblas/ops/impl/batch_matmul_impl.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.ops.operator import TransformKind + + +def matmul_nt( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) + B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (Batch, M, N), + lambda b, i, j: te.sum( + A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + Batch, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + raise ValueError("Currently only support layout=nt") + return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def select_implementation( + Batch=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/convolution2d_impl.py b/bitblas/ops/impl/convolution2d_impl.py new file mode 100644 index 000000000..d77d8f573 --- /dev/null +++ b/bitblas/ops/impl/convolution2d_impl.py @@ -0,0 +1,190 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, tir + + +def conv2d_nhwc_ohwi( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", +): + + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((f, kh, kw, c), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), + c,].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( + dilation_w), c].astype(accum_dtype), + axis=[kh, kw, c], + ), + name="C", + ) + args = [A, B] + last_output = C + if accum_dtype != out_dtype: + D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") + last_output = D + args.append(last_output) + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def conv2d_nhwc_hwio( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", +): + + A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) + B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) + + pad_shape = (n, h + 2 * p, w + 2 * p, c) + pad_value = tir.const(0.0, A.dtype) + pad = te.compute( + pad_shape, + lambda n, h, w, c: te.if_then_else( + tir.all( + h >= p, + w >= p, + h < pad_shape[1] - p, + w < pad_shape[2] - p, + ), + A[n, h - p, w - p, c], + pad_value, + ), + name="pad", + ) + kernel_h, kernel_w = kh, kw + stride_h, stride_w = s, s + dilation_h, dilation_w = d, d + out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 + out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 + out_shape = (n, out_h, out_w, f) + kh = te.reduce_axis((0, kernel_h), name="kh") + kw = te.reduce_axis((0, kernel_w), name="kw") + c = te.reduce_axis((0, c), name="c") + C = te.compute( + out_shape, + lambda n, h, w, f: te.sum( + pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), + c,].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( + dilation_w), c, f].astype(accum_dtype), + axis=[kh, kw, c], + ), + name="C", + ) + args = [A, B] + last_output = C + if accum_dtype != out_dtype: + D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") + last_output = D + args.append(last_output) + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def select_implementation( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype="float16", + accum_dtype="float16", + out_dtype="float16", + input_layout="nhwc", + weight_layout="ohwi", +): + assert input_layout in ["nhwc", "nchw"] + if input_layout == "nhwc" and weight_layout == "ohwi": + return conv2d_nhwc_ohwi( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype, + accum_dtype, + out_dtype, + ) + elif input_layout == "nhwc" and weight_layout == "hwio": + return conv2d_nhwc_hwio( + n, + f, + h, + w, + c, + kh, + kw, + s, + d, + p, + in_dtype, + accum_dtype, + out_dtype, + ) + else: + raise ValueError("Unsupported input_layout: {} and weight_layout: {}".format( + input_layout, weight_layout)) diff --git a/bitblas/ops/impl/ladder_permutate_impl.py b/bitblas/ops/impl/ladder_permutate_impl.py new file mode 100644 index 000000000..8086bf584 --- /dev/null +++ b/bitblas/ops/impl/ladder_permutate_impl.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from typing import Literal +from tvm import te, IRModule, DataType +from tvm.tir import IndexMap + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", + dequantize_bits: int = -1, + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", + propagate_kind: Literal["A", "B"] = "B", + transpose_matrix: bool = False, + transform_kind: int = 0, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 # noqa: E741 + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + intra_index_map, _ = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + + target_dtype = DataType(datatype) + scaling_factor = 1 + if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: + scaling_factor = ((target_dtype.bits // dequantize_bits) * DataType(storage_dtype).bits // + target_dtype.bits) + r = r // scaling_factor + initial_indices = intra_index_map.initial_indices + scaling_final_indices = intra_index_map.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + intra_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) + args = [inp] + + if transform_kind >= 1: + arg = args[-1] + + inter_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + lambda i, j, ii, jj: arg[i * l + ii, j * r + jj], + name="inter_warp_permutate", + ) + args.append(inter_warp) + if transform_kind >= 2: + arg = args[-1] + + def fcompute(*args): + warp_i, warp_j = args[-2:] + spatial_args = args[:-2] + permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, permutate_i, permutate_j) + return arg[new_index] + + intra_warp = te.compute( + (M // l, (N // scaling_factor) // r, l, r), + fcompute, + name="intra_warp_permutate", + ) + args.append(intra_warp) + args = [args[0], args[-1]] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/bitblas/ops/impl/lop3_permutate_impl.py b/bitblas/ops/impl/lop3_permutate_impl.py new file mode 100644 index 000000000..07d8f4f0c --- /dev/null +++ b/bitblas/ops/impl/lop3_permutate_impl.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Literal +from tvm import DataType +from tvm import IRModule +from tvm.ir import GlobalVar +from tvm.script import tir as T + + +# fmt: off +# TIR interleave weight impl-> 2D implementation +def tir_interleave_weight( + N: int = 2, + K: int = 16, + bits: int = 4, + QK: int = -1, + target_dtype: str = "float16", + storage_dtype: str = "int32", +): + if QK == -1: + QK = K * bits // 32 + bits_stride = DataType(target_dtype).bits + mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // bits + + @T.prim_func + def interleave_weight(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype)): + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + @T.prim_func + def interleave_weight_f16_2b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 + B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] + + @T.prim_func + def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_6 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_7 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 + B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 + B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1] + | B_tmp_6[v0, v1] + | B_tmp_7[v0, v1]) + + @T.prim_func + def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), + storage_dtype)): + B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") + for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): + with T.block("B_tmp"): + v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + offset = v2 * elems_per_group + v3 + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits + B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) + + for ax0, ax1 in T.grid(N, QK): + with T.block("B"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) + B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 + B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 + B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 + B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 + B[v0, v1] = ( + B_tmp_1[v0, v1] + | B_tmp_2[v0, v1] + | B_tmp_3[v0, v1] + | B_tmp_4[v0, v1] + | B_tmp_5[v0, v1]) + + if target_dtype == "float16" and bits == 2: + return interleave_weight_f16_2b + elif target_dtype == "float16" and bits == 1: + return interleave_weight_f16_1b + elif target_dtype == "int8" and bits == 1: + return interleave_weight_int8_1b + + return interleave_weight + + +# fmt: on + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16", "int8"] = "float16", + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32", + dequantize_bits: int = 4, +): + func = tir_interleave_weight( + N=M, + K=N, + bits=dequantize_bits, + target_dtype=datatype, + storage_dtype=storage_dtype, + ) + mod = IRModule() + mod.update_func(GlobalVar("main"), func) + return mod diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py new file mode 100644 index 000000000..d4aa02c84 --- /dev/null +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -0,0 +1,644 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te, DataType +from tvm.tir import IndexMap +from bitblas.ops.operator import TransformKind +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.quantization import ( + _tir_packed_int_to_int_convert, + _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, + _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16, + _tir_packed_to_unsigned_convert_with_zeros, +) + + +def matmul_nt_dequantize_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((K // group_size), N // storage_nbit * bit), + name="QZeros", + dtype=storage_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def qzeros_dequantize(k, n): + return _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=storage_dtype, + ) + + Dequantize_qzeros = None + if with_zeros and zeros_mode == "quantized": + Dequantize_qzeros = te.compute( + (K // group_size, N), + qzeros_dequantize, + name="Dequantize_zeros", + ) + + def decode_func(n, k): + if with_zeros and zeros_mode == "quantized": + assert Dequantize_qzeros is not None, "Dequantize_zeros is None" + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + Dequantize_qzeros[k // group_size, n], + dtype=in_dtype, + ) + elif source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + if zeros_mode == "quantized": + args.append(QZeros) + else: + args.append(Zeros) + if with_bias: + E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inverse_indexmap.initial_indices + scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inverse_indexmap = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), + name="C", + ) + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("weight_transform_kind", transform_kind.value) + return tvm.IRModule.from_expr(func) + + +def matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + qr = r * bit // storage_nbit + B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K // storage_nbit * bit), + fcompute, + name="B_reindex", + ) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "int": + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B_reindex[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype=in_dtype, + ) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B_reindex[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_zeros: + args.append(Zeros) + if with_bias: + E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "with_scaling": with_scaling, + "group_size": group_size, + } + }, + ) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_dequantize_b_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + raise NotImplementedError + elif propagate_b: + return matmul_nt_dequantize_b_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + transform_kind=propagate_b, + ) + else: + return matmul_nt_dequantize_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py new file mode 100644 index 000000000..afe241b65 --- /dev/null +++ b/bitblas/ops/impl/matmul_dequantize_splitk_impl.py @@ -0,0 +1,184 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, + _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, + _tir_u8_to_f8_e4m3_to_f16) + + +def matmul_nt_dequantize_b( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", +): + assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) + if not isinstance(M, int): + M = tvm.te.var("m") + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + if group_size == -1: + group_size = K + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def decode_func(n, k): + if source_format == "uint": + if bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "int": + if bit == 1: + # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif bit == 8: + # 8 bit does not need to be compressed + w = B[n, k].astype(in_dtype) + else: + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif source_format == "nf": + w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", # assume the index data type is int32 + )] + else: + raise ValueError("Unsupported source_format: {}".format(source_format)) + + if not with_scaling: + return w + + if not with_zeros: + return w * Scale[n, k // group_size] + + return w + + B_decode = te.compute((N, K), decode_func, name="B_decode") + # Describe the matrix multiplication in TE + RK = K // SplitK + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum( + A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), + axis=k), + name="C", + ) + D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + args = [A, B] + last_output = D + if source_format == "nf": + args.append(LUT) + if with_scaling: + args.append(Scale) + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") + last_output = E + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": fast_decoding, + "source_format": { + "bits": bit, + "format": source_format, + }, + "storage_dtype": storage_dtype, + "target_format": in_dtype, + "with_scaling": with_scaling, + "with_zeros": with_zeros, + "zeros_mode": zeros_mode, + "group_size": group_size, + } + }, + ) + return tvm.IRModule.from_expr(func) + + +def select_implementation( + SplitK=1, + M=None, + N=1024, + K=1024, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + layout="nt", + zeros_mode="original", + propagate_a=False, + propagate_b=False, +): + if layout == "nn": + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" + ) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul_nt_dequantize_b( + SplitK, + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + bit, + storage_dtype, + source_format, + with_scaling, + with_zeros, + group_size, + fast_decoding, + with_bias, + zeros_mode, + ) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/matmul_impl.py b/bitblas/ops/impl/matmul_impl.py new file mode 100644 index 000000000..69b426354 --- /dev/null +++ b/bitblas/ops/impl/matmul_impl.py @@ -0,0 +1,356 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.gpu.matmul_analysis import get_propagate_map +from bitblas.ops.operator import TransformKind + + +def matmul_nn( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((K, N), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def matmul_nt_propagate_a( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("weight_transform_kind", transform_kind.value) + + return tvm.IRModule.from_expr(func) + + +def matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, + transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, +): + if not isinstance(M, int): + M = tvm.te.var("m") + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) + B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_input >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return A[new_index] + + A_reindex = te.compute( + (M, K), + fcompute, + name="A_reindex", + ) + + _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind_weight >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return B[new_index] + + B_reindex = te.compute( + (N, K), + fcompute, + name="B_reindex", + ) + # Describe the matrix multiplication in TE + k = te.reduce_axis((0, K), name="k") + C = te.compute( + (M, N), + lambda i, j: te.sum( + A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), + axis=k, + ), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + func = func.with_attr("input_transform_kind", transform_kind_input.value) + func = func.with_attr("weight_transform_kind", transform_kind_weight.value) + + return tvm.IRModule.from_expr(func) + + +def select_implementation( + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + return matmul_nt_propagate_a_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind_input=propagate_a, + transform_kind_weight=propagate_b, + ) + elif propagate_a: + return matmul_nt_propagate_a( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_a, + ) + elif propagate_b: + return matmul_nt_propagate_b( + M, + N, + K, + in_dtype, + out_dtype, + accum_dtype, + with_bias, + transform_kind=propagate_b, + ) + else: + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/matmul_splitk_impl.py b/bitblas/ops/impl/matmul_splitk_impl.py new file mode 100644 index 000000000..c437f64cb --- /dev/null +++ b/bitblas/ops/impl/matmul_splitk_impl.py @@ -0,0 +1,94 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# pre-transformed tir expression of matmul +import tvm +from tvm import te +from bitblas.ops.operator import TransformKind + + +def matmul_nt( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, +): + if not isinstance(M, int): + M = tvm.te.var("m") + A = te.placeholder((M, K), name="A", dtype=in_dtype) + B = te.placeholder((N, K), name="B", dtype=in_dtype) + Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) + + # Describe the matrix multiplication in TE + RK = K // SplitK + k = te.reduce_axis((0, RK), name="k") + C = te.compute( + (SplitK, M, N), + lambda sk, i, j: te.sum( + A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k), + name="C", + ) + last_output = C + if accum_dtype != out_dtype: + D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") + last_output = D + + if with_bias: + E = te.compute((SplitK, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") + last_output = E + + args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] + + func = te.create_prim_func(args) + + return tvm.IRModule.from_expr(func) + + +def matmul( + SplitK, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", +): + if layout == "nn": + raise ValueError("Currently only support layout=nt") + return matmul_nt(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) + + +def select_implementation( + SplitK=1, + M=None, + N=16384, + K=16384, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + with_bias=False, + layout="nt", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, +): + if layout == "nn": + if propagate_a or propagate_b: + raise ValueError( + "Currently only support propagate_a=False and propagate_b=False for layout=nn") + return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + elif layout == "nt": + if propagate_a and propagate_b: + raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") + elif propagate_a: + raise ValueError("Currently only support propagate_a=False for layout=nt") + elif propagate_b: + raise ValueError("Currently only support propagate_b=False for layout=nt") + else: + return matmul(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) + else: + raise ValueError(f"Unsupported layout: {layout}") diff --git a/bitblas/ops/impl/param_permutate_impl.py b/bitblas/ops/impl/param_permutate_impl.py new file mode 100644 index 000000000..4ecb17709 --- /dev/null +++ b/bitblas/ops/impl/param_permutate_impl.py @@ -0,0 +1,56 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas.gpu.matmul_analysis import get_propagate_map +from ..operator import TransformKind +from typing import Literal +from tvm import te, IRModule + + +def select_implementation( + M: int, + N: int, + datatype: Literal["float16"] = "float16", + transpose_matrix: bool = True, + group_size: int = -1, + propagate_kind: TransformKind = TransformKind.NonTransform, + target_instruction: Literal["nvidia-mma"] = "nvidia-mma", +): + if target_instruction != "nvidia-mma": + raise ValueError("Currently only support nvidia-mma instruction") + if propagate_kind < TransformKind.IntraWarpTransform: + raise ValueError("Currently only support propagate_kind >= IntraWarpTransform") + if transpose_matrix is not True: + raise ValueError("Currently only support transpose_matrix == True") + # This is trick to get the basic tile size for the current datatype + # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 + l = r = 16 # noqa: E741 + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + if group_size == -1: + group_size = N + + intra_index_map, inverse_indexmap = get_propagate_map( + transpose_matrix, dtype=datatype, matrix_name=propagate_kind) + + inp = te.placeholder((M, N // group_size), name="inp", dtype=datatype) + + def fcompute(n, k): + rl, rr = n, k + warp_i, warp_j = rl % l, rr % r + spatial_i, spatial_j = rl // l, rr // r + if propagate_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) + new_index = (spatial_i * l + warp_i, (spatial_j * r + warp_j) // group_size) + return inp[new_index] + + inp_prmt = te.compute( + (M, N // group_size), + fcompute, + name="intra_warp_permutate", + ) + + args = [inp, inp_prmt] + + func = te.create_prim_func(args) + + return IRModule.from_expr(func) diff --git a/bitblas/ops/ladder_permutate.py b/bitblas/ops/ladder_permutate.py new file mode 100644 index 000000000..70999b09d --- /dev/null +++ b/bitblas/ops/ladder_permutate.py @@ -0,0 +1,97 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +from typing import Literal, Union +from .operator import Operator +from .impl.ladder_permutate_impl import select_implementation +from dataclasses import dataclass + + +@dataclass(frozen=True) +class LadderPermutateConfig: + M: int + N: int + datatype: Literal["int8", "e4m3_float8", "e5m2_float8"] = "float16" + dequantize_bits: int = -1 + storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16" + propagate_kind: Literal["A", "B"] = "B" # "A" or "B" + transpose_matrix: bool = False + transform_kind: int = 2 # 0: none, 1: inter_warp 2: intra_warp + target_instruction: Literal["nvidia-mma"] = ( + "nvidia-mma" # maybe extend to "cdna-mfma" in future. + ) + + +class LadderPermutate(Operator): + + def __init__( + self, + config: LadderPermutateConfig, + name: str = "permutate", + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. + enable_tuning: bool = False, + from_database: bool = False, + ): + # consider to warp the arguments to MatmulConfig + super().__init__(name, config, target) + + target = self.target + if target.kind.name == "cuda": + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + if enable_tuning: + self.hardware_aware_finetune() + if not from_database: + self._build_runtime_module(target) + + # select implementation based on the Operator config + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + dequantize_bits=self.dequantize_bits, + storage_dtype=self.storage_dtype, + propagate_kind=self.propagate_kind, + transpose_matrix=self.transpose_matrix, + transform_kind=self.transform_kind, + target_instruction=self.target_instruction, + ) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def dequantize_bits(self): + return self.config.dequantize_bits + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def propagate_kind(self): + return self.config.propagate_kind + + @property + def transpose_matrix(self): + return self.config.transpose_matrix + + @property + def transform_kind(self): + return self.config.transform_kind + + @property + def target_instruction(self): + return self.config.target_instruction + + +__all__ = ["LadderPermutate", "LadderPermutateConfig"] diff --git a/bitblas/ops/lop3_permutate.py b/bitblas/ops/lop3_permutate.py new file mode 100644 index 000000000..867432a5e --- /dev/null +++ b/bitblas/ops/lop3_permutate.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +from typing import Literal, Union +from .operator import Operator +from .impl.lop3_permutate_impl import select_implementation +from dataclasses import dataclass +import torch + + +@dataclass(frozen=True) +class LOP3PermutateConfig: + M: int + N: int + datatype: Literal["float16", "int8"] = "float16" + storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32" + dequantize_bits: int = 4 + + +class LOP3Permutate(Operator): + + def __init__( + self, + config: LOP3PermutateConfig, + name: str = "permutate", + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. + ): + # consider to warp the arguments to MatmulConfig + super().__init__(name, config, target) + + if target.kind.name != "llvm": + raise ValueError("Currently only support llvm target for Permutation") + + self.target = target + self._build_runtime_module(target) + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + dequantize_bits=self.dequantize_bits, + ) + + def forward(self, weight, res): + # reinterpret the input tensor to int32 format + args = [arg.view(torch.int32) for arg in [weight, res]] + self.torch_func(*args) + return args[-1].view(weight.dtype) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def dequantize_bits(self): + return self.config.dequantize_bits + + +__all__ = ["LOP3Permutate", "LOP3PermutateConfig"] diff --git a/bitblas/ops/matmul.py b/bitblas/ops/matmul.py new file mode 100644 index 000000000..7783c4972 --- /dev/null +++ b/bitblas/ops/matmul.py @@ -0,0 +1,288 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +import numpy as np +from tvm.target import Target +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +from typing import List, Union, Optional, Any, Tuple +from .operator import Operator, TransformKind +from .impl.matmul_impl import select_implementation +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from dataclasses import dataclass +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +import logging + +logger = logging.getLogger(__name__) + + +class TransformExecutorCPU: + + def __init__(self, operators: Optional[List[Operator]] = None): + if operators is None: + operators = [] + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) + + +@dataclass(frozen=True) +class MatmulConfig: + M: Union[int, Tuple[int]] + N: int + K: int + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + with_bias: bool = False + # layout of matrix A and B + # "nn": C[i, j] = A[i, k] * B[k, j] + # "nt": C[i, j] = A[i, k] * B[j, k] + layout: str = "nt" + # weight transformation kind of matrix A + propagate_a: TransformKind = TransformKind.NonTransform + # weight transformation kind of matrix B + propagate_b: TransformKind = TransformKind.NonTransform + + def __post_init__(self): + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) + if isinstance(self.propagate_a, bool): + object.__setattr__( + self, + "propagate_a", + (TransformKind.IntraWarpTransform + if self.propagate_a else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_a, int): + object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) + + if isinstance(self.propagate_b, bool): + object.__setattr__( + self, + "propagate_b", + (TransformKind.IntraWarpTransform + if self.propagate_b else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_b, int): + object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) + + +class Matmul(Operator): + + def __init__( + self, + config: MatmulConfig, + name: str = "matmul", + target: Union[str, Target] = "cuda", + enable_tuning: bool = False, + from_database: bool = False, + ): + super().__init__(name, config, target) + target = self.target + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + + if isinstance(self.M, Tuple): + self.dynamic_range = {"m": self.M} + self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) + else: + self.dynamic_range = None + + if not from_database: + self._build_default_module(target) + + if self.propagate_a: + assert (self.propagate_a is + TransformKind.NonTransform), "Currently only support NonTransform for input" + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=self.propagate_a, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="B", + transpose_matrix=(self.layout == "nt"), + transform_kind=self.propagate_b, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + input_executors = TransformExecutorCPU() + if self.ladder_permutate_a is not None: + input_executors.append(self.ladder_permutate_b) + + self.input_executors = input_executors + + weight_executors = TransformExecutorCPU() + if self.ladder_permutate_b is not None: + weight_executors.append(self.ladder_permutate_b) + + self.weight_executors = weight_executors + + if enable_tuning: + self.hardware_aware_finetune() + + def _build_default_module(self, target: Target): + try: + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + except Exception: + self.optimized_func = None + logger.warning( + "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." + ) + + self._build_runtime_module(target) + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + with_bias=self.with_bias, + layout=self.layout, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + + def _profile_latency_with_dynamic_range(self) -> List: + func = self.prim_func_mod["main"] + device = self.arch.device + + def var_warpper(v, m): + if isinstance(v, tvm.tir.Var): + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return m + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + benchmark_latencies = [] + for m in self.dynamic_range["m"]: + profile_tensors = [] + for param in func.params: + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + profile_tensors.append( + tvm.nd.array( + np.random.uniform(0, 1, + [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), + device=device, + )) + self.profile_tensors = profile_tensors + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + benchmark_latencies.append({"m": m, "latency": latency}) + # ms + return benchmark_latencies + + def forward(self, *args) -> Any: + if self.lib is None: + self._forward_from_torch_func(*args) + dynamic_symbolic = [] + if self.dynamic_range is not None: + # assume we only have one dynamic range + m = args[0].shape[0] + dynamic_symbolic.append(m) + self._forward_from_prebuild_lib(*args, *dynamic_symbolic) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def in_dtype(self): + return self.config.in_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def layout(self): + return self.config.layout + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def input_transform(self): + return self.input_executors if self.input_executors.size else None + + @property + def weight_transform(self): + return self.weight_executors if self.weight_executors.size else None + + +__all__ = ["Matmul", "MatmulConfig"] diff --git a/bitblas/ops/matmul_dequantize.py b/bitblas/ops/matmul_dequantize.py new file mode 100644 index 000000000..25c68b121 --- /dev/null +++ b/bitblas/ops/matmul_dequantize.py @@ -0,0 +1,331 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from tvm.target import Target +from bitblas.base.roller.arch.cuda import CUDA +from typing import Any, List, Literal, Optional, Tuple, Union +from .operator import Operator, TransformKind +from .impl.matmul_dequantize_impl import select_implementation +from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +from dataclasses import dataclass +from .ladder_permutate import LadderPermutate, LadderPermutateConfig +from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig +import logging + +logger = logging.getLogger(__name__) + + +class OPExecutorCPU: + + def __init__(self, operators: Optional[List[Operator]] = None): + if operators is None: + operators = [] + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) + + +@dataclass(frozen=True) +class MatmulWeightOnlyDequantizeConfig: + M: Union[int, Tuple[int]] + N: int + K: int + in_dtype: str = "float16" + out_dtype: str = "float16" + accum_dtype: str = "float16" + bit: int = 4 + storage_dtype: str = "int8" + # documents for source_format: + # the format of the source data, which can be "int", "uint", "fp", "nf" + # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale + # where the fixed_zero_point is 2^(bit - 1) - 1 + # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale + # where the zero_point is manually set by zeros tensor + # "fp": dequantize_weight = (quantize_weight - zero_point) * scale + # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale + source_format: Literal["int", "uint", "fp", "nf"] = "int" + with_scaling: bool = False + with_zeros: bool = False + group_size: int = -1 + fast_decoding: bool = False + with_bias: bool = False + propagate_a: TransformKind = TransformKind.NonTransform + propagate_b: TransformKind = TransformKind.NonTransform + layout: str = "nt" + # documents for zeros_mode: + # original: target = (dequantize_weight - zero_point) * scale + # rescale: target = dequantize_weight * scale - zero_point + # quantized: target = (dequantize_weight - dequantize_zeros) * scale + # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. + zeros_mode: Literal["original", "rescale", "quantized"] = "original" + + def __post_init__(self): + # set M to tuple if it is list + # otherwise, M is not hashable + object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) + if isinstance(self.propagate_a, bool): + object.__setattr__( + self, + "propagate_a", + (TransformKind.IntraWarpTransform + if self.propagate_a else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_a, int): + object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) + + if isinstance(self.propagate_b, bool): + object.__setattr__( + self, + "propagate_b", + (TransformKind.IntraWarpTransform + if self.propagate_b else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_b, int): + object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) + + +class MatmulWeightOnlyDequantize(Operator): + + def __init__( + self, + config: MatmulWeightOnlyDequantizeConfig, + name: str = "matmul_weight_only_dequantize", + target: Target = "cuda", + enable_tuning: bool = False, + from_database: bool = False, + ): + super().__init__(name, config, target) + + target = self.target + if target.kind.name != "cuda": + raise ValueError("Currently only support cuda target") + + self.arch = CUDA(target) + + if isinstance(self.M, Tuple): + self.dynamic_range = {"m": self.M} + self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( + {"opt_shapes": self.dynamic_range}) + else: + self.dynamic_range = None + + if not from_database: + self._build_default_module(target) + + if self.propagate_a: + ladder_permutate_config = LadderPermutateConfig( + M=self.M, + N=self.K, + datatype=self.in_dtype, + storage_dtype=self.in_dtype, + propagate_kind="A", + transpose_matrix=False, + transform_kind=self.propagate_a, + ) + self.ladder_permutate_a = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_a = None + + if self.propagate_b: + ladder_permutate_config = LadderPermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + propagate_kind="B", + transpose_matrix=self.layout == "nt", + transform_kind=self.propagate_b, + ) + self.ladder_permutate_b = LadderPermutate( + config=ladder_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.ladder_permutate_b = None + + if self.fast_decoding: + lop3_permutate_config = LOP3PermutateConfig( + M=self.N, + N=self.K, + datatype=self.in_dtype, + dequantize_bits=self.bit, + storage_dtype=self.storage_dtype, + ) + self.lop3_permutate = LOP3Permutate( + config=lop3_permutate_config, + target=tvm.target.Target("llvm"), + ) + else: + self.lop3_permutate = None + + input_executors = OPExecutorCPU() + if self.ladder_permutate_a is not None: + input_executors.append(self.ladder_permutate_a) + self.input_executors = input_executors + + weight_executors = OPExecutorCPU() + if self.lop3_permutate is not None: + weight_executors.append(self.lop3_permutate) + + if self.ladder_permutate_b is not None: + weight_executors.append(self.ladder_permutate_b) + + self.weight_executors = weight_executors + + if enable_tuning: + self.hardware_aware_finetune() + + def _build_default_module(self, target: Target): + try: + self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) + except Exception: + self.optimized_func = None + logger.warning( + "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." + ) + + self._build_runtime_module(target) + + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + K=self.K, + in_dtype=self.in_dtype, + out_dtype=self.out_dtype, + accum_dtype=self.accum_dtype, + bit=self.bit, + storage_dtype=self.storage_dtype, + source_format=self.source_format, + with_scaling=self.with_scaling, + with_zeros=self.with_zeros, + group_size=self.group_size, + fast_decoding=self.fast_decoding, + with_bias=self.with_bias, + layout=self.layout, + zeros_mode=self.zeros_mode, + propagate_a=self.propagate_a, + propagate_b=self.propagate_b, + ) + + def post_process(self, code: str) -> str: + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + + def retrieve_weight_shape(self): + return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] + + def forward(self, *args) -> Any: + if self.lib is None: + self._forward_from_torch_func(*args) + dynamic_symbolic = [] + if self.dynamic_range is not None: + # assume we only have one dynamic range + m = args[0].shape[0] + dynamic_symbolic.append(m) + self._forward_from_prebuild_lib(*args, *dynamic_symbolic) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def K(self): + return self.config.K + + @property + def in_dtype(self): + return self.config.in_dtype + + @property + def out_dtype(self): + return self.config.out_dtype + + @property + def accum_dtype(self): + return self.config.accum_dtype + + @property + def bit(self): + return self.config.bit + + @property + def storage_dtype(self): + return self.config.storage_dtype + + @property + def source_format(self): + return self.config.source_format + + @property + def with_scaling(self): + return self.config.with_scaling + + @property + def with_zeros(self): + return self.config.with_zeros + + @property + def group_size(self): + return self.config.group_size + + @property + def fast_decoding(self): + return self.config.fast_decoding + + @property + def with_bias(self): + return self.config.with_bias + + @property + def propagate_a(self): + return self.config.propagate_a + + @property + def propagate_b(self): + return self.config.propagate_b + + @property + def layout(self): + return self.config.layout + + @property + def zeros_mode(self): + return self.config.zeros_mode + + @property + def input_transform(self): + return self.input_executors if self.input_executors.size else None + + @property + def weight_transform(self): + return self.weight_executors if self.weight_executors.size else None diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py new file mode 100644 index 000000000..90930d6d3 --- /dev/null +++ b/bitblas/ops/operator.py @@ -0,0 +1,367 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from abc import ABC, abstractmethod +import tvm +from tvm import IRModule +from tvm.target import Target +from tvm.tir import PrimFunc +from tvm.contrib.dlpack import to_pytorch_func +from tvm._ffi.base import _LIB, raise_last_ffi_error +from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode +import bitblas +import ctypes +from typing import List, Dict, Any, Optional +import numpy as np +from ..base import fast_tune, fast_tune_with_dynamic_range +from copy import deepcopy +from bitblas.base.roller.arch import get_arch +from bitblas.utils.tensor_adapter import tvm_tensor_to_torch +from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic +from dataclasses import dataclass +from enum import IntEnum +import logging + +logger = logging.getLogger(__name__) + + +class TransformKind(IntEnum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + + +@dataclass +class OperatorConfig: + """Base class for operator configurations. Used for typing.""" + + pass + + +class Operator(ABC): + + def __init__(self, name, config: OperatorConfig, target: Target = None): + if isinstance(target, str): + target = Target(target) + self.name = name + self.config = config + self.target = target + self.prim_func_mod = self._select_implementation() + self.optimized_func = None + self.rt_mod = None + self.time_evaluator = None + self.profile_tensors = None + self.arch = get_arch(target) if target else None + self.dynamic_range = None + self.pass_context: Dict = {} + self.num_args = len(self.prim_func.params) + self.function_handle = None + self.num_output_args: int = ( + 1 # todo(lei): should be analyzed from the prim_func. + ) + self.wrapper = None + self.src_name = None + self.lib_name = None + self.lib = None + + def get_source(self, target: Target = None) -> str: + if target is None: + target = self.target + if self.rt_mod is None: + self._build_runtime_module(target) + return self.rt_mod.imported_modules[0].get_source() if self.rt_mod else None + + def _build_runtime_module(self, target: Target): + """ + Builds the runtime module based on the architecture platform. + + This function attempts to build a runtime module (rt_mod) for the specified target. + If the platform is CUDA and an optimized function is available, it tries to build + using the optimized function with a specific pass context. Otherwise, it falls back + to building with the primary function. After successful build, it initializes a + time evaluator for performance measurement. + + Args: + target (Target): The compilation target specification. + + Returns: + The compiled runtime module or None if the build was unsuccessful. + """ + + # Initialize rt_mod as None to handle cases where build fails or is skipped + rt_mod = None + + # Check if the platform is CUDA and we have an optimized function + if self.arch.platform == "CUDA": + if self.optimized_func is None: + return None + + @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + def tvm_callback_cuda_postproc(code, _): + return self.post_process(code) + + try: + # Use a specific TVM pass context for CUDA platforms + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + **self.pass_context + }): + rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) + except Exception as e: + rt_build_error = e # noqa + logger.debug( + "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" + ) + else: + # For non-CUDA platforms or when no optimized function is available, build with the primary function + rt_mod = tvm.build(self.prim_func, target=target, name=self.name) + + # If the runtime module was successfully built, set up for evaluation + if rt_mod: + self.rt_mod = rt_mod + # Initialize a time evaluator with the built module, specifying the device and the number of runs + self.time_evaluator = rt_mod.time_evaluator( + rt_mod.entry_name, self.arch.device, number=10) + self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle + self.torch_func = to_pytorch_func(rt_mod) + if self.arch.platform == "CUDA": + try: + if (self.dynamic_range is not None and len(self.optimized_func.functions) > 1): + wrapper = CUDASourceWrapperWithDynamic(self.optimized_func, + self.get_source(target), self.arch) + else: + wrapper = CUDASourceWrapper(self.optimized_func, self.get_source(target), + self.arch) + wrapper.compile_lib() + self.wrapper = wrapper + self.src_name = self.wrapper.src_name + self.lib_name = self.wrapper.lib_name + self.lib = self.wrapper.load_lib() + self.lib.init() + except Exception as e: + build_runtime_library_error = e + logger.debug( + "Failed to build runtime library {}".format(build_runtime_library_error)) + + return rt_mod + + def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule: + mod_for_opt = deepcopy(func_mod) + with target: + optimized_mod = ( + bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable + bitblas.gpu.Matmul(), + bitblas.gpu.GEMV(), + bitblas.gpu.Reduction(), + bitblas.gpu.GeneralReduction(), + bitblas.gpu.Fallback(), + )(mod_for_opt)) + + if optimized_mod is not None: + return optimized_mod + return None + + def post_process(self, code: str) -> str: + return code + + def apply_fast_tuning(self, + func: PrimFunc, + target: Target, + topk: int = 20, + parallel_build=True) -> IRModule: + _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) + if best is not None: + return best.sch.mod + self.pass_context = best.config.pass_context + return None + + def apply_fast_tuning_with_dynamic_range( + self, + func: PrimFunc, + target: Target, + topk: int = 20, + dynamic_range: Dict[str, List[int]] = None, + ): + optimized_mod = fast_tune_with_dynamic_range( + func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) + if optimized_mod is not None: + return optimized_mod + return None + + def hardware_aware_finetune(self, + topk: int = 20, + target: tvm.target.Target = None, + parallel_build=True): + if target is None: + target = self.target + dynamic_range = self.dynamic_range + func = self.prim_func + if dynamic_range is not None: + self.optimized_func = self.apply_fast_tuning_with_dynamic_range( + func, target, topk, dynamic_range) + else: + self.optimized_func = self.apply_fast_tuning( + func, target, topk, parallel_build=parallel_build) + self._build_runtime_module(self.target) + + def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): + if dynamic_symbolic_constrains is None: + dynamic_symbolic_constrains = {} + func = self.prim_func + device = self.arch.device + + def var_warpper(v): + if isinstance(v, tvm.tir.Var): + if v.name in dynamic_symbolic_constrains: + return dynamic_symbolic_constrains[v.name] + assert "opt_shapes" in func.attrs + assert v.name in func.attrs["opt_shapes"] + return func.attrs["opt_shapes"][v.name].value + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + + profile_tensors = [] + for param in func.params: + if param not in func.buffer_map: + # in case of dynamic symbolic may in params + continue + arg = func.buffer_map[param] + numpy_dtype = map_numpy_type(arg.dtype) + profile_tensors.append( + tvm.nd.array( + np.random.uniform(0, 1, + [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), + device=device, + )) + self.profile_tensors = profile_tensors + return profile_tensors + + def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: + if dynamic_symbolic_constrains is None: + dynamic_symbolic_constrains = {} + profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency + + def _tensor_adapter(self, tensor, device): + import torch + from torch.utils.dlpack import to_dlpack + + if isinstance(tensor, tvm.te.Tensor): + return tensor + elif isinstance(tensor, torch.Tensor): + return tvm.runtime.ndarray.from_dlpack(to_dlpack(tensor)) + elif isinstance(tensor, np.ndarray): + return tvm.nd.array(tensor, device=device) + else: + raise RuntimeError("Not supported type: ", type(tensor)) + + def _forward_from_tvm_args(self, *args): + _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] + self.rt_mod(*_tvm_args) + + def _forward_from_tvm_nd_array(self, *args): + self.rt_mod(*args) + + def _forward_from_torch_func(self, *args): + # torch func is not reliable as some datatypes they don't support + # like float8. + self.torch_func(*args) + return args[-1] + + def forward(self, *args): + return self._forward_from_torch_func(*args) + + def _forward_from_prebuild_lib(self, *args, stream=0): + ctypes_args = [ + ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args + ] + ctypes_args.append(ctypes.c_void_p(stream)) + self.lib.call(*ctypes_args) + + def call_lib(self, *args, stream=0): + self.lib.call(*args, ctypes.c_void_p(stream)) + + def _forward_from_tvm_lib_func(self, values): + tcodes = (ctypes.c_int * self.num_args)() + ret_val = TVMValue() + ret_tcode = ctypes.c_int() + for i in range(self.num_args): + tcodes[i] = ArgTypeCode.NDARRAY_HANDLE + if (_LIB.TVMFuncCall( + self.function_handle, + values, + tcodes, + ctypes.c_int(self.num_args), + ctypes.byref(ret_val), + ctypes.byref(ret_tcode), + ) != 0): + raise_last_ffi_error() + + def __call__(self, *args: Any) -> Any: + return self.forward(*args) + + def update_func(self, func: PrimFunc): + self.prim_func_mod["main"] = func + + def update_runtime_module(self, rt_mod, src_name=None, lib_name=None): + self.rt_mod = rt_mod + self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) + self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle + self.torch_func = to_pytorch_func(rt_mod) + if src_name is not None: + self.src_name = src_name + if lib_name is not None: + self.lib_name = lib_name + self.lib = ctypes.CDLL(lib_name) + self.lib.init() + + @abstractmethod + def _select_implementation(self) -> IRModule: + pass + + @property + def prim_func(self): + return self.prim_func_mod["main"] + + +class OPExecutorCPU: + """ + A class to execute a sequence of operators on the CPU. + """ + + def __init__(self, operators: Optional[List[Operator]] = None): + if operators is None: + operators = [] + self.operators = operators + + def append(self, op): + self.operators.append(op) + + def is_none(self): + return len(self.operators) == 0 + + def forward(self, weight): + inputs = [weight] + for op in self.operators: + inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) + inputs = [op.forward(*inputs)] + return inputs[-1] + + def __call__(self, *args: Any, **kwds: Any) -> Any: + return self.forward(*args, **kwds) + + @property + def size(self): + return len(self.operators) diff --git a/bitblas/ops/param_permutate.py b/bitblas/ops/param_permutate.py new file mode 100644 index 000000000..ca28c86eb --- /dev/null +++ b/bitblas/ops/param_permutate.py @@ -0,0 +1,91 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.target import Target +from typing import Literal, Union +from .operator import Operator, TransformKind +from .impl.param_permutate_impl import select_implementation +from dataclasses import dataclass + + +@dataclass(frozen=True) +class ParamPermutateConfig: + M: int + N: int + datatype: Literal["float16"] = "float16" + transpose_matrix: bool = True + group_size: int = -1 + propagate_kind: TransformKind = TransformKind.NonTransform + target_instruction: Literal["nvidia-mma"] = ( + "nvidia-mma" # maybe extend to "cdna-mfma" in future. + ) + + def __post_init__(self): + if isinstance(self.propagate_kind, bool): + object.__setattr__( + self, + "propagate_kind", + (TransformKind.IntraWarpTransform + if self.propagate_kind else TransformKind.NonTransform), + ) + elif isinstance(self.propagate_kind, int): + object.__setattr__(self, "propagate_kind", TransformKind(self.propagate_kind)) + + +class ParamPermutate(Operator): + + def __init__( + self, + config: ParamPermutateConfig, + name: str = "permutate", + target: Union[str, Target] = "llvm", # assume to do permutation on cpu. + ): + super().__init__(name, config, target) + + if target.kind.name != "llvm": + raise ValueError("Currently only support llvm target for Permutation") + + self.target = target + self._build_runtime_module(target) + + # select implementation based on the Operator config + def _select_implementation(self): + return select_implementation( + M=self.M, + N=self.N, + datatype=self.datatype, + transpose_matrix=self.transpose_matrix, + group_size=self.group_size, + propagate_kind=self.propagate_kind, + target_instruction=self.target_instruction, + ) + + @property + def M(self): + return self.config.M + + @property + def N(self): + return self.config.N + + @property + def datatype(self): + return self.config.datatype + + @property + def propagate_kind(self): + return self.config.propagate_kind + + @property + def transpose_matrix(self): + return self.config.transpose_matrix + + @property + def group_size(self): + return self.config.group_size + + @property + def target_instruction(self): + return self.config.target_instruction + + +__all__ = ["ParamPermutate", "ParamPermutateConfig"] diff --git a/bitblas/quantization/__init__.py b/bitblas/quantization/__init__.py new file mode 100644 index 000000000..d29cb679a --- /dev/null +++ b/bitblas/quantization/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .quantization import ( + _tir_packed_int_to_int_convert, # noqa: F401 + _tir_packed_to_signed_convert, # noqa: F401 + _tir_packed_to_unsigned_convert, # noqa: F401 + _tir_u32_to_f4_to_f16, # noqa: F401 + _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 + _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 +) + +from .utils import gen_quant4, general_compress # noqa: F401 diff --git a/bitblas/quantization/quantization.py b/bitblas/quantization/quantization.py new file mode 100644 index 000000000..71ef224d7 --- /dev/null +++ b/bitblas/quantization/quantization.py @@ -0,0 +1,217 @@ +# Copyright 2018 The apache/tvm Authors. All Rights Reserved. +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Modifications Copyright (c) Microsoft. +# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. +# pylint: disable=invalid-name,missing-function-docstring,unused-variable +"""TIR computation utilities for quantization.""" + +import tvm +from tvm import tir + + +# fmt: off +def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): + mask = tir.const((1 << 16) - 1, "uint32") + res = [] + for data in [v0, v1]: + u32_val = tir.reinterpret("uint32", data) + if round_to_even: + rounding_bias = ((u32_val >> tir.const(16, "uint32")) + & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") + u32_val += rounding_bias + res.append((u32_val >> tir.const(16, "uint32")) & mask) + return res[0] | (res[1] << tir.const(16, "uint32")) + + +def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): + mask = tir.const((1 << 16) - 1, "uint32") + x0 = x & mask + x1 = (x >> 16) & mask + return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) + + +def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == "uint32" + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) + + +def _tir_packed_uint_to_uint_to_float(storage_nbit: int): + storage_dtype = "uint" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) - 1 + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float32" + val_u32 = tir.reinterpret("uint32", val) + # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) + # e_f32 == 120 -> e_f4 = 1 + # e_f32 < 120 -> e_f4 = 0 + m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") + e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") + s = (val_u32 >> tir.const(31, "uint32")) + e_f4 = tir.Select( + e_f32 > tir.const(120, "uint32"), + tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), + tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): + assert val.dtype == "float16" + val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) + m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") + e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") + s = (val_u32 >> tir.const(15, "uint32")) + e_f4 = tir.Select( + e_f16 > tir.const(8, "uint32"), + tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), + tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) + return (s << tir.const(3, "uint32")) | e_f4 + + +def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float32" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f32 = 0 + # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f32 = e_f4 | tir.const(120, "uint32") + val_f32 = tir.reinterpret("float32", + (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) + + +def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert nbit == 4 + assert dtype == "float16" + assert val.dtype == "uint32" + # e_f4 == 0 -> e_f16 = 0 + # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 + mask = tvm.tir.const((1 << nbit) - 1, "uint32") + f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask + s = f4 >> tir.const(3, "uint32") + e_f4 = f4 & tir.const(7, "uint32") + e_f16 = e_f4 | tir.const(8, "uint32") + val_f16 = tir.reinterpret("float16", + (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) + return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) + + +def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + e4 = val & tir.const(0x40, "uint16") + prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), + tir.const(0x4000, "uint16")) + e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix + return tir.reinterpret("float16", s_f16 | e_f16) + + +def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") + e4 = val & tir.const(0x40, "uint16") + e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) + e_f16 = e_f16 ^ tir.const(0x2000, "uint16") + return tir.reinterpret("float16", s_f16 | e_f16) + + +def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): + assert nbit == 8 + assert dtype == "float16" + return tir.reinterpret("e5m2_float8", val).astype("float16") + + +def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + max_int_value = (1 << (nbit - 1)) + return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( + (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) + + return f_convert + + +def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr, + dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) + return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype) + + return f_convert + + +def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): + storage_dtype = storage_type + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast( + dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + + +# fmt: on diff --git a/bitblas/quantization/utils.py b/bitblas/quantization/utils.py new file mode 100644 index 000000000..45890c3d8 --- /dev/null +++ b/bitblas/quantization/utils.py @@ -0,0 +1,110 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import numpy as np +import torch +import torch.nn as nn + + +def gen_quant4(k, n, groupsize=-1): + maxq = 2**4 + w = torch.randn((k, n), dtype=torch.half, device="cpu") + + original_w = w.clone() + + if groupsize == -1: + groupsize = k + + if groupsize != -1: + w = w.reshape((-1, groupsize, n)) + w = w.permute(1, 0, 2) + w = w.reshape((groupsize, -1)) + + s = torch.max(torch.abs(w), 0, keepdim=True)[0] + s *= 2 / maxq + + # Quantize. + w = torch.round(w / s).int() + + # Unsigned storage. + w += (maxq) // 2 + + w = torch.clamp(w, 0, maxq) + + # Dequantize. + ref = (w - (maxq) // 2).half() * s + + if groupsize != -1: + + def reshape(w): + w = w.reshape((groupsize, -1, n)) + w = w.permute(1, 0, 2) + w = w.reshape((k, n)).contiguous() + return w + + ref = reshape(ref) + w = reshape(w) + + s = s.reshape((-1, n)).contiguous() + linear = nn.Linear(k, n, bias=False) + linear.weight.data = ref.t() + + return original_w, linear, s, (w - (maxq) // 2) + + +def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): + elems_per_byte = 8 // source_bits + if lowprecision_weight.dtype == np.float16: + lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) + int8_weight = np.zeros( + ( + *lowprecision_weight.shape[:-1], + lowprecision_weight.shape[-1] // elems_per_byte, + ), + dtype=np.int8, + ) + for j in range(lowprecision_weight.shape[-1] // elems_per_byte): + for k in range(elems_per_byte): + int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) + + return int8_weight.view(storage_dtype) + + +# interleave weight numpy implementation +def interleave_weight(qweight, nbits=4, target_dtype="float16"): + assert target_dtype in ["float16", "int8"] + # reinterpret the data type of qweight to int32 + qweight = qweight.view(np.int32) + new_qweight = np.zeros_like(qweight) + bits_stride = 8 if target_dtype == "int8" else 16 + mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f + num_groups = 32 // bits_stride + elems_per_group = bits_stride // nbits + for i in range(num_groups): + for j in range(elems_per_group): + offset = i * elems_per_group + j + shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits + new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift + + if nbits == 1 and target_dtype == "int8": + # special handling for 1b interleave + n16_weight = new_qweight & np.int32(0xF0F00F0F) + n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 + n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 + n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 + n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 + return n16_weight.view(np.int8) + elif nbits == 2 and target_dtype == "float16": + n8_weight = new_qweight & np.int32(0xFF0000FF) + n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 + n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 + return n8_weight.view(np.int8) + elif nbits == 1 and target_dtype == "float16": + n8_weight = new_qweight & 0xF000000F + n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 + n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 + n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 + n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 + n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 + n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 + + return new_qweight.view(np.int8) diff --git a/bitblas/relax/op/interleave_weight.py b/bitblas/relax/op/interleave_weight.py new file mode 100644 index 000000000..98b1f5cd4 --- /dev/null +++ b/bitblas/relax/op/interleave_weight.py @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from tvm.relax.block_builder import BlockBuilder +from tvm.relax.expr import Call, Expr +from tvm.relax.transform.legalize_ops.common import register_legalize + +from bitblas.ops.impl import tir_interleave_weight + + +@register_legalize("bitblas.interleave_weight") +def _interleave_weight(bb: BlockBuilder, call: Call) -> Expr: + nbits = call.attrs.nbits + target_dtype = call.attrs.target_dtype + out_dtype = call.attrs.out_dtype + + return bb.call_te( + tir_interleave_weight(nbits, target_dtype, out_dtype), + call.args[0], + primfunc_name_hint="interleave_weight", + ) + + +__all__ = ["_interleave_weight"] diff --git a/bitblas/relax/transform/__init__.py b/bitblas/relax/transform/__init__.py new file mode 100644 index 000000000..b92f2c0b4 --- /dev/null +++ b/bitblas/relax/transform/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .annotate_decode_block import AnnotateDecodeInformation +from .weight_only_propagate import WeightOnlyLayoutPropagation diff --git a/bitblas/relax/transform/annotate_decode_block.py b/bitblas/relax/transform/annotate_decode_block.py new file mode 100644 index 000000000..601647839 --- /dev/null +++ b/bitblas/relax/transform/annotate_decode_block.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Dict, Tuple +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm import tir +from tvm.tir.schedule import BlockRV +from mlc_llm.quantization import quantization_schemes, GroupQuantizationSpec +from bitblas.gpu.gemv import is_gemv +from bitblas.gpu.matmul_analysis import ( + get_reduction_blocks, + get_index_map, + get_root_block, + get_dequantize_block, +) +from bitblas.base import ( + normalize_prim_func, + try_inline_contiguous_spatial, +) + + +# Define a module pass to annotate dequantization information +@module_pass(opt_level=0, name="AnnotateDecodeInformation") +class AnnotateDecodeInformation: + + def __init__(self, spec: str = "q4f16_0"): + # Validate and store the specified quantization scheme + if spec not in quantization_schemes: + raise ValueError(f"Quantization scheme {spec} not found") + self.quantize_scheme = quantization_schemes[spec] + + def detect_matmul(self, func: tir.PrimFunc) -> bool: + """Detect if the given function represents a matrix multiplication.""" + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + # Identify reduction blocks to infer matmul operations + reduction_blocks = get_reduction_blocks(sch, blocks) + if not reduction_blocks: + return False + + # Check for index map patterns typical of matmul operations + main_block = reduction_blocks[0] + main_block_stmt = sch.get(main_block) + index_maps = get_index_map(main_block_stmt) + _is_matmul = index_maps is not None + + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + block_info = block_infos[0] + _is_gemv = True + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + _is_gemv = False + if _is_gemv: + _is_gemv = is_gemv(sch, block_info) + return _is_matmul or _is_gemv + + def transform_module(self, mod: IRModule, _: PassContext) -> IRModule: + """Annotate dequantize information for all applicable functions in the module.""" + for g_var, func in mod.functions.items(): + if not isinstance(func, tir.PrimFunc) or g_var.name_hint == "main": + continue + + if not self.detect_matmul(func): + continue # Process only if matmul is detected + + sch = tir.Schedule(func) + root_block = get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + dequantize_block = get_dequantize_block(sch, blocks) + if dequantize_block is None: + continue # Skip if no dequantize block is found + + # Prepare dequantize info annotation + dequantize_info = self.prepare_dequantize_info(sch, dequantize_block) + + # Annotate function with dequantize information + mod[g_var] = func.with_attr("dequantize_info", dequantize_info) + return mod + + def prepare_dequantize_info(self, sch: tir.Schedule, dequantize_block: BlockRV) -> Dict: + """Generate dequantize information for a given block.""" + block_stmt = sch.get(dequantize_block) + block_name = block_stmt.name_hint + dequantize_info = {block_name: {"decode_block": block_name, "fast_decoding": False}} + + quantize_spec = self.quantize_scheme.linear_weight + if isinstance(quantize_spec, GroupQuantizationSpec): + dequantize_info[block_name].update({ + "with_scaling": True, + "group_size": quantize_spec.group_size, + }) + + # Determine source format based on quantization mode + quantize_mod = quantize_spec.mode + bits, source_format = self.parse_quantize_mode(quantize_mod) + dequantize_info[block_name]["source_format"] = { + "bits": bits, + "format": source_format, + } + + # Set storage and target data types + storage_dtype = self.get_storage_dtype(block_stmt, source_format) + dequantize_info[block_name]["storage_dtype"] = storage_dtype + dequantize_info[block_name]["target_format"] = quantize_spec.dtype + + return dequantize_info + + def parse_quantize_mode(self, quantize_mod: str) -> Tuple[int, str]: + """Extract bits and format from quantization mode.""" + if quantize_mod.startswith("int"): + return int(quantize_mod[3:]), "int" + elif quantize_mod.startswith("uint"): + return int(quantize_mod[4:]), "uint" + raise ValueError(f"Unsupported mode {quantize_mod}") + + def get_storage_dtype(self, block_stmt: BlockRV, source_format: str) -> str: + """Determine storage data type based on source format.""" + return (block_stmt.reads[0].buffer.dtype + if "nf" not in source_format else block_stmt.reads[1].buffer.dtype) diff --git a/bitblas/relax/transform/weight_only_propagate.py b/bitblas/relax/transform/weight_only_propagate.py new file mode 100644 index 000000000..709e02085 --- /dev/null +++ b/bitblas/relax/transform/weight_only_propagate.py @@ -0,0 +1,432 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Optional, Tuple, Union, List, Dict +from tvm.ir import IRModule +from tvm.ir.transform import PassContext, module_pass +from tvm import relax +from tvm import tir +from enum import Enum +from tvm.ir import GlobalVar +from tvm.tir import IndexMap +from tvm.target import Target +from tvm.tir import IterVar +from tvm.tir.schedule.schedule import BlockRV +from tvm.relax import PyExprMutator +from tvm.relax.expr import Call +from bitblas.gpu.matmul_analysis import ( + get_tensorized_func_and_tags, + get_propagate_map, + find_last_producer_from_buffer, + find_arg_idx_from_buffer_chain, + layout_propagate_chain, +) +from tvm.dlight.base import ( + analysis,) +from dataclasses import dataclass + + +def get_reduction_blocks(sch, blocks) -> bool: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks + + +class TransformKind(Enum): + NonTransform = 0 + InterWarpTransform = 1 + IntraWarpTransform = 2 + + +def check_sm_version(arch: str) -> int: + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 + + +def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +@dataclass +class LayoutTransformHint: + """ + A dataclass to store the layout transformation hint. + """ + + transform_level: TransformKind + inter_warp_layout: IndexMap + intra_warp_layout: IndexMap + apply_arg_idx: int + + +@module_pass(opt_level=0, name="InsertLayoutTransform") +class WeightOnlyLayoutPropagation: + + def __init__( + self, + transform_level: Union[int, TransformKind] = TransformKind.InterWarpTransform, + target: Optional[Target] = None, + faster_conversion: bool = False, + ) -> None: + if isinstance(transform_level, int): + transform_level = TransformKind(transform_level) + assert transform_level in [ + TransformKind.NonTransform, + TransformKind.InterWarpTransform, + TransformKind.IntraWarpTransform, + ] + # transform_level 1: only transform the inter-warp memory layout + # transform_level 2: transform the inter-warp memory layout and the intra-warp memory layout + self.transform_level = transform_level + self.target = Target.current() if target is None else target + # fast type conversion on nvidia gpu also requires weight permutation + self.faster_conversion = faster_conversion + # layout transform info to sync the layout in both graph and tir + self.layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {} + + def detect_propagate_matmul(self, func: tir.PrimFunc, target: Target): + _, tags = get_tensorized_func_and_tags(func, target, skip_normalize=True, allow_gemv=True) + if tags is None: + return False, None + return True, tags["intrin_info"] + + def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group,) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None or len(reduction_blocks) != 1: + return False + (main_block,) = reduction_blocks + + intrin_group = get_mma_intrin_group( + load_scope="shared", + store_scope="shared", + a_dtype=intrin_info["in_dtype"], + b_dtype=intrin_info["in_dtype"], + out_dtype=intrin_info["out_dtype"], + trans_a=False, + trans_b=intrin_info["trans_b"], + ) + + _, inter_j, inter_k = intrin_group["micro_kernel"] + + # weight only propagation + target_scope = ("read", 1) + weight_buffer = sch.get(main_block).reads[1].buffer + + # checkout whether the weight buffer has dynamic symbol + def check_dynamic_symbol(buffer): + return any([isinstance(axis, tir.Var) for axis in buffer.shape]) + + if check_dynamic_symbol(weight_buffer): + print("[BitBLAS] Weight buffer has dynamic symbol, skip weight propagation.") + return False + + transformed_block = find_last_producer_from_buffer(sch, main_block, weight_buffer) + if transformed_block is None: + return False + if transformed_block != main_block: + target_scope = ("read", 0) + + reindex_block = sch.cache_read(transformed_block, target_scope[1], "global") + + # create inter-warp memory layout index map + inter_warp_layout = IndexMap.from_func( + lambda i, j: (i // inter_j, j // inter_k, i % inter_j, j % inter_k)) + + inter_warp_layout = layout_propagate_chain( + sch, + main_block, + sch.get(main_block).reads[1].buffer, + reindex_block, + inter_warp_layout, + ) + + sch.transform_layout( + reindex_block, + ("read", 0), + lambda i, j: inter_warp_layout.map_indices([i, j]), + ) + arg_idx = find_arg_idx_from_buffer_chain(sch, reindex_block, + sch.get(reindex_block).reads[0].buffer) + + intra_warp_layout = None + if self.transform_level.value >= TransformKind.IntraWarpTransform.value: + intra_warp_layout, _ = get_propagate_map(intrin_info["trans_b"]) + intra_warp_layout = layout_propagate_chain( + sch, + main_block, + sch.get(main_block).reads[1].buffer, + reindex_block, + intra_warp_layout, + ) + sch.transform_layout( + reindex_block, + ("read", 0), + lambda i, j, ii, jj: ( + i, + j, + *intra_warp_layout.map_indices([ii, jj]), + ), + ) + + self.layout_transform_hints[g_var] = [ + LayoutTransformHint( + transform_level=self.transform_level, + inter_warp_layout=inter_warp_layout, + intra_warp_layout=intra_warp_layout, + apply_arg_idx=arg_idx, + ) + ] + + return sch.mod["main"] + + def transform_module( # pylint: disable=missing-function-docstring + self, + mod: IRModule, + _: PassContext, + ) -> IRModule: + if self.target.kind.name != "cuda": + # currently weight propagation only support nvidia gpus + return mod + + propagate_candidates = {} + propagated_funcs = {} # some funcs may not be able to transform + candidates_intrin_info = {} + decoded_funcs = {} + for g_var, func in mod.functions_items(): + if not isinstance(func, tir.PrimFunc): + continue + if g_var.name_hint != "main": + # Note: this can be applied to any function which can be transformed to matmul (e.g., conv2d) + # for mlc we only consider matmul + # detect the pattern + is_matmul, intrin_info = self.detect_propagate_matmul(func, self.target) + + if (func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys()): + # currently we only support tensorize propagation + continue + + if is_matmul: + if "dequantize_info" in func.attrs: + decoded_funcs[g_var] = func + if self.transform_level != TransformKind.NonTransform: + # lift tags to the function as it has intrinsic information that can be reused. + propagate_candidates[g_var] = func + candidates_intrin_info[g_var] = intrin_info + + for g_var, func in propagate_candidates.items(): + updated_func = self.transform_matmul(g_var, func, candidates_intrin_info[g_var]) + if updated_func: + updated_func = updated_func.with_attrs({ + "transform_kind": self.transform_level.value, + "weight_transform_kind": True, + }) + propagated_funcs[g_var] = updated_func + mod[g_var] = updated_func + + @relax.expr_functor.mutator + class TensorCoreLayoutMutator(PyExprMutator): + """Mutator that performs transformation.""" + + def __init__( + self, + transform_level: TransformKind = TransformKind.NonTransform, + layout_transform_hints: Optional[Dict[str, List[LayoutTransformHint]]] = None, + ): + if layout_transform_hints is None: + layout_transform_hints = {} + super().__init__() + self.transform_level = transform_level + self.layout_transform_hints = layout_transform_hints + + def tc_layout_transform(self, call_node: Call) -> Call: + if self.transform_level == TransformKind.NonTransform: + return super().visit_call_(call_node) + g_var = call_node.args[0] + if g_var not in propagated_funcs: + return super().visit_call_(call_node) + args = list(call_node.args[1]) + # assume we only have weight propagation currently + (weight_layout_hint,) = self.layout_transform_hints[g_var] + weight = args[weight_layout_hint.apply_arg_idx] + weight = self.builder_.emit( + relax.op.layout_transform( + weight, + index_map=lambda i, j: weight_layout_hint.inter_warp_layout.map_indices( + [i, j]), + )) + if self.transform_level.value >= TransformKind.IntraWarpTransform.value: + weight = self.builder_.emit( + relax.op.layout_transform( + weight, + index_map=lambda i, j, ii, jj: ( + i, + j, + *weight_layout_hint.intra_warp_layout.map_indices([ii, jj]), + ), + )) + + call_node = self.builder_.emit( + relax.call_tir( + g_var, + args[:weight_layout_hint.apply_arg_idx] + [weight] + + args[weight_layout_hint.apply_arg_idx + 1:], + out_sinfo=call_node.struct_info, + )) + return call_node + + def visit_call_(self, call_node: Call): + return self.tc_layout_transform(call_node) + + def transform( + self, + mod: IRModule, + ): + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + for gv, func in new_mod.functions_items(): + mod.update_func(gv, func) + return mod + + mod = TensorCoreLayoutMutator( + transform_level=self.transform_level, + layout_transform_hints=self.layout_transform_hints, + ).transform(mod) + + @relax.expr_functor.mutator + class FastTypeConversionLayoutMutator(PyExprMutator): + """Mutator that performs transformation.""" + + def __init__(self, faster_conversion: bool = False): + super().__init__() + self.faster_conversion = faster_conversion + + def lop3_layout_transform(self, call_node: Call) -> Call: + if not self.faster_conversion: + return super().visit_call_(call_node) + + from bitblas.ops.impl import tir_interleave_weight + + g_var = call_node.args[0] + if g_var not in decoded_funcs: + return super().visit_call_(call_node) + + args = list(call_node.args[1]) + func = decoded_funcs[g_var] + if "dequantize_info" not in func.attrs: + return super().visit_call_(call_node) + dequantize_info = dict(func.attrs["dequantize_info"]) + assert len(dequantize_info) == 1 + (weight_dequantize_info,) = dequantize_info.values() + + sch = tir.Schedule(func) + dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) + + # weight is the first read buffer if format in ["int", "uint"], otherwise the second read buffer, nf .etc + source_format = weight_dequantize_info["source_format"]["format"] + source_bits = weight_dequantize_info["source_format"]["bits"] + target_dtype = weight_dequantize_info["target_format"] + + if source_format in ["int", "uint"]: + weight_buffer = sch.get(dequantize_block).reads[0].buffer + elif source_format in ["nf"]: + weight_buffer = sch.get(dequantize_block).reads[1].buffer + else: + raise ValueError(f"Unsupported source format {source_format}") + + # update func with dequantize_info + dequantize_info["fast_decoding"] = True + self.builder_.update_func(g_var, + func.with_attrs({"dequantize_info": dequantize_info})) + + weight_idx = find_arg_idx_from_buffer_chain(sch, dequantize_block, weight_buffer) + weight = args[weight_idx] + + weight_shape = weight_buffer.shape + # reshape the weight shape to 2d + reshape_weight = self.builder_.emit( + relax.op.reshape(weight, (-1, weight_shape[-1]))) + # register g_var to the func + lop3_interleave_func = tir_interleave_weight( + N=reshape_weight.struct_info.shape[0], + QK=reshape_weight.struct_info.shape[1], + bits=source_bits, + target_dtype=target_dtype, + storage_dtype=reshape_weight.struct_info.dtype, + ) + interleave_gvar = self.builder_.add_func( + lop3_interleave_func.without_attr("global_symbol"), + "tir_interleave_weight", + ) + lop3_interleave_weight = self.builder_.emit( + relax.call_tir( + interleave_gvar, + [reshape_weight], + out_sinfo=reshape_weight.struct_info, + ),) + reshape_weight = self.builder_.emit( + relax.op.reshape(lop3_interleave_weight, weight_shape)) + call_node = self.builder_.emit( + relax.call_tir( + g_var, + args[:weight_idx] + [reshape_weight] + args[weight_idx + 1:], + out_sinfo=call_node.struct_info, + ),) + + return call_node + + def visit_call_(self, call_node: Call): + return self.lop3_layout_transform(call_node) + + def transform( + self, + mod: IRModule, + ): + for gv, func in mod.functions_items(): + if isinstance(func, relax.Function): + updated_func = self.visit_expr(func) + self.builder_.update_func(gv, updated_func) + new_mod = self.builder_.get() + new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod + for gv, func in new_mod.functions_items(): + mod.update_func(gv, func) + return mod + + mod = FastTypeConversionLayoutMutator( + faster_conversion=self.faster_conversion).transform(mod) + mod = relax.transform.LegalizeOps()(mod) + return mod diff --git a/bitblas/testing/__init__.py b/bitblas/testing/__init__.py new file mode 100644 index 000000000..24f896bd8 --- /dev/null +++ b/bitblas/testing/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import sys +import inspect +import pytest +from bitblas.base import DefaultPolicy, TensorCorePolicy +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags + + +# pytest.main() wrapper to allow running single test file +def main(): + test_file = inspect.getsourcefile(sys._getframe(1)) + sys.exit(pytest.main([test_file] + sys.argv[1:])) + + +def debug_with_schedule(func, arch, sch_rule): + policy = DefaultPolicy(func=func, arch=arch) + try: + tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) + except Exception: + tags = None + if tags: + policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) + configs = policy.emit_config(1) + return sch_rule.apply_config(func, configs[0]) diff --git a/bitblas/utils/__init__.py b/bitblas/utils/__init__.py new file mode 100644 index 000000000..00bddc2a5 --- /dev/null +++ b/bitblas/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 +from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 +from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 diff --git a/bitblas/utils/post_process.py b/bitblas/utils/post_process.py new file mode 100644 index 000000000..cabee6be1 --- /dev/null +++ b/bitblas/utils/post_process.py @@ -0,0 +1,38 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import re + + +def match_global_kernel(source: str) -> int: + pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" + matched = re.findall(pattern, source) + assert len(matched) > 1 # may have statement before kernel + return source.index(matched[0]) + + +def tensor_replace_dp4a(source: str) -> str: + # as under block reduction in tir dsl, the dp4a tensorize will fail, so we should do dp4a in post processor. + # TODO(lei): this is a stuff that should be fixed in the tvm in the future + pattern = r"""for\s*\(int\s*(?P\w+)\s*=\s*0;\s*\1\s*<\s*4;\s*\+\+\1\)\s*\{\s*(?P\w+)\[0\]\s*=\s*\(\2\[0\]\s*\+\s*\(\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\s*\*\s*\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\)\);\s*\}""" + replacement = (r"""\2[0] = __dp4a(*(int *)&\3[((\4 * 4))],*(int *)&\5[((\6 * 4))], \2[0]);""") + source = re.sub(pattern, replacement, source) + return source + + +def tensor_remove_make_int4(source: str) -> str: + # remove make_int4 with 16 signed char arguments + # TODO(lei): this is a stuff that should be fixed in the tvm in the future + source = source.replace( + "make_int4((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", + "make_int4(0, 0, 0, 0)", + ) + return source + +def tensor_remove_make_int2(source: str) -> str: + # remove make_int4 with 16 signed char arguments + # TODO(lei): this is a stuff that should be fixed in the tvm in the future + source = source.replace( + "make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", + "make_int2(0, 0)", + ) + return source diff --git a/bitblas/utils/target_detector.py b/bitblas/utils/target_detector.py new file mode 100644 index 000000000..71d6dcc1f --- /dev/null +++ b/bitblas/utils/target_detector.py @@ -0,0 +1,103 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import os +import subprocess +from typing import List +from thefuzz import process +from tvm.target import Target +from tvm.target.tag import list_tags + +import logging + +logger = logging.getLogger(__name__) + +TARGET_MISSING_ERROR = ( + "TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, " + "where is one of the available targets can be found in the output of `tools/get_available_targets.py`." +) + +# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target +# Remap list to match the oem model name to the closest public model name +NVIDIA_GPU_REMAP = { + "NVIDIA PG506-230": "NVIDIA A100", + "NVIDIA PG506-232": "NVIDIA A100", +} + +def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): + """ + Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. + + Returns: + str: The name of the GPU, or None if 'nvidia-smi' command fails. + """ + try: + # Execute nvidia-smi command to get the GPU name + output = subprocess.check_output( + ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], + encoding="utf-8", + ).strip() + except subprocess.CalledProcessError as e: + logger.info("nvidia-smi failed with error: %s", e) + return None + + gpus = output.split("\n") + + # for multiple gpus, CUDA_DEVICE_ORDER=PCI_BUS_ID must be set to match nvidia-smi or else wrong + # gpu is returned for gpu_id + if len(gpus) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID": + raise EnvironmentError("Multi-gpu environment must set `CUDA_DEVICE_ORDER=PCI_BUS_ID`.") + + if gpu_id >= len(gpus) or gpu_id < 0: + raise ValueError(f"Passed gpu_id:{gpu_id} but there are {len(gpus)} detected Nvidia gpus.") + + return gpus[gpu_id] + +def find_best_match(tags, query): + """ + Finds the best match for a query within a list of tags using fuzzy string matching. + """ + MATCH_THRESHOLD = 25 + best_match, score = process.extractOne(query, tags) + + def check_target(best, default): + return best if Target(best).arch == Target(default).arch else default + + if check_target(best_match, "cuda") == best_match: + return best_match if score >= MATCH_THRESHOLD else "cuda" + else: + logger.warning(TARGET_MISSING_ERROR) + return "cuda" + + +def get_all_nvidia_targets() -> List[str]: + """ + Returns all available NVIDIA targets. + """ + all_tags = list_tags() + return [tag for tag in all_tags if "nvidia" in tag] + + +def auto_detect_nvidia_target(gpu_id: int = 0) -> str: + """ + Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target. + + Returns: + str: The detected TVM target architecture. + """ + # Return a predefined target if specified in the environment variable + # if "TVM_TARGET" in os.environ: + # return os.environ["TVM_TARGET"] + + # Fetch all available tags and filter for NVIDIA tags + all_tags = list_tags() + nvidia_tags = [tag for tag in all_tags if "nvidia" in tag] + + # Get the current GPU model and find the best matching target + gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id) + + # Compat: remap oem devices to their correct non-oem model names for tvm target + if gpu_model in NVIDIA_GPU_REMAP: + gpu_model = NVIDIA_GPU_REMAP[gpu_model] + + target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" + return target diff --git a/bitblas/utils/tensor_adapter.py b/bitblas/utils/tensor_adapter.py new file mode 100644 index 000000000..55b80d138 --- /dev/null +++ b/bitblas/utils/tensor_adapter.py @@ -0,0 +1,130 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from typing import Union +from enum import IntEnum +import numpy as np +import torch +from torch.utils.dlpack import from_dlpack, to_dlpack +from math import prod + +from tvm.relay import TensorType +from tvm._ffi.base import _LIB, c_str +from tvm._ffi._ctypes.types import TVMValue, check_call +from tvm._ffi.runtime_ctypes import ( + TVMArrayHandle,) +import ctypes + +TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) +_c_str_dltensor = c_str("dltensor") +_c_str_used_dltensor = c_str("used_dltensor") + + +def get_values_from_torch_tensors(tensors, num_args): + values = (TVMValue * num_args)() + dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors] + for i, dltensor in enumerate(dlpack_tensors): + dltensor = ctypes.py_object(dltensor) + if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): + ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) + # enforce type to make sure it works for all ctypes + ptr = ctypes.cast(ptr, ctypes.c_void_p) + handle = TVMArrayHandle() + check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) + # ndarray = tvm.runtime.ndarray._make_array(handle, False, False) + ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) + ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) + values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p) + else: + raise ValueError("Invalid DLTensor") + return values + + +class TensorSupplyType(IntEnum): + Integer = 1 + Uniform = 2 + Normal = 3 + Randn = 4 + Zero = 5 + One = 6 + + +def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None): + + def var_wrapper(v, opt_shapes): + if isinstance(v, tvm.tir.Var): + assert opt_shapes + assert v.name in opt_shapes + return opt_shapes[v.name] + elif isinstance(v, tvm.tir.IntImm): + return v.value + else: + raise RuntimeError("Not supported type: ", type(v)) + + def get_tensor(tensor: TensorType) -> torch.Tensor: + dtype = torch.__getattribute__(str(tensor.dtype)) + device = torch.cuda.current_device() + shape = [var_wrapper(i, opt_shapes) for i in tensor.shape] + if supply_type == TensorSupplyType.Integer: + return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.Uniform: + return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) + elif supply_type == TensorSupplyType.Normal: + return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) + elif supply_type == TensorSupplyType.Randn: + return torch.randn(*shape, device=device).to(dtype) + elif supply_type == TensorSupplyType.Zero: + return torch.zeros(*shape, device=device, dtype=dtype) + elif supply_type == TensorSupplyType.One: + return torch.ones(*shape, device=device, dtype=dtype) + else: + raise NotImplementedError(supply_type) + + return get_tensor + + +def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): + if isinstance(tensor, tvm.te.Tensor): + return torch.from_numpy(tensor.numpy()) + elif isinstance(tensor, tvm.nd.NDArray): + return from_dlpack(tensor) + else: + raise RuntimeError("Not supported type: ", type(tensor)) + +def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): + # It additionally needs the ctypes type as torch type + def as_tensor(address, shape, elems_inbytes, torch_type): + arr = (ctypes.c_int8 * elems_inbytes).from_address( + address) + return torch.frombuffer(arr, dtype=torch_type).view(*shape) + + if isinstance(tensor, tvm.nd.NDArray): + np_array = tensor.asnumpy() + shape = np_array.shape + dtype = np_array.dtype + torch_dtype = getattr(torch, str(dtype)) + num_elems_inbytes = prod(shape) * np_array.itemsize + data_ptr = np_array.ctypes.data + tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype) + return tensor + else: + raise RuntimeError("Not supported type: ", type(tensor)) + +def lazy_torch_to_tvm_tensor(tensor): + # It additionally needs the ctypes type as torch type + def as_tensor(address, shape, elems_inbytes, numpy_type): + arr = (ctypes.c_int8 * elems_inbytes).from_address( + address) + return np.frombuffer(arr, dtype=numpy_type).reshape(shape) + + if isinstance(tensor, torch.Tensor): + data_ptr = tensor.data_ptr() + shape = tensor.shape + torch_dtype = tensor.dtype + numpy_dtype = str(torch_dtype).replace("torch.", "") + num_elems_inbytes = prod(shape) * tensor.itemsize + np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype) + tvm_tensor = tvm.nd.array(np_tensor) + return tvm_tensor + else: + raise RuntimeError("Not supported type: ", type(tensor)) diff --git a/bitblas/wrapper/__init__.py b/bitblas/wrapper/__init__.py new file mode 100644 index 000000000..1d87f8020 --- /dev/null +++ b/bitblas/wrapper/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .general import CUDASourceWrapper, CUDASourceWrapperWithDynamic # noqa: F401 diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py new file mode 100644 index 000000000..58aa8d226 --- /dev/null +++ b/bitblas/wrapper/general.py @@ -0,0 +1,518 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import tvm +from typing import Optional, List, Dict, Union +from tvm import IRModule +from bitblas import TileDevice +from tvm.runtime import ndarray +from bitblas.utils import match_global_kernel +import re +import ctypes +import os +import tempfile +import subprocess +import logging +from tvm.driver import lower +from tvm.target import Target + +logger = logging.getLogger(__name__) + +_TYPE_MAP = { + "float32": "float", + "float16": "half", + "bfloat16": "__nv_bfloat162", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", + "float64": "double", + "int64": "int64_t", + "int32": "int", + "uint32": "unsigned int", + "bool": "int8_t", + "int8": "int8_t", + "uint8": "uint8_t", + "int16": "int16_t", + "uchar": "uint8_t", +} + + +def get_annotated_device_mod(mod: IRModule, target: Target): + """ + Lower the given IRModule and create a device module for the specified target. + + Parameters: + - mod: The input IRModule. + - target: The compilation target. + + Returns: + - A device module ready for execution. + """ + input_mod = lower(mod) + target_input_mod = {target: input_mod} + annotated_mods = {} + runtime = None + target_host = None + for tgt, mod in target_input_mod.items(): + if not isinstance(tgt, (str, Target)): + raise ValueError("The key of inputs must be str or " + "Target when inputs is dict.") + if not isinstance(mod, tvm.IRModule): + raise ValueError("inputs must be Schedule, IRModule, " + "or dict of str to IRModule.") + annotated_mods[tgt] = mod.with_attr("runtime", runtime) + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + if not target_host: + for tar, _ in annotated_mods.items(): + device_type = ndarray.device(tar.kind.name, 0).device_type + if device_type == ndarray.cpu(0).device_type: + target_host = tar + break + if not target_host: + target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" + annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) + for target, mod in annotated_mods.items(): + mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") + device_mod_passes = tvm.get_global_func("driver.device_mod_passes") + mod = mixed_mod_passes(mod, target)(mod) + device_mod = device_mod_passes(mod, target)(mod) + return device_mod + + +def get_thread_block_information(mod: IRModule): + """ + Extracts the thread block and grid dimensions for the reduction block within a given IRModule. + + Parameters: + - mod: The input IRModule from which to extract thread block and grid information. + + Returns: + A tuple containing two lists: + - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). + - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). + """ + + # Initialize the schedule from the IRModule + sch = tvm.tir.Schedule(mod) + + # Get the root block and its child blocks + root_block = sch.get_block("root") + child_blocks = sch.get_child_blocks(root_block) + + # Initialize default block and grid dimensions (1, 1, 1) + block_dims, grid_dims = [1, 1, 1], [1, 1, 1] + + for block in child_blocks: + # Get the loops surrounding the main block + loops = sch.get_loops(block) + + # Iterate over each loop to extract thread and block bindings + for loop in loops: + stmt = sch.get(loop) + thread_binding = stmt.thread_binding + extent = int(stmt.extent) + + # Skip loops without thread binding + if thread_binding: + if "threadIdx" in thread_binding.thread_tag: + block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + elif "blockIdx" in thread_binding.thread_tag: + grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent + + return block_dims, grid_dims + + +class CUDASourceWrapper(object): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + self.mod = optimized_mod + self.arch = arch + self.source = source + self.function_name: Optional[str] = None + self.dynamic_smem_buf: Optional[int] = None + self.block_info: Union[List[int], Dict] = [1, 1, 1] + self.grid_info: Union[List[int], Dict] = [1, 1, 1] + self.parse_source_information() + self.src_name: Optional[str] = None + self.lib_name: Optional[str] = None + self.lib_code: Optional[str] = self.update_lib_code(source) + + def load_lib(self): + return ctypes.CDLL(self.lib_name) + + def remove_lib(self): + if self.lib_name: + os.remove(self.lib_name) + self.lib_name = None + + def compile_lib(self, timeout: float = None): + arch = self.arch + src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) + compute_version = arch.compute_capability + lib_name = src.name.replace(".cu", ".so") + + command = [ + "nvcc", + "-std=c++17", + "-Xcudafe", + "--diag_suppress=177", + "--compiler-options", + "'-fPIC'", + "-lineinfo", + "--shared", + src.name, + "-lcuda", + f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-o", + lib_name, + ] + src.write(self.lib_code) + src.flush() + try: + ret = subprocess.run(command, timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning(f"Compilation Timeout! {command}") + return None + if ret.returncode != 0: + logger.warning(f"Compilation Failed! {command}") + return None + self.src_name = src.name + self.lib_name = lib_name + + def parse_source_information(self): + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + assert (len(device_mod.functions) == 1 + ), "Only support one function in the module for static shape kernel." + for g_var, func in device_mod.functions.items(): + self.function_name = g_var.name_hint + attrs = func.attrs + if "dyn_shared_memory_buf" in attrs: + self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + self.block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + self.grid_info["xyz".index(tag[-1])] = extent + + def get_dynamic_symbolic_set(self, prim_func): + # Determine the set of dynamic symbols used in the function + dynamic_symbolic_set = set() + for param in prim_func.params: + buffer = prim_func.buffer_map[param] + for dim in buffer.shape: + if isinstance(dim, tvm.tir.Var): + dynamic_symbolic_set.add(dim.name) + return dynamic_symbolic_set + + def get_cuda_init_func(self): + # Initialize an empty string for the CUDA function call + call_str = """""" + # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call + if self.dynamic_smem_buf is not None: + call_str = """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(self.function_name, self.dynamic_smem_buf) + # Format the initialization function using the call_str + init_funcs = """ + extern "C" void init() {{ + {} + }} + """.format(call_str) + return init_funcs + + def update_lib_code(self, code: str): + # Update the library code with the given code string + self.lib_code = code + # Find the index of the global kernel function in the code + index = match_global_kernel(code) + # Extract the declaration of the function starting from the found index + declaration = code[index:].split(";")[0] + + function_name = self.function_name + # Get the CUDA initialization function + init_func = self.get_cuda_init_func() + + # Locate the opening brace of the function to insert arguments + index = code.index("{", index) + function_args = [] + # Populate the function arguments from the primary function's parameters and buffers + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + # Add dynamic symbolic parameters as integers to the function arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + # Format the function arguments for declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s, function_args): + # Extract the function call arguments matching the function definition + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(declaration, function_args)) + block_info, grid_info = self.block_info, self.grid_info + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + # Prepare the block and grid dimensions for the CUDA kernel launch + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) + # Determine the shared memory size, defaulting to 0 if not specified + smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf + # Format the CUDA kernel launch string + if len(dynamic_symbolic_set) != 0: + call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) + else: + call_str = "" + call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) + # Create the host function wrapper for the CUDA kernel + host_func = """ + extern "C" void call({}) {{ + {} + }} + """.format(def_args, call_str) + # Combine the source, initialization function, and host function to form the complete library code + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] + + +class CUDASourceWrapperWithDynamic(CUDASourceWrapper): + + def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): + super().__init__(optimized_mod, source, arch) + + def get_cuda_init_func(self): + # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory + call_str = """""" + # Iterate over functions and their dynamic shared memory requirements + for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): + if dynamic_smem_buf is not None: + # Format the cudaFuncSetAttribute call for dynamic shared memory + call_str += """ + cudaFuncSetAttribute({}, + cudaFuncAttributeMaxDynamicSharedMemorySize, {}); + """.format(function_name, dynamic_smem_buf) + # Define the init function that will set the attributes for each kernel + init_funcs = """ +extern "C" void init() {{ + {} +}} + """.format(call_str) + return init_funcs + + def create_dispatch_func(self, code, function_informations): + # Extract the set of dynamic symbolic names used in the primary function + dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) + + # Find the location of the global kernel function in the code + index = match_global_kernel(code) + + # Analyze the function declaration to prepare for argument extraction + dummy_declaration = code[index:].split(";")[0] + + function_name = self.function_name + + # Identify the start of the function body to insert arguments + index = code.index("{", index) + function_args = [] + # Collect function arguments based on primary function's parameters and buffer mappings + for param in self.prim_func.params: + buffer = self.prim_func.buffer_map[param] + function_args.append({ + "name": buffer.name, + "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", + }) + # Add dynamic symbols as integer arguments + for dyn_sym in dynamic_symbolic_set: + function_args.append({"name": dyn_sym, "type": "int"}) + + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) + + # Format the argument definitions for function declaration + def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) + + def func_call_args(s: str, function_args): + # Extract and clean the function call arguments to match the declaration + pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" + matches = re.findall(pattern, s) + call_args = [] + for match in matches: + match = re.sub(r"\d+", "", match) # Remove numbers + match = re.sub(r"_", "", match) # Remove underscores + for arg in function_args: + if arg["name"] == match: + call_args.append(match) + return call_args + + call_args = ", ".join(func_call_args(dummy_declaration, function_args)) + + def legalize_c(p): + # Convert TIR expressions to legal C expressions + # Directly convert to string since the special case handling + # does not alter the string representation for `tvm.tir.Var` and `IntImm`. + # Replace Python's floor division operator with C's division operator + if isinstance(p, tvm.tir.IntImm): + p = int(p) + return str(p).replace("//", "/") + + last_range = 0 + num_items = len(function_informations) + _call_str = """""" + for function_name, info in function_informations.items(): + # Prepare block and grid configurations for kernel launches + block_info, grid_info = info["block_info"], info["grid_info"] + block_str = "dim3({}, {}, {})".format( + legalize_c(block_info[0]), + legalize_c(block_info[1]), + legalize_c(block_info[2]), + ) + grid_str = "dim3({}, {}, {})".format( + legalize_c(grid_info[0]), + legalize_c(grid_info[1]), + legalize_c(grid_info[2]), + ) + # Handle dynamic shared memory specification + smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) + opt_shapes = info["opt_shapes"] + # Generate conditional kernel launch code based on dynamic symbolic ranges + (symbolic,) = list(dynamic_symbolic_set) + range_str = opt_shapes[symbolic] + if last_range == 0: + call_str = "if ({} == 0) return; \n".format(symbolic,) + call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + else: + call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + symbolic, + range_str, + function_name, + grid_str, + block_str, + smem_str, + call_args, + ) + if last_range == num_items - 1: + call_str += ( + "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) + last_range += 1 + _call_str += call_str + + # Wrap the kernel dispatch logic in an external C function + host_func = """ +extern "C" void call({}) {{ + {} +}} + """.format(def_args, _call_str) + return host_func + + def parse_source_information(self): + # Parse device module to extract execution configurations for each function + device_mod = get_annotated_device_mod(self.mod, self.arch.target) + block_info_map = {} + grid_info_map = {} + dynamic_smem_buf_map = {} + for g_var, func in device_mod.functions.items(): + # Default block and grid configurations + block_info = [1, 1, 1] + grid_info = [1, 1, 1] + function_name = g_var.name_hint + attrs = func.attrs + dynamic_smem_buf = None + if "dyn_shared_memory_buf" in attrs: + dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) + if "thread_extent" in attrs: + # Extract block and grid sizes from thread extents + thread_extent = attrs["thread_extent"] + for tag, extent in thread_extent.items(): + if "threadIdx" in tag: + block_info["xyz".index(tag[-1])] = extent + elif "blockIdx" in tag: + grid_info["xyz".index(tag[-1])] = extent + # Map the extracted configurations to each function + block_info_map[function_name] = block_info + grid_info_map[function_name] = grid_info + dynamic_smem_buf_map[function_name] = dynamic_smem_buf + # Store the mappings for use in code generation + self.block_info = block_info_map + self.grid_info = grid_info_map + self.dynamic_smem_buf = dynamic_smem_buf_map + + def update_lib_code(self, code: str): + # Organize function information for code generation + function_informations = {} + for g_var, func in self.mod.functions.items(): + if g_var.name_hint == "main": + continue + function_name = g_var.name_hint + attrs = func.attrs + assert "opt_shapes" in attrs + opt_shapes = attrs["opt_shapes"] + function_informations[function_name] = { + "function_name": function_name, + "opt_shapes": opt_shapes, + "block_info": self.block_info[function_name], + "grid_info": self.grid_info[function_name], + "dynamic_smem_buf": self.dynamic_smem_buf[function_name], + } + + def compare_map_objects(map_obj): + comparable_representation = list(map_obj.values()) + return comparable_representation + + function_informations = dict( + sorted( + function_informations.items(), + key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) + + self.lib_code = code + + # Generate the initialization and dispatch functions + init_func = self.get_cuda_init_func() + host_func = self.create_dispatch_func(code, function_informations) + # Concatenate source code with generated code segments + lib_code = self.source + init_func + host_func + return lib_code + + @property + def prim_func(self): + return self.mod["main"] From 955b379423b839ac457a904b04c15db2f005b0d0 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 3 Jul 2024 10:06:23 +0000 Subject: [PATCH 3/3] Remove unused code files --- install.sh | 26 + python/bitblas/__init__.py | 87 - python/bitblas/base/__init__.py | 18 - python/bitblas/base/analysis.py | 300 --- python/bitblas/base/common_schedules.py | 163 -- python/bitblas/base/roller/__init__.py | 7 - python/bitblas/base/roller/arch/__init__.py | 14 - python/bitblas/base/roller/arch/arch_base.py | 40 - python/bitblas/base/roller/arch/cpu.py | 19 - python/bitblas/base/roller/arch/cuda.py | 67 - python/bitblas/base/roller/bestfit.py | 66 - python/bitblas/base/roller/hint.py | 248 -- python/bitblas/base/roller/node.py | 408 --- python/bitblas/base/roller/policy/__init__.py | 5 - python/bitblas/base/roller/policy/common.py | 56 - python/bitblas/base/roller/policy/default.py | 748 ------ .../bitblas/base/roller/policy/tensorcore.py | 349 --- python/bitblas/base/roller/rasterization.py | 88 - .../base/roller/shape_inference/__init__.py | 4 - .../base/roller/shape_inference/common.py | 66 - .../base/roller/shape_inference/tir.py | 399 --- python/bitblas/base/schedule_rule.py | 149 -- python/bitblas/base/transform.py | 218 -- python/bitblas/base/utils.py | 517 ---- python/bitblas/cache/__init__.py | 9 - python/bitblas/cache/operator.py | 179 -- python/bitblas/generator.py | 15 - python/bitblas/gpu/__init__.py | 23 - python/bitblas/gpu/base.py | 44 - python/bitblas/gpu/element_wise.py | 97 - python/bitblas/gpu/fallback.py | 95 - python/bitblas/gpu/gemv.py | 794 ------ python/bitblas/gpu/gemv_dequantize.py | 369 --- python/bitblas/gpu/general_reduction.py | 465 ---- python/bitblas/gpu/intrin/__init__.py | 3 - python/bitblas/gpu/intrin/lop3.py | 1667 ------------ python/bitblas/gpu/matmul.py | 372 --- python/bitblas/gpu/matmul_analysis.py | 786 ------ python/bitblas/gpu/matmul_mma.py | 1069 -------- python/bitblas/gpu/matmul_mma_dequantize.py | 2295 ----------------- python/bitblas/gpu/matmul_wmma.py | 892 ------- python/bitblas/gpu/reduction.py | 301 --- python/bitblas/gpu/rmsnorm.py | 144 -- python/bitblas/gpu/transpose.py | 133 - python/bitblas/gpu/utils.py | 86 - python/bitblas/module/__init__.py | 305 --- python/bitblas/ops/__init__.py | 7 - python/bitblas/ops/general_matmul.py | 588 ----- python/bitblas/ops/general_matmul_splitk.py | 199 -- python/bitblas/ops/impl/__init__.py | 3 - .../ops/impl/batch_matmul_dequantize_impl.py | 392 --- python/bitblas/ops/impl/batch_matmul_impl.py | 93 - python/bitblas/ops/impl/convolution2d_impl.py | 190 -- .../bitblas/ops/impl/ladder_permutate_impl.py | 81 - .../bitblas/ops/impl/lop3_permutate_impl.py | 152 -- .../ops/impl/matmul_dequantize_impl.py | 644 ----- .../ops/impl/matmul_dequantize_splitk_impl.py | 184 -- python/bitblas/ops/impl/matmul_impl.py | 356 --- python/bitblas/ops/impl/matmul_splitk_impl.py | 94 - .../bitblas/ops/impl/param_permutate_impl.py | 56 - python/bitblas/ops/ladder_permutate.py | 97 - python/bitblas/ops/lop3_permutate.py | 72 - python/bitblas/ops/matmul.py | 288 --- python/bitblas/ops/matmul_dequantize.py | 331 --- python/bitblas/ops/operator.py | 367 --- python/bitblas/ops/param_permutate.py | 91 - python/bitblas/quantization/__init__.py | 12 - python/bitblas/quantization/quantization.py | 217 -- python/bitblas/quantization/utils.py | 110 - python/bitblas/relax/op/interleave_weight.py | 23 - python/bitblas/relax/transform/__init__.py | 5 - .../relax/transform/annotate_decode_block.py | 123 - .../relax/transform/weight_only_propagate.py | 432 ---- python/bitblas/testing/__init__.py | 25 - python/bitblas/utils/__init__.py | 5 - python/bitblas/utils/post_process.py | 38 - python/bitblas/utils/target_detector.py | 103 - python/bitblas/utils/tensor_adapter.py | 130 - python/bitblas/wrapper/__init__.py | 4 - python/bitblas/wrapper/general.py | 518 ---- python/bitblas_cli.py | 2 - .../dsl/test_auto_normalized_tensorcore.py | 112 - 82 files changed, 26 insertions(+), 20323 deletions(-) create mode 100755 install.sh delete mode 100644 python/bitblas/__init__.py delete mode 100644 python/bitblas/base/__init__.py delete mode 100644 python/bitblas/base/analysis.py delete mode 100644 python/bitblas/base/common_schedules.py delete mode 100644 python/bitblas/base/roller/__init__.py delete mode 100644 python/bitblas/base/roller/arch/__init__.py delete mode 100644 python/bitblas/base/roller/arch/arch_base.py delete mode 100644 python/bitblas/base/roller/arch/cpu.py delete mode 100644 python/bitblas/base/roller/arch/cuda.py delete mode 100644 python/bitblas/base/roller/bestfit.py delete mode 100644 python/bitblas/base/roller/hint.py delete mode 100644 python/bitblas/base/roller/node.py delete mode 100644 python/bitblas/base/roller/policy/__init__.py delete mode 100644 python/bitblas/base/roller/policy/common.py delete mode 100644 python/bitblas/base/roller/policy/default.py delete mode 100644 python/bitblas/base/roller/policy/tensorcore.py delete mode 100644 python/bitblas/base/roller/rasterization.py delete mode 100644 python/bitblas/base/roller/shape_inference/__init__.py delete mode 100644 python/bitblas/base/roller/shape_inference/common.py delete mode 100644 python/bitblas/base/roller/shape_inference/tir.py delete mode 100644 python/bitblas/base/schedule_rule.py delete mode 100644 python/bitblas/base/transform.py delete mode 100644 python/bitblas/base/utils.py delete mode 100644 python/bitblas/cache/__init__.py delete mode 100644 python/bitblas/cache/operator.py delete mode 100644 python/bitblas/generator.py delete mode 100644 python/bitblas/gpu/__init__.py delete mode 100644 python/bitblas/gpu/base.py delete mode 100644 python/bitblas/gpu/element_wise.py delete mode 100644 python/bitblas/gpu/fallback.py delete mode 100644 python/bitblas/gpu/gemv.py delete mode 100644 python/bitblas/gpu/gemv_dequantize.py delete mode 100644 python/bitblas/gpu/general_reduction.py delete mode 100644 python/bitblas/gpu/intrin/__init__.py delete mode 100644 python/bitblas/gpu/intrin/lop3.py delete mode 100644 python/bitblas/gpu/matmul.py delete mode 100644 python/bitblas/gpu/matmul_analysis.py delete mode 100644 python/bitblas/gpu/matmul_mma.py delete mode 100644 python/bitblas/gpu/matmul_mma_dequantize.py delete mode 100644 python/bitblas/gpu/matmul_wmma.py delete mode 100644 python/bitblas/gpu/reduction.py delete mode 100644 python/bitblas/gpu/rmsnorm.py delete mode 100644 python/bitblas/gpu/transpose.py delete mode 100644 python/bitblas/gpu/utils.py delete mode 100644 python/bitblas/module/__init__.py delete mode 100644 python/bitblas/ops/__init__.py delete mode 100644 python/bitblas/ops/general_matmul.py delete mode 100644 python/bitblas/ops/general_matmul_splitk.py delete mode 100644 python/bitblas/ops/impl/__init__.py delete mode 100644 python/bitblas/ops/impl/batch_matmul_dequantize_impl.py delete mode 100644 python/bitblas/ops/impl/batch_matmul_impl.py delete mode 100644 python/bitblas/ops/impl/convolution2d_impl.py delete mode 100644 python/bitblas/ops/impl/ladder_permutate_impl.py delete mode 100644 python/bitblas/ops/impl/lop3_permutate_impl.py delete mode 100644 python/bitblas/ops/impl/matmul_dequantize_impl.py delete mode 100644 python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py delete mode 100644 python/bitblas/ops/impl/matmul_impl.py delete mode 100644 python/bitblas/ops/impl/matmul_splitk_impl.py delete mode 100644 python/bitblas/ops/impl/param_permutate_impl.py delete mode 100644 python/bitblas/ops/ladder_permutate.py delete mode 100644 python/bitblas/ops/lop3_permutate.py delete mode 100644 python/bitblas/ops/matmul.py delete mode 100644 python/bitblas/ops/matmul_dequantize.py delete mode 100644 python/bitblas/ops/operator.py delete mode 100644 python/bitblas/ops/param_permutate.py delete mode 100644 python/bitblas/quantization/__init__.py delete mode 100644 python/bitblas/quantization/quantization.py delete mode 100644 python/bitblas/quantization/utils.py delete mode 100644 python/bitblas/relax/op/interleave_weight.py delete mode 100644 python/bitblas/relax/transform/__init__.py delete mode 100644 python/bitblas/relax/transform/annotate_decode_block.py delete mode 100644 python/bitblas/relax/transform/weight_only_propagate.py delete mode 100644 python/bitblas/testing/__init__.py delete mode 100644 python/bitblas/utils/__init__.py delete mode 100644 python/bitblas/utils/post_process.py delete mode 100644 python/bitblas/utils/target_detector.py delete mode 100644 python/bitblas/utils/tensor_adapter.py delete mode 100644 python/bitblas/wrapper/__init__.py delete mode 100644 python/bitblas/wrapper/general.py delete mode 100644 python/bitblas_cli.py delete mode 100644 testing/python/dsl/test_auto_normalized_tensorcore.py diff --git a/install.sh b/install.sh new file mode 100755 index 000000000..584392820 --- /dev/null +++ b/install.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# install requirements +pip install -r requirements.txt + +# install llvm +apt-get install llvm-10 + +# clone and build tvm +git submodule update --init --recursive + +cd 3rdparty/tvm +mkdir build +cp cmake/config.cmake build +cd build +echo "set(USE_LLVM llvm-config-10)" >> config.cmake && echo "set(USE_CUDA ON)" >> config.cmake + +cmake .. && make -j && cd ../../.. + +echo "export TVM_HOME=$(pwd)/3rdparty/tvm" >> ~/.bashrc +echo "export PYTHONPATH=\$TVM_HOME/python:$(pwd)/python:\$PYTHONPATH" >> ~/.bashrc + +source ~/.bashrc diff --git a/python/bitblas/__init__.py b/python/bitblas/__init__.py deleted file mode 100644 index 14b510845..000000000 --- a/python/bitblas/__init__.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import sys -import os - -# installing tvm -install_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "3rdparty", "tvm", "python") -if os.path.exists(install_tvm_path) and install_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = install_tvm_path + ":" + os.environ.get("PYTHONPATH", "") - sys.path.insert(0, install_tvm_path) - -develop_tvm_path = os.path.join( - os.path.dirname(os.path.abspath(__file__)), "..", "..", "3rdparty", "tvm", "python") -if os.path.exists(develop_tvm_path) and develop_tvm_path not in sys.path: - os.environ["PYTHONPATH"] = develop_tvm_path + ":" + os.environ.get("PYTHONPATH", "") - sys.path.insert(0, develop_tvm_path) - -from . import gpu # noqa: F401 -from .base import ( - TileDevice, # noqa: F401 - fast_tune, # noqa: F401 - ApplyDefaultSchedule, # noqa: F401 - ApplyFastTuning, # noqa: F401 - BlockInfo, # noqa: F401 - IterInfo, # noqa: F401 - ScheduleRule, # noqa: F401 - normalize_prim_func, # noqa: F401 - try_inline, # noqa: F401 - try_inline_contiguous_spatial, # noqa: F401 -) - -from . import testing # noqa: F401 -from .utils import auto_detect_nvidia_target # noqa: F401 -from .ops.general_matmul import MatmulConfig, Matmul # noqa: F401 -from .ops.general_matmul_splitk import MatmulConfigWithSplitK, MatmulWithSplitK # noqa: F401 -from .ops.matmul_dequantize import MatmulWeightOnlyDequantizeConfig, MatmulWeightOnlyDequantize # noqa: F401 -from .module import Linear # noqa: F401 - -import logging -from tqdm import tqdm - - -class TqdmLoggingHandler(logging.Handler): - """ Custom logging handler that directs log output to tqdm progress bar to avoid interference. """ - - def __init__(self, level=logging.NOTSET): - """ Initialize the handler with an optional log level. """ - super().__init__(level) - - def emit(self, record): - """ Emit a log record. Messages are written to tqdm to ensure output in progress bars isn't corrupted. """ - try: - msg = self.format(record) - tqdm.write(msg) - except Exception: - self.handleError(record) - - -def set_log_level(level): - """ Set the logging level for the module's logger. - - Args: - level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO). - OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL' - """ - if isinstance(level, str): - level = getattr(logging, level.upper(), logging.INFO) - logger = logging.getLogger(__name__) - logger.setLevel(level) - - -def _init_logger(): - """ Initialize the logger specific for this module with custom settings and a Tqdm-based handler. """ - logger = logging.getLogger(__name__) - handler = TqdmLoggingHandler() - formatter = logging.Formatter( - fmt="%(asctime)s [BitBLAS:%(levelname)s]: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") - handler.setFormatter(formatter) - logger.addHandler(handler) - logger.propagate = False - set_log_level('WARNING') - - -_init_logger() - -__version__ = "0.0.1.dev12" diff --git a/python/bitblas/base/__init__.py b/python/bitblas/base/__init__.py deleted file mode 100644 index 122c44cbd..000000000 --- a/python/bitblas/base/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Base infra""" -from .analysis import ( - BlockInfo, - IterInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, - is_broadcast_epilogue, - normalize_prim_func, -) -from .common_schedules import get_block, get_output_blocks, try_inline, try_inline_contiguous_spatial -from .schedule_rule import ScheduleRule -from .transform import ApplyDefaultSchedule, ApplyFastTuning -from .utils import fast_tune, fast_tune_with_dynamic_range -from .roller import * diff --git a/python/bitblas/base/analysis.py b/python/bitblas/base/analysis.py deleted file mode 100644 index eb9c19415..000000000 --- a/python/bitblas/base/analysis.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Analysis on TIR blocks, loops and functions.""" -from typing import List, Optional, Set, Union -from typing_extensions import Literal - -from tvm import ir, tir, DataType -from tvm._ffi import get_global_func -from tvm.target.target import Target -from tvm.tir import Schedule, IterVar -from tvm.tir.schedule import BlockRV - - -class IterInfo: - """Information about a loop/iter var.""" - - kind: Literal["S", "R", "O"] - var: tir.Var - _dom: tir.PrimExpr - loop_rv: tir.schedule.LoopRV - - def __init__( - self, - kind: Literal["S", "R", "O"], - var: tir.Var, - dom: tir.PrimExpr, - loop_rv: tir.schedule.LoopRV, - ): - """Construct an IterInfo object.""" - self.kind = kind - self.var = var - self._dom = dom - self.loop_rv = loop_rv - - @property - def dom(self) -> Union[int, tir.PrimExpr]: - """The iteration domain of the loop.""" - return int(self._dom) if isinstance(self._dom, tir.IntImm) else self._dom - - def __str__(self) -> str: - return f'Iter("{self.kind}", {self.dom})' - - def __repr__(self) -> str: - return str(self) - - -class BlockInfo: - """Information about a TIR block.""" - - name: str - iters: List[IterInfo] - block_rv: tir.schedule.BlockRV - _reduction_block: bool - - def __init__( - self, - name: str, - iters: List[IterInfo], - block_rv: tir.schedule.BlockRV, - reduction_block: bool = False, - ): - """Construct a BlockInfo object.""" - self.name = name - self.block_rv = block_rv - self.iters = iters - self._reduction_block = reduction_block - - def dom(self) -> List[Union[int, tir.PrimExpr]]: - """The iteration domain of the block.""" - return [i.dom for i in self.iters] - - def dom_kind(self) -> str: - """The iteration domain kind of the block, for example, SSSS, SSSR.""" - return "".join(i.kind for i in self.iters) - - def is_injective(self) -> bool: - """Whether the block is injective, i.e. all its iteration domains are injective.""" - return all(k == "S" for k in self.dom_kind()) - - def is_elementwise(self, sch: tir.Schedule) -> bool: - """Whether the block is elementwise, i.e. trivial mapping between read/write region""" - - def _check_unit_var_range(dom: ir.Range, var: tir.Var) -> bool: - return dom.min.same_as(var) and dom.extent == 1 - - if not self.is_injective(): - return False - block = sch.get(self.block_rv) - if len(block.reads) != 1 or len(block.writes) != 1: - return False - r_region = block.reads[0].region - w_region = block.writes[0].region - if len(r_region) != len(w_region): - return False - for var, r_dom, w_dom in zip(block.iter_vars, r_region, w_region): - if not _check_unit_var_range(var, r_dom) or not _check_unit_var_range(var, w_dom): - return False - return True - - def is_reduction(self) -> bool: - """Whether the block is a reduction workload.""" - # TODO(@junrushao): distinguish GEMV and reduction - return self._reduction_block - - def is_gemv(self) -> bool: - """Whether the block is a GEMV workload.""" - raise NotImplementedError - - def is_gemm(self) -> bool: - """Whether the block is a GEMM workload.""" - raise NotImplementedError - - def __str__(self) -> str: - return f'BlockInfo("{self.name}", "{self.dom_kind()}", {self.dom()})' - - def __repr__(self) -> str: - return str(self) - - -_normalize_prim_func = get_global_func("tir.schedule.NormalizePrimFunc") - - -def normalize_prim_func(sch: tir.Schedule) -> Optional[List[BlockInfo]]: - """Normalize the primfunc to normal form""" - try: - result = _normalize_prim_func(sch) - if result is None: - return None - except Exception: # pylint: disable=broad-except - return None - - def _iter_kind(i: tir.IterVar) -> str: - return { - tir.IterVar.DataPar: "S", - tir.IterVar.CommReduce: "R", - }.get(i.iter_type, "O") - - blocks: List[BlockInfo] = [] - for block, loops, iters, is_reduction in zip(*result): - blocks.append( - BlockInfo( - name=sch.get(block).name_hint, - iters=[ - IterInfo( - kind=_iter_kind(iter), # type: ignore - var=iter.var, - dom=iter.dom, - loop_rv=loop, - ) for loop, iter in zip(loops, iters) - ], - block_rv=block, - reduction_block=is_reduction, - )) - return blocks - - -def find_var_from_func(func, var: str): - for buffer in func.buffer_map.values(): - for i in buffer.shape: - if isinstance(i, tir.Var) and i.name == var: - return i - return None - - -def check_func_with_dynamic(func): - for buffer in func.buffer_map.values(): - for i in buffer.shape: - if isinstance(i, tir.Var): - return True - return False - - -def _assert_gpu_target(target: Target): - if "gpu" not in target.keys: - raise ValueError(f"Expect a GPU target, but got {target}") - - -def get_max_threads_per_block(target: Target) -> int: - _assert_gpu_target(target) - max_threads_per_block = None - for name in ["max_threads_per_block", "max_num_threads"]: - if max_threads_per_block is None: - max_threads_per_block = target.attrs.get(name, None) - if max_threads_per_block is None: - max_threads_per_block = 64 - return int(max_threads_per_block) - - -def get_max_shared_memory_per_block(target: Target) -> int: - _assert_gpu_target(target) - max_shared_memory_per_block = target.attrs.get("max_shared_memory_per_block", None) - if max_shared_memory_per_block is None: - raise ValueError( - f"Cannot find `max_shared_memory_per_block` in {target}, please specify it manually") - return int(max_shared_memory_per_block) - - -def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV: - try: - block = sch.mod[func_name].body.block - except Exception: - raise ValueError(f"The function body is expected to be the root block, but got:\n" - f"{sch.mod[func_name].body}") from None - return sch.get_block(block.name_hint) - - -def collect_block_iter_vars_used_in_access_region(block: tir.Block, - region: List[ir.Range]) -> Set[tir.Var]: - """Collect the block iter variables used in the access region of a buffer region.""" - tir_vars = set() - for expr in region: - if expr.extent != 1: - continue - tir_vars |= collect_vars_used_in_prim_expr(expr.min) - tir_vars &= set(iter_var.var for iter_var in block.iter_vars) - return tir_vars - - -def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]: - """Collect the variables used in the PrimExpr.""" - tir_vars = set() - - def _collect_tir_var(expr): - if isinstance(expr, tir.Var): - tir_vars.add(expr) - - tir.stmt_functor.post_order_visit(expr, _collect_tir_var) - return tir_vars - - -def detect_dominant_read(block: tir.Block) -> tir.PrimExpr: - """Detect the dominant read indices in the block.""" - dominant_read = None - num_read_iters = -1 - for buffer_region in block.reads: - tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region) - if num_read_iters < len(tir_vars): - num_read_iters = len(tir_vars) - dominant_read = buffer_region - assert dominant_read is not None - (result,) = dominant_read.buffer.offset_of([e.min for e in dominant_read.region]) - return result - - -def is_broadcast_epilogue( - sch: tir.Schedule, - block: tir.schedule.BlockRV, - epilogue: tir.schedule.BlockRV, -) -> bool: - """Check if the epilogue block is a broadcast pattern""" - write_buffers = {r.buffer for r in sch.get(block).writes} - epilogue_iters = {i.var: i for i in sch.get(epilogue).iter_vars if i.dom != 1} - for buffer_region in sch.get(epilogue).reads: - if buffer_region.buffer not in write_buffers: - continue - tir_vars = collect_block_iter_vars_used_in_access_region( - sch.get(epilogue), buffer_region.region) - if len(tir_vars) < len(epilogue_iters): - return True - return False - - -def get_reduction_blocks(sch: tir.Schedule, - blocks: List[tir.schedule.BlockRV]) -> List[tir.schedule.BlockRV]: - # Get the main computation block - def is_reduction(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.CommReduce, IterVar.DataPar} - - def is_spatial(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.DataPar} - - # NOTE: We assume there is only one reduction block in the function - # all blocks are required to be spatial or reduction - if not all([is_reduction(block) or is_spatial(block) for block in blocks]): - return None - - # There is only one reduction block - reduction_blocks = [block for block in blocks if is_reduction(block)] - if len(reduction_blocks) == 0: - return None - return reduction_blocks - - -def get_coalesced_veclen(block_stmt: tir.Block, target_bits: int = 128) -> int: - # gpu memory prefer 128 bits coalesced access (e.g. four banks) - # 128 bits - buffers: List[tir.Buffer] = [] - for read in block_stmt.reads: - buffers.append(read.buffer) - for write in block_stmt.writes: - buffers.append(write.buffer) - # pick the dtype with the largest bits - max_dtype_bits: int = 0 - for buffer in buffers: - max_dtype_bits = max(max_dtype_bits, DataType(buffer.dtype).bits) - return target_bits // max_dtype_bits diff --git a/python/bitblas/base/common_schedules.py b/python/bitblas/base/common_schedules.py deleted file mode 100644 index 7d528c70a..000000000 --- a/python/bitblas/base/common_schedules.py +++ /dev/null @@ -1,163 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm common_schedules.py in dlight. -"""Common schedule strategies for TIR.""" -from typing import Callable, List - -from tvm import tir - -from .analysis import BlockInfo - - -def get_block( - sch: tir.Schedule, - blocks: List[BlockInfo], - name: str, -): - """Get the target block from a schedule. - - Parameters - ---------- - sch : tir.Schedule - The TIR schedule used to get target block. - name : str - The name of the target block. - - Returns - ------- - target_block : BlockRV - The target block. - """ - - target_block: tir.BlockRV = None - for block_info in blocks: - block = block_info.block_rv - if sch.get(block).name_hint == name: - target_block = block - return target_block - - -def get_output_blocks( - sch: tir.Schedule, - blocks: List[BlockInfo], -): - """Get the output blocks of a schedule. - - Parameters - ---------- - sch : tir.Schedule - The TIR schedule used to get output blocks. - blocks : List[BlockInfo] - The blocks to be analyzed. - - Returns - ------- - output_blocks : List[BlockInfo] - The output blocks. - """ - - # collect arguments buffer - func = sch.mod["main"] - args = list(func.buffer_map.values()) - - output_blocks = [] - for block_info in blocks: - block = block_info.block_rv - for write in sch.get(block).writes: - if write.buffer in args: - output_blocks.append(block) - - return output_blocks - - -def try_inline( - sch: tir.Schedule, - blocks: List[BlockInfo], -) -> List[BlockInfo]: - """Try to inline as many blocks as possible, and return the remaining blocks. - - Parameters - ---------- - sch : tir.Schedule - The TIR schedule used to inline blocks. - blocks : List[BlockInfo] - The blocks to be inlined. - - Returns - ------- - remaining : List[BlockInfo] - The remaining blocks that cannot be inlined. - """ - - def _trial(func: Callable): - for i, block in enumerate(blocks): - try: - func(block.block_rv) - except Exception: # pylint: disable=bare-except - continue - return i - return None - - while True: - i = _trial(sch.compute_inline) - if i is None: - i = _trial(sch.reverse_compute_inline) - if i is None: - break - blocks.pop(i) - return blocks - - -def try_inline_contiguous_spatial( - sch: tir.Schedule, - block_infos: List[BlockInfo], -) -> List[BlockInfo]: - """Try to inline contiguous spatial blocks in a schedule - - Parameters - ---------- - sch : tir.Schedule - The TIR schedule used to inline blocks. - block_infos : List[BlockInfo] - The blocks to be try. - - Returns - ------- - remaining : List[BlockInfo] - The remaining blocks that cannot be inlined. - """ - - if block_infos is None: - return None - results = [] - spatial_blocks = [] - block: BlockInfo - for block in block_infos: - if block.is_injective(): - spatial_blocks.append(block) - elif spatial_blocks: - results.extend(try_inline(sch, spatial_blocks)) - results.append(block) - spatial_blocks = [] - else: - results.append(block) - if spatial_blocks: - results.extend(try_inline(sch, spatial_blocks)) - return results diff --git a/python/bitblas/base/roller/__init__.py b/python/bitblas/base/roller/__init__.py deleted file mode 100644 index 9afd7cff0..000000000 --- a/python/bitblas/base/roller/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .node import PrimFuncNode # noqa: F401 -from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401 -from .hint import Hint # noqa: F401 -from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401 -from .arch import TileDevice, CUDA # noqa: F401 diff --git a/python/bitblas/base/roller/arch/__init__.py b/python/bitblas/base/roller/arch/__init__.py deleted file mode 100644 index 9cb036792..000000000 --- a/python/bitblas/base/roller/arch/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .arch_base import TileDevice -from .cuda import * -from .cpu import * - - -def get_arch(target: tvm.target.Target) -> TileDevice: - if target.kind.name == "cuda": - return CUDA(target) - elif target.kind.name == "llvm": - return CPU(target) - else: - raise ValueError(f"Unsupported target: {target.kind.name}") diff --git a/python/bitblas/base/roller/arch/arch_base.py b/python/bitblas/base/roller/arch/arch_base.py deleted file mode 100644 index 6e98838c7..000000000 --- a/python/bitblas/base/roller/arch/arch_base.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import List - - -class TileDevice: - """ - Represents the architecture of a computing device, capturing various hardware specifications. - """ - - def __init__(self) -> None: - self.reg_cap: int = 0 # Register capacity: The amount of register memory available - self.smem_cap: int = 0 # Shared memory capacity: The amount of shared memory available - self.compute_max_core: int = 0 # The maximum number of computing cores - self.warp_size: int = ( - 0 # The size of a warp, a group of threads that execute instructions in lockstep - ) - self.sm_partition: int = 0 # The number of streaming multiprocessor partitions - self.transaction_size: List[int] = [ - 0, - 0, - ] # The size of memory transactions, typically in bytes - self.max_smem_usage: int = 0 # The maximum shared memory usage allowed - self.bandwidth: List[int] = [ - 0, - 0, - ] # Bandwidth specifications, possibly including peak and sustained rates - self.platform: str = "unknown" # The platform or manufacturer of the device - self.compute_capability: str = ( - "unknown" # The compute capability, indicating the feature set and performance level - ) - self.l2_cache_size_bytes: int = 0 - # the number of transaction size in bytes - self.transaction_size: List[int] = [0, 0] # in bytes - # bandwidth in MB/s, will be used for recommend basic tile size - self.bandwidth: List[int] = [0, 0] - - def get_avaliable_tensorintrin_shapes(self): - raise NotImplementedError() diff --git a/python/bitblas/base/roller/arch/cpu.py b/python/bitblas/base/roller/arch/cpu.py deleted file mode 100644 index 98fb14af5..000000000 --- a/python/bitblas/base/roller/arch/cpu.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import tvm -from tvm.target import Target -from .arch_base import TileDevice - - -# For LLVM Backend, we do not provide the detailed information of the CPU -# As the LLVM backend do not required tuning, just maintain the consistency -class CPU(TileDevice): - - def __init__(self, target: Target): - self.target = target - device = tvm.runtime.cpu(0) - if not device.exist: - raise RuntimeError("Cannot find cpu device 0.") - self.device: tvm.runtime.Device = device - self.platform: str = "CPU" diff --git a/python/bitblas/base/roller/arch/cuda.py b/python/bitblas/base/roller/arch/cuda.py deleted file mode 100644 index 2189947e7..000000000 --- a/python/bitblas/base/roller/arch/cuda.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import tvm -from tvm.target import Target -from .arch_base import TileDevice -from typing import List, Dict, Union - - -def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - -class TensorInstruction(object): - - def __init__( - self, - name: str, - intrin_group: Dict, - shape: List[int], - ): - self.name: str = name - self.intrin_group: Dict = intrin_group - # only maintain the shape of M and N - self.shape: List[int] = shape - - -class CUDA(TileDevice): - - def __init__(self, target: Union[Target, str]): - if isinstance(target, str): - target = tvm.target.Target(target) - self.target = target - self.sm_version = check_sm_version(self.target.arch) - device = tvm.runtime.cuda(0) - if not device.exist: - raise RuntimeError("Cannot find cuda device 0.") - self.device: tvm.runtime.Device = device - self.platform: str = "CUDA" - self.smem_cap = device.max_shared_memory_per_block - self.compute_max_core = device.multi_processor_count - self.warp_size = device.warp_size - self.compute_capability = device.compute_version.replace(".", "") - self.reg_cap: int = 65536 - self.max_smem_usage: int = 2 * self.smem_cap - self.sm_partition: int = 4 - self.l2_cache_size_bytes: int = target.l2_cache_size_bytes - # the number of transaction size in bytes - self.transaction_size: List[int] = [32, 128] # in bytes - # bandwidth in MB/s, will be used for recommend basic tile size - # TODO(lei): find some way to get the real bandwidth - # However, the ratio of bandwidth between different devices can - # be similar. The bandwidth can work for another devices as well. - self.bandwidth: List[int] = [750, 12080] - # get the available tensor instructions during runtime to avoid - # the dependency of the tensor intrinsics registration - self.available_tensor_instructions: List[TensorInstruction] = None - - def get_avaliable_tensorintrin_shapes(self): - from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group, get_mma_intrin_group - - self.available_tensor_instructions = ( - TensorInstruction("mma", get_mma_intrin_group, [16, 16]), - TensorInstruction("wmma", get_wmma_intrin_group, [16, 16]), - ) - return [t.shape for t in self.available_tensor_instructions] diff --git a/python/bitblas/base/roller/bestfit.py b/python/bitblas/base/roller/bestfit.py deleted file mode 100644 index ad8ec20a8..000000000 --- a/python/bitblas/base/roller/bestfit.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Benifit For BitBLAS Schedule""" -class Block: - def __init__(self, start, end, is_free): - self.start = start - self.end = end - self.is_free = is_free - - def size(self) -> int: - return self.end - self.start - - def merge(self, other): - assert self.is_free == other.is_free - self.start = min(self.start, other.start) - self.end = max(self.end, other.end) - - def __repr__(self) -> str: - return "".format(self.start, self.size()) - - -class BestFit: - def __init__(self, align=32): - self.limit = 0 - self.list = [] - self.align = align - - def malloc(self, size) -> Block: - size = (size + self.align - 1) // self.align * self.align - found = None - for block in self.list: - if block.is_free and block.size() >= size: - if not found or found.size() > block.size(): - found = block - if found: - found.is_free = False - remain = found.size() - size - if remain != 0: - found.end -= remain - self.list.insert( - self.list.index(found) + 1, Block(found.end, found.end + remain, True) - ) - return found - elif len(self.list) > 0 and self.list[-1].is_free: - add = size - self.list[-1].size() - self.list[-1].end += add - self.limit = self.list[-1].end - self.list[-1].is_free = False - return self.list[-1] - else: - block = Block(self.limit, self.limit + size, False) - self.list.append(block) - self.limit += size - return block - - def free(self, block: Block) -> None: - assert not block.is_free - idx = self.list.index(block) - self.list[idx] = Block(block.start, block.end, True) - if idx + 1 < len(self.list) and self.list[idx + 1].is_free: - self.list[idx].merge(self.list[idx + 1]) - self.list.pop(idx + 1) - if idx - 1 >= 0 and self.list[idx - 1].is_free: - self.list[idx].merge(self.list[idx - 1]) - self.list.pop(idx - 1) diff --git a/python/bitblas/base/roller/hint.py b/python/bitblas/base/roller/hint.py deleted file mode 100644 index f6e2fb03a..000000000 --- a/python/bitblas/base/roller/hint.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Hint definition for schedule""" -from typing import Dict, List, Tuple -from . import PrimFuncNode -import numpy as np -from .rasterization import * - - -class TensorCoreExtraConfig: - """ - This class is used to store extra information for tensorcore - """ - - def __init__( - self, - AS_shape: Tuple[int], - BS_shape: Tuple[int], - AF_shape: Tuple[int], - BF_shape: Tuple[int], - tc_axis: Tuple[int], - ) -> None: - self.AS_shape: Tuple[int] = AS_shape - self.BS_shape: Tuple[int] = BS_shape - self.AF_shape: Tuple[int] = AF_shape - self.BF_shape: Tuple[int] = BF_shape - self.tc_axis: Tuple[int] = tc_axis - - -class Stride: - """ - Manages stride information for a given axis of a tensor. - """ - - def __init__(self, stride: int = 1, ax: int = -1) -> None: - # which axis to put stride on - self._ax: int = int(ax) - # the stride size of the axis - self._stride: int = int(stride) - - @property - def ax(self) -> int: - return self._ax - - @property - def stride(self) -> int: - return self._stride - - def compute_strides_from_shape(self, shape: List[int]) -> List[int]: - ndim = len(shape) - strides = [1 for _ in shape] - for i in range(ndim - 2, -1, -1): - if i == self.ax: - strides[i] = self.stride - else: - strides[i] = int(strides[i + 1] * shape[i + 1]) - return strides - - def compute_elements_from_shape(self, shape: List[int]) -> int: - original_shape = np.prod(shape) - if not self.is_valid(): - strided_elem = original_shape - else: - assert self.ax < len(shape) - strided_elem = np.prod(shape[0:self.ax + 1]) * self.stride - assert strided_elem >= original_shape - return int(strided_elem) - - def is_valid(self) -> bool: - return self.ax >= 0 - - def __repr__(self) -> str: - return f"" - - -class TileDict: - """ - Manages tiling information and configurations for computational tasks. - """ - - def __init__(self, output_tile) -> None: - self.output_tile = output_tile - # schedule config - self.tile_map = {} - self.rstep_map = {} - self.cached_tensors_map = {} - self.output_strides_map = {} - self.tensor_strides_map = {} - - # analysis - self.traffic = -1 - self.smem_cost = -1 - self.block_per_SM = -1 - self.num_wave = -1 - self.grid_size = -1 - self.valid = True - - def get_tile(self, func) -> List[int]: - return self.tile_map[func] - - def get_rstep(self, func) -> Dict[str, int]: - return self.rstep_map - - def __hash__(self) -> int: - return hash(tuple(self.output_tile)) - - -class IntrinInfo: - """ - The information of tensorcore intrinsic related information - """ - - def __init__( - self, - in_dtype: str, - out_dtype: str, - trans_b: bool, - input_transform_kind: int = 0, - weight_transform_kind: int = 0, - ) -> None: - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.trans_a = False - self.trans_b = trans_b - self.input_transform_kind = input_transform_kind - self.weight_transform_kind = weight_transform_kind - - def __repr__(self) -> str: - return f"" - - @property - def smooth_a(self) -> bool: - return self.input_transform_kind >= 2 - - @property - def smooth_b(self) -> bool: - return self.weight_transform_kind >= 2 - - @property - def inter_transform_a(self) -> bool: - return self.input_transform_kind >= 1 - - @property - def inter_transform_b(self) -> bool: - return self.weight_transform_kind >= 1 - - -class Hint(object): - """ - Central configuration class for managing various parameters of computational tasks. - """ - - def __init__(self) -> None: - self.arch = None - self.use_tc = None # todo(lei): this should be renamed. - - # Special axes tiling info - self.block = [] - self.thread = [] - # Special axes for MMA - self.warp = [] - # Reduce axes tiling info - self.rstep = [] - self.reduce_thread = [] - self.rasterization_plan = NoRasterization() - self.cached_tensors = [] - self.output_strides = {} - self.schedule_stages = None - # Config for block reduction - self.block_reduction_depth = None # type: int - - # Experimental - self._raxis_order = [] - self._step = [] - self.vectorize: Dict[str, int] = {} - self.pipeline_stage = 1 - self.use_async = False - self.opt_shapes: Dict[str, int] = {} - self.intrin_info = IntrinInfo("float16", "float16", True) - self.shared_scope: str = "shared" - self.pass_context: Dict = {} - - def to_dict(self) -> Dict: - dic = {} - dic["block"] = self.block - if self.use_tc: - dic["warp"] = self.warp - else: - dic["thread"] = self.thread - dic["rstep"] = self.rstep - if np.prod(self.reduce_thread) > 1: - dic["reduce_thread"] = self.reduce_thread - if self.use_tc: - dic["use_tc"] = self.use_tc - if self.output_strides: - dic["strides"] = {} - for k, stride in self.output_strides.items(): - if stride.is_valid(): - dic["strides"][k] = stride - if len(dic["strides"]) == 0: - del dic["strides"] - if np.prod(self._step) > 1: - dic["step"] = self._step - if self._raxis_order != []: - dic["raxis_order"] = self._raxis_order - if self.vectorize != {}: - dic["vectorize"] = self.vectorize - if self.pipeline_stage != 1: - dic["pipeline_stage"] = self.pipeline_stage - if self.block_reduction_depth is not None: - dic["block_reduction_depth"] = self.block_reduction_depth - return dic - - def from_dict(self, dic: Dict) -> "Hint": - self.__init__() - for k, v in dic.items(): - setattr(self, k, v) - return self - - def tensorcore_legalization(self): - # only keep the last 2 axes for tensorcore - self.warp = self.warp[-2:] - self.block = self.block[-2:] - return self - - @property - def raxis_order(self) -> List[int]: - if self._raxis_order != []: - return self._raxis_order - return list(range(len(self.rstep))) - - @property - def step(self) -> List[int]: - if self._step != []: - return self._step - return [1 for _ in self.block] - - def __repr__(self) -> str: - return str(self.to_dict()) - - def complete_config(self, node: PrimFuncNode): - # analysis pass context, for int8 mma, we should merge static shared memory - merge_static_smem = False - # int32 and float32 accum may take too much shared memory - if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]: - merge_static_smem = True - self.pass_context = {"tir.merge_static_smem": merge_static_smem} - return self diff --git a/python/bitblas/base/roller/node.py b/python/bitblas/base/roller/node.py deleted file mode 100644 index 8e20440bb..000000000 --- a/python/bitblas/base/roller/node.py +++ /dev/null @@ -1,408 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""PrimFunc Wrapper and Block information Analaysis""" - -import tvm -from tvm import tir -from tvm.tir import IterVar, PrimFunc -from typing import Any, Dict, List, Tuple, Optional -from tvm.tir.schedule.schedule import BlockRV -import numpy as np -import functools -from ..analysis import BlockInfo, get_reduction_blocks -from .. import analysis -from .. import normalize_prim_func -from .shape_inference import get_analyzer_by_tir - - -def pre_order_traverse(block_analyzer, blocks, func): - visited = set() - - def _traverse(block): - if block in visited: - return - visited.add(block) - for dep_blocks in block_analyzer.get_consumer_blocks(block): - _traverse(dep_blocks) - func(block) - - for block in blocks: - _traverse(block) - - -class BlockAnalyzer(object): - - def __init__(self, sch) -> None: - self.sch: tir.Schedule = sch - self.block_infos: List[BlockInfo] = normalize_prim_func(self.sch) - - def get_block_name(self, block: BlockRV) -> str: - return self.sch.get(block).name_hint - - def get_block_info(self, block: BlockRV) -> BlockInfo: - for block_info in self.block_infos: - if self.get_block_name(block) == block_info.name: - return block_info - return None - - def get_spatial_axis(self, block: BlockRV) -> List[IterVar]: - block_info = self.get_block_info(block) - axis = [] - for iter in block_info.iters: - if iter.kind == "S": - axis.append(iter) - return axis - - def get_reduce_axis(self, block: BlockRV) -> List[IterVar]: - block_info = self.get_block_info(block) - raxis = [] - for iter in block_info.iters: - if iter.kind == "R": - raxis.append(iter) - return raxis - - def get_input_buffers(self, block: BlockRV) -> List[tir.Buffer]: - buffers = [] - for read in self.sch.get(block).reads: - buffers.append(read.buffer) - return buffers - - def get_output_buffers(self, block: BlockRV) -> List[tir.Buffer]: - buffers = [] - for write in self.sch.get(block).writes: - buffers.append(write.buffer) - return buffers - - def get_buffers(self, block: BlockRV) -> List[tir.Buffer]: - return self.get_input_buffers(block) + self.get_output_buffers(block) - - def get_producer_blocks(self, block: BlockRV) -> List[BlockRV]: - return self.sch.get_producers(block) - - def get_consumer_blocks(self, block: BlockRV) -> List[BlockRV]: - return self.sch.get_consumers(block) - - -class Node(object): - - def __init__(self, tags: Optional[Dict] = None) -> None: - if tags is None: - tags = {} - self._dtypes = [] - self._tag: Dict = {} - for tag in tags: - self.add_tag(tag, tags[tag]) - - def set_tag(self, k: str, v: Any = True) -> None: - self.add_tag(k, v) - - def add_tag(self, k: str, v: Any = True) -> None: - self._tag[k] = v - - def get_tag(self, k: str) -> Any: - if k not in self._tag: - return None - return self._tag[k] - - -class PrimFuncNode(Node): - - def __init__(self, prim_func: PrimFunc, tags: Optional[Dict] = None) -> None: - super().__init__(tags) - self.prim_func = self._specialize_func(prim_func) - self.sch: tir.Schedule = tir.Schedule(self.prim_func) - self.block_analyzer: BlockAnalyzer = BlockAnalyzer(self.sch) - self.schedule_stages: List[BlockRV] = [] - self.blocks: List[BlockRV] = [] - self.output_blocks: List[BlockRV] = None - self.reduction_block: BlockRV = None - self.raxis = [] - self.input_buffers = [] - self.output_buffers = [] - self.buffers = [] - self.args = [] - self._analysis_funcinfo() - self.ana = get_analyzer_by_tir(self.block_analyzer, self.blocks) - - def _specialize_func(self, func: PrimFunc): - # Specialize the function to make it more friendly for analysis. - # set attrs - for k, v in func.attrs.items(): - self.set_tag(k, v) - if self.get_tag("is_speclized"): - return func - opt_shapes = self.get_tag("opt_shapes") - if opt_shapes: - for name, shape in opt_shapes.items(): - var = analysis.find_var_from_func(func, name) - if var is not None: - func = func.specialize({var: shape.astype(var.dtype)}) - return func - - def _analysis_funcinfo(self): - root_block = analysis.get_root_block(self.sch) - blocks = self.sch.get_child_blocks(root_block) - self.blocks = blocks - - self.output_blocks = self.sch.get_output_blocks(root_block) - reduction_blocks = get_reduction_blocks(self.sch, blocks) - if reduction_blocks is None: - self.reduction_block = None - self.schedule_stages.append(*self.output_blocks) - else: - # analysis on the last reduction block - self.reduction_block = reduction_blocks[-1] - # set raxis - reduce_block_info = self.block_analyzer.get_block_info(self.reduction_block) - for iter in reduce_block_info.iters: - if iter.kind == "R": - self.raxis.append(iter) - self.schedule_stages.append(self.reduction_block) - - # collect output buffers - for output_block in self.output_blocks: - for write in self.sch.get(output_block).writes: - if write not in self.output_buffers: - self.output_buffers.append(write.buffer) - - for param in self.prim_func.params: - if param not in self.prim_func.buffer_map: - # in case of dynamic symbolic may in params - continue - buffer = self.prim_func.buffer_map[param] - if buffer not in self.output_buffers: - self.input_buffers.append(buffer) - - self.args = self.input_buffers + self.output_buffers - self.buffers = [buffer for buffer in self.prim_func.buffer_map.values()] - - # set dtype - self.set_dtype(tvm.DataType(self.output_buffers[0].dtype)) - - def get_opt_shape(self, name) -> int: - opt_shapes = self.get_tag("opt_shapes") - if opt_shapes is None: - return None - return opt_shapes[name] - - def extent_wrapper(self, value) -> int: - if isinstance(value, tvm.tir.Var): - return self.get_opt_shape(value.name) - elif isinstance(value, tvm.tir.IntImm): - return int(value) - else: - return value - - @functools.lru_cache() - def get_space_dim(self) -> List[int]: - dim_size = [] - if self.reduction_block: - block_info = self.block_analyzer.get_block_info(self.reduction_block) - for iter in block_info.iters: - if iter.kind == "S": - if isinstance(iter.dom.extent, tvm.tir.IntImm): - dim_size.append(int(iter.dom.extent)) - else: - assert isinstance(iter.dom.extent, tvm.tir.Var) - dim_size.append(self.get_opt_shape(iter.dom.extent.name)) - else: - # assume outer stage has the same shape - loops = self.sch.get_loops(self.schedule_stages[0]) - for loop in loops: - dim_size.append(int(self.sch.get(loop).extent)) - return [int(x) for x in dim_size] - - def set_dtype(self, dtype: tvm.DataType, id=0) -> None: - assert isinstance(dtype, tvm.DataType), type(dtype) - if dtype == tvm.DataType("bool"): - dtype = tvm.DataType("int8") - if len(self._dtypes) <= id: - self._dtypes.extend([None for _ in range(id - len(self._dtypes) + 1)]) - elif self._dtypes[id] is not None: - assert self._dtypes[id] == dtype, (self._dtypes, dtype) - self._dtypes[id] = dtype - - def get_dtype(self, id=0) -> tvm.DataType: - return self._dtypes[id] - - def get_buffer_dtype(self, buffer: tir.Buffer) -> tvm.DataType: - return tvm.DataType(buffer.dtype) - - def propagate(self, tile, rstep: Optional[Dict] = None, targets=None): - if rstep is None: - rstep = {} - shape = { - self.block_analyzer.get_output_buffers(block)[0].name: - [tvm.arith.ConstIntBound(0, val - 1) for val in tile] for block in self.schedule_stages - } - return self.ana.infer(shape, rstep, targets) - - def propagate_inputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: - if rstep is None: - rstep = {} - read_idx_offset = len(self.input_buffers) - targets = [t.name for t in self.args[:read_idx_offset]] - shapes, intermediate_bind = self.propagate(tile, rstep, targets) - results = [] - for i, arg in enumerate(self.args[:read_idx_offset]): - if arg.name in intermediate_bind: - results.append(shapes[arg.name]) - continue - # should not exceed original shape - trimmed_shape = [ - self.extent_wrapper(i) - for i in list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) - ] - results.append(trimmed_shape) - return results - - # Propagate inputs only on reduction block - def propagate_inputs_on_reduction(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: - if rstep is None: - rstep = {} - reduction_block = self.reduction_block - args = self.block_analyzer.get_input_buffers(reduction_block) - targets = [t.name for t in args] - shapes, intermediate_bind = self.propagate(tile, rstep, targets) - results = [] - for i, arg in enumerate(args): - if arg.name in intermediate_bind: - results.append(shapes[arg.name]) - continue - # should not exceed original shape - propagate_shape = shapes[arg.name] - buffer_shape = args[i].shape - if len(buffer_shape) > len(propagate_shape): - buffer_shape = buffer_shape[-len(propagate_shape):] - trimmed_shape = [ - self.extent_wrapper(j) for j in list(map(min, zip(propagate_shape, buffer_shape))) - ] - results.append(trimmed_shape) - return results - - def propagate_outputs(self, tile, rstep: Optional[Dict] = None) -> List[List[int]]: - if rstep is None: - rstep = {} - read_idx_offset = len(self.input_buffers) - targets = [t.name for t in self.args[read_idx_offset:]] - shapes, _ = self.propagate(tile, rstep, targets) - results = [] - for i, arg in enumerate(self.args[read_idx_offset:]): - # should not exceed original shape - trimmed_shape = list(map(min, zip(shapes[arg.name], self.input_buffers[i].shape))) - results.append(trimmed_shape) - return results - - def propagate_reduction_inputs(self, - shape, - rstep: Optional[Dict] = None) -> Dict[str, List[int]]: - if rstep is None: - rstep = {} - if self.reduction_block is None: - return {} - targets = [b.name for b in self.block_analyzer.get_input_buffers(self.reduction_block)] - results, _ = self.propagate(shape, rstep, targets) - return results - - def get_reduce_inputs_dtype(self): - if self.reduction_block is None: - return {} - return { - b.name: tvm.DataType(b.dtype) - for b in self.block_analyzer.get_input_buffers(self.reduction_block) - } - - @functools.lru_cache() - def infer_tensorcore_axis(self) -> Tuple[int]: - # axis is fixed for one expression, so only inference and cached - assert self.get_tag("tensorcore_config") - - C_ax_m, C_ax_n = self.get_tag("tensorcore_config") - wmma_m, wmma_n, wmma_k = [16, 16, 16] # just for testing, any number is ok - - output_buffer_shape = ( - self.block_analyzer.sch.get(self.reduction_block).writes[0].buffer.shape) - valid_region = [] - for region in output_buffer_shape: - if region.value == 1: - continue - valid_region.append(region) - - num_nvalid_regions = len(output_buffer_shape) - len(valid_region) - self.set_tag("num_nvalid_regions", num_nvalid_regions) - - def get_cl_shapes(c_ax_m, c_ax_n, num_nvalid_regions): - spatial_dim = self.get_space_dim() - assert len(valid_region) == len( - spatial_dim), f" {valid_region} mismatch with {spatial_dim}" - cl_shapes = [1] * len(spatial_dim) - cl_shapes[c_ax_m - num_nvalid_regions] = wmma_m - cl_shapes[c_ax_n - num_nvalid_regions] = wmma_n - return cl_shapes - - CL_shape = get_cl_shapes(C_ax_m, C_ax_n, num_nvalid_regions) - self.set_tag("tensorcore_config", [s - num_nvalid_regions for s in [C_ax_m, C_ax_n]]) - shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: 1 for x in self.raxis}) - A_deps, B_deps = shapes.values() - A_ax_m = A_deps.index(wmma_m) - B_ax_n = B_deps.index(wmma_n) - - CL_shape = [1] * len(self.get_space_dim()) - shapes = self.propagate_reduction_inputs(CL_shape, {x.var.name: wmma_k for x in self.raxis}) - A_deps, B_deps = shapes.values() - A_ax_k = len(A_deps) - 1 - A_deps[::-1].index(wmma_k) - B_ax_k = len(B_deps) - 1 - B_deps[::-1].index(wmma_k) - tc_axis = (A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n) - return tc_axis - - def footprint(self, shape, rstep, stride_map: Optional[Dict] = None) -> int: - if stride_map is None: - stride_map = {} - result = 0 - shapes, _ = self.propagate(shape, rstep) - - def is_broadcast_pattern(buffer, output_buffer): - return (buffer in self.args and - len(shapes[output_buffer.name]) > len(shapes[buffer.name]) and - np.prod(shapes[output_buffer.name]) > np.prod(shapes[buffer.name])) - - def is_after_reduce_stage(block): - if not self.reduction_block: - return False - reduce_dependent_blocks = getattr(self, "reduce_dependent_blocks", None) - if reduce_dependent_blocks is None: - reduce_dependent_blocks = set() - pre_order_traverse( - self.block_analyzer, - [self.reduction_block], - lambda block: reduce_dependent_blocks.add(block), - ) - self.reduce_dependent_blocks = reduce_dependent_blocks - return block not in reduce_dependent_blocks - - # compute cached stages - cached_tensor = [] - for block in self.blocks: - output_buffer = self.block_analyzer.get_output_buffers(block)[0] - for buffer in self.block_analyzer.get_input_buffers(block): - cache = buffer.name not in cached_tensor and ( - is_broadcast_pattern(buffer, output_buffer) or - self.block_analyzer.get_block_info(block).is_reduction) - if not cache: - continue - cached_tensor.append(buffer.name) - if is_after_reduce_stage(block): - continue # cache after reduce op can often reuse buffer in reduce stage - - if buffer.name in stride_map: - num_elem = stride_map[buffer.name].compute_elements_from_shape( - shapes[buffer.name]) - else: - num_elem = np.prod(shapes[buffer.name]) - buffer_len = num_elem * int((tvm.DataType(buffer.dtype).bits + 7) // 8) - buffer_len = (buffer_len + 31) // 32 * 32 - result += buffer_len - return result, cached_tensor - - def get_input_buffers(self) -> List[tir.Buffer]: - return self.block_analyzer.input_buffers diff --git a/python/bitblas/base/roller/policy/__init__.py b/python/bitblas/base/roller/policy/__init__.py deleted file mode 100644 index 09ed1d51b..000000000 --- a/python/bitblas/base/roller/policy/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .default import DefaultPolicy -from .tensorcore import TensorCorePolicy diff --git a/python/bitblas/base/roller/policy/common.py b/python/bitblas/base/roller/policy/common.py deleted file mode 100644 index 9141550c8..000000000 --- a/python/bitblas/base/roller/policy/common.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import List -import numpy as np - - -def get_all_factors(n: int) -> List[int]: - # Calculate the square root of n and round it up to the nearest integer - n0 = int(np.ceil(np.sqrt(n))) - - # Find all divisors of n that are less than n0 - val = np.where(n % np.arange(1, n0) == 0)[0] + 1 - - # If n is a perfect square, add the square root to the list of factors - mid = np.array([], dtype=int) if n0 * n0 != n else [n0] - - # Combine the factors and their corresponding larger pair factors - return [int(x) for x in np.concatenate([val, mid, n // val[::-1]])] - - -def factorize(n: int) -> List[int]: - i = 2 # Start with the smallest prime number - result = [] - - # Iterate through numbers to find factors - while n > 1: - if n % i == 0: # If i is a factor of n - n //= i # Divide n by i and keep the integer part - result.append(i) - else: - i += 1 # Try the next number - return result - - -def coalesced_factor(subtensor: List[int], tensor: List[int]) -> int: - # If the last dimension of the subtensor and tensor differ, or subtensor has only one dimension - if subtensor[-1] != tensor[-1] or len(subtensor) == 1: - return subtensor[-1] - else: - # Recursively calculate the coalesced factor for the remaining dimensions - return subtensor[-1] * coalesced_factor(subtensor[:-1], tensor[:-1]) - - -def coalesced_tensor_shape(subtensor: List[int], tensor: List[int], transaction_size: int) -> int: - # Calculate the total number of elements in the subtensor - bytes = int(np.prod(subtensor)) - - if bytes == 0: - return 0 - - # Calculate the coalesced factor for the subtensor - factor = int(coalesced_factor(subtensor, tensor)) - - # Compute the shape of the coalesced tensor - return transaction_size * bytes / min(transaction_size, factor) diff --git a/python/bitblas/base/roller/policy/default.py b/python/bitblas/base/roller/policy/default.py deleted file mode 100644 index 81aeba123..000000000 --- a/python/bitblas/base/roller/policy/default.py +++ /dev/null @@ -1,748 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Policy for cuda core schedule""" -import functools -import math -from queue import PriorityQueue -from typing import Iterable, Dict, List, Optional - -import numpy as np -import tvm - -from ..arch import TileDevice -from ..bestfit import BestFit -from ..hint import Hint, Stride, TileDict -from .common import coalesced_factor, coalesced_tensor_shape, factorize, get_all_factors -from ..node import PrimFuncNode -from ..rasterization import NoRasterization - - -class DefaultPolicy: - """ - Default Policy for fastdlight, a heuristic plan that tries to - minimize memory traffic and maximize parallelism.for BitBLAS Schedule. - """ - - def __init__(self, - func: tvm.tir.PrimFunc, - arch: TileDevice, - tags: Optional[Dict] = None) -> None: - if tags is None: - tags = {} - self.arch = arch - self.prim_func_node = PrimFuncNode(func, tags) - self.ordered_nodes = [self.prim_func_node] - self.output_nodes = [self.prim_func_node] - - def emit_config(self, topk: int) -> List[Hint]: - base_tile = self.get_base_tile() - if base_tile is None: - return [] - - rstep_map = self._assign_reduce_step(self.prim_func_node) - smem_tile_condidates = self.dfs_smem_tile(base_tile, rstep_map) - results = [] - for td in smem_tile_condidates: - if not self.check_tile_shape_isvalid(td): - continue - - self._expand_reduce_axis(td) - for codegen_dicts in self.assign_block_size(td): - results.append(codegen_dicts) - if len(results) >= topk: - break - if len(results) >= topk: - break - return results - - def dfs_smem_tile(self, init_tile, rstep_map) -> Iterable[TileDict]: - _steps = [get_all_factors(n) for n in self.prim_func_node.get_space_dim()] - steps = [step[step.index(t):] for step, t in zip(_steps, init_tile)] - for i in range(len(steps)): - added = list( - filter( - lambda s: s < steps[i][-1] and s > steps[i][0] and s not in steps[i], - [2, 4, 8, 16, 32], - )) - steps[i].extend(added) - steps[i] = sorted(steps[i]) - visited_tiles = {} - queue = PriorityQueue() - - def prio(td: TileDict): - return (td.traffic + 1) * td.num_wave - - def add_to_queue(tile): - if tuple(tile) in visited_tiles: - return - td = self.compute_tile_dict(tile, rstep_map) - visited_tiles[tuple(tile)] = td - if td.valid: - queue.put([prio(td), tile]) - - add_to_queue(init_tile) - while not (queue.empty() or len(visited_tiles) > 2000): - _, tile = queue.get() - dim_ids = [step.index(t) for step, t in zip(steps, tile)] - for i in reversed(range(len(dim_ids))): - if dim_ids[i] + 1 < len(steps[i]): - new_tile = tile.copy() - new_tile[i] = steps[i][dim_ids[i] + 1] - add_to_queue(new_tile) - - visited_tiles = filter(lambda td: td.valid, visited_tiles.values()) - sorted_tiles = sorted(visited_tiles, key=lambda td: prio(td)) - return sorted_tiles - - def get_base_tile(self): - """ - Gets the minimum tile configuration that satisfies no redundancy in computation. - - Returns - ------- - List[int] - The base tile configuration, which is a list of 1s equal in length to the space dimensions - of the primary function node. - """ - shape = self.prim_func_node.get_space_dim() - base_tile = [1 for _ in shape] - - return base_tile - - # handles multiple output cases - def _get_output_tile_map(self, tile): - """ - Handles multiple output cases by mapping output nodes to their respective tile configurations. - - Parameters - ---------- - tile : List[int] - The tile configuration. - - Returns - ------- - Dict - A dictionary mapping the primary function node to its corresponding tile configuration - based on the output nodes' space dimensions. - """ - tile_map = {} - tile_map[self.prim_func_node] = [ - tile[i] * self.prim_func_node.get_space_dim()[i] // - self.output_nodes[0].get_space_dim()[i] for i in range(len(tile)) - ] - return tile_map - - def score_block_size(self, n): - """ - Scores a block size based on its efficiency and fit relative to the architecture's warp size and SM partition. - - Parameters - ---------- - n : int - The block size to score. - - Returns - ------- - Tuple[float, float] - A tuple containing two scores representing efficiency and fit, respectively. - """ - num_wrap = (n + self.arch.warp_size - 1) // self.arch.warp_size - r1 = max(num_wrap / self.arch.sm_partition, self.arch.sm_partition / num_wrap) - r2 = (num_wrap * self.arch.warp_size - n) / n - return (r1, r2) - - def get_block_size(self, n): - """ - Determines the optimal block size for a given constraint, based on scoring various factors. - - Parameters - ---------- - n : int - The constraint size. - - Returns - ------- - int - The optimal block size chosen from the factors of n, constrained by a maximum of 1024 and - scored by the `score_block_size` method. - """ - factors = get_all_factors(n) - factors = list(filter(lambda x: x <= 1024, factors)) - factor_ordered = sorted(factors, key=self.score_block_size) - return factor_ordered[0] - - def get_node_reduce_step_candidates(self, node: PrimFuncNode): - """ - Calculates reduction step candidates for each reduction axis in a PrimFuncNode. General idea : use factor first, since it does not require extra boundary check. for large prime number, which is rare case, use power of 2. - - Parameters - ---------- - node : PrimFuncNode - The node for which to calculate reduction step candidates. It contains reduction axes (raxis) - with their domains (dom.extent). - - Returns - ------- - Dict[str, List[int]] - A dictionary mapping axis variable names to lists of step candidates. For each axis in the node, - this function calculates possible step sizes. For axes with a large prime domain, it uses powers of 2 - as step candidates; for others, it uses all factors of the domain. - """ - - results = {} - for k_iter in node.raxis: - all_factors = get_all_factors(int(k_iter.dom.extent)) - if len(all_factors) == 2 and int(k_iter.dom.extent) > 64: - all_factors = [1] - while all_factors[-1] * 2 < int(k_iter.dom.extent): - all_factors.append(all_factors[-1] * 2) - results[k_iter.var.name] = all_factors - return results - - def _assign_reduce_step(self, node: PrimFuncNode): - """ - Assigns an optimal reduction step for the given PrimFuncNode. - - Parameters - ---------- - node : PrimFuncNode - The node for which the reduction step is to be assigned. - - Returns - ------- - Dict - A dictionary mapping reduction axis variable names to their optimal reduction steps. - """ - if node.reduction_block is None: - return {} - - raxis = node.raxis - tile = [1] * len(node.get_space_dim()) - all_steps = self.get_node_reduce_step_candidates(node) - - def sim(a: int, b: int): - return (2 * a * b) / (a * a + b * b) - - def _score(rstep_id): - rstep = {k: all_steps[k][rstep_id[k]] for k in rstep_id} - score = 0 - shape = node.propagate_inputs(tile, rstep=rstep) - for i, input_buffer in enumerate(node.input_buffers): - read_transaction_elements = self.arch.transaction_size[1] // ( - (node.get_buffer_dtype(input_buffer).bits + 7) // 8) - score += sim( - int(coalesced_factor(shape[i], input_buffer.shape)), - read_transaction_elements, - ) - return score - - def _enlarge(rstep_id): - candidates = [] - candidates.append((rstep_id, _score(rstep_id))) - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - best = max(candidates, key=lambda x: x[1]) - return best - - # enlarge rstep to ensure read is coaleased - cur_rstep_id = {ax.var.name: 0 for ax in raxis} - cur_score = _score(cur_rstep_id) - while True: - if cur_score == 0: - break - new_rstep, new_score = _enlarge(cur_rstep_id) - if new_score <= cur_score: - break - else: - cur_rstep_id, cur_score = new_rstep, new_score - rstep = {k: all_steps[k][cur_rstep_id[k]] for k in cur_rstep_id} - return rstep - - def _expand_reduce_axis(self, td: TileDict): - """ - Expands the reduction axis in the TileDict based on shared memory limits. - - Parameters - ---------- - td : TileDict - The TileDict object to be optimized. - - Returns - ------- - None - This function modifies the TileDict in place. - """ - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) - rstep_map = td.rstep_map.copy() - - def _optimize(node, rstep): - all_steps = self.get_node_reduce_step_candidates(node) - for k in all_steps: - all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) - - def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } - score = 0 - shape = node.propagate_inputs(td.get_tile(node), rstep=rstep) - for i, input_buffer in enumerate(node.input_buffers): - score += coalesced_factor(shape[i], input_buffer.shape) - return score - - def _enlarge(rstep_id): - candidates = [] - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - if len(candidates) == 0: - return None - return max(candidates, key=lambda x: x[1])[0] - - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } - new_rstep_map = rstep_map.copy() - while True: - new_rstep_id = _enlarge(cur_rstep_id) - if new_rstep_id is None: - break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } - old_rstep_map = td.rstep_map - td.rstep_map = new_rstep_map - smem_usage, _ = self._compute_shared_memory_usage(td) - td.rstep_map = old_rstep_map - if smem_usage > smem_limit: - break - else: - cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } - return rstep - - for node in self.ordered_nodes: - if len(node.raxis) > 0: - rstep = _optimize(node, rstep_map) - rstep_map = rstep - td.rstep_map = rstep_map - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - - def _compute_memory_traffic(self, output_tile): - """ - Computes the memory traffic for a given output tile configuration. - - Parameters - ---------- - output_tile : List[int] - The output tile configuration. - - Returns - ------- - Tuple[int, Dict] - The total memory traffic and a map of operation tiles. - """ - op_tile_map = self._get_output_tile_map(output_tile) - traffic = 0 - for node in reversed(self.ordered_nodes): - tile = op_tile_map[node] - input_shapes = node.propagate_inputs(tile) - output_shapes = node.propagate_outputs(tile) - for i, buffer in enumerate(node.input_buffers): - nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 - read_transaction_elements = self.arch.transaction_size[1] // nbytes - traffic += ( - coalesced_tensor_shape(input_shapes[i], buffer.shape, read_transaction_elements) - * nbytes) - for i, buffer in enumerate(node.output_buffers): - nbytes = (node.get_buffer_dtype(buffer).bits + 7) // 8 - write_transaction_elements = self.arch.transaction_size[0] // nbytes - traffic += ( - coalesced_tensor_shape(output_shapes[i], buffer.shape, - write_transaction_elements) * nbytes) - return traffic, op_tile_map - - def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): - """ - Infers the shared memory usage of a node given a TileDict configuration. - - Parameters - ---------- - td : TileDict - The TileDict object containing the tile configuration. - node : PrimFuncNode - The node for which to infer the shared memory usage. - - Returns - ------- - int - The estimated amount of shared memory used by the node. - """ - return node.footprint(td.get_tile(node), td.get_rstep(node), td.tensor_strides_map[node]) - - def _compute_shared_memory_usage(self, td: TileDict): - """ - Computes the stride map for a given node and TileDict configuration. - - Parameters - ---------- - node : PrimFuncNode - The node for which to compute the stride map. - td : TileDict - The TileDict object containing the tile configuration. - - Returns - ------- - Tuple[Dict, Dict] - The output strides and tensor strides. - """ - self._compute_stride_map(td) - allocator = BestFit() - block_map = {} - cached_tensors_map = {} - - node_internal_bytes, cached_tensors_map[self.prim_func_node] = self.infer_node_smem_usage( - td, self.prim_func_node) - block = allocator.malloc(node_internal_bytes) - allocator.free(block) - assert len(block_map) == 0 - return allocator.limit, cached_tensors_map - - def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): - """ - Computes the stride map for a given node based on the TileDict configuration. - - Parameters - ---------- - node : PrimFuncNode - The node for which to compute the stride map. - td : TileDict - The TileDict object containing the tile configuration. - - Returns - ------- - Tuple[Dict, Dict] - A tuple of dictionaries containing the output strides and tensor strides. - """ - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } - tensor_strides = {} - return output_strides, tensor_strides - - def _compute_stride_map(self, td: TileDict): - """ - Computes the stride map for all nodes in a TileDict. - - Parameters - ---------- - td : TileDict - The TileDict object for which to compute the stride maps. - - Returns - ------- - None - This function updates the TileDict object in place with the computed stride maps. - """ - output_strides_map = {} - tensor_strides_map = {} - for node in self.ordered_nodes: - output_strides_map[node], tensor_strides_map[node] = self.compute_node_stride_map( - node, td) - td.output_strides_map, td.tensor_strides_map = output_strides_map, tensor_strides_map - - def compute_tile_dict(self, output_tile: List[int], rstep_map) -> TileDict: - """ - Computes and returns a TileDict object for a given output tile configuration and reduction step map. - - Parameters - ---------- - output_tile : List[int] - The output tile configuration. - rstep_map : Dict - The reduction step map. - - Returns - ------- - TileDict - A TileDict object containing the computed tile configuration, memory traffic, shared memory cost, - grid size, and other related parameters. - """ - td = TileDict(output_tile) - td.rstep_map = rstep_map - td.traffic, td.tile_map = self._compute_memory_traffic(output_tile) - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - if td.smem_cost > self.arch.smem_cap: - td.valid = False - return td - output_shape = self.output_nodes[0].get_space_dim() - td.grid_size = int(np.prod([(y + x - 1) // x for x, y in zip(output_tile, output_shape)])) - # estimated reg usage - reg_usage = int(2 * max([ - np.prod(td.get_tile(node)) * node.get_dtype().bits / 32 for node in self.ordered_nodes - ])) - if reg_usage > self.arch.reg_cap: - td.valid = False - return td - td.block_per_SM = min( - self.arch.max_smem_usage // max(td.smem_cost, 1), - self.arch.reg_cap // max(reg_usage, 1), - self.arch.sm_partition, - ) - td.num_wave = int(np.ceil(td.grid_size / int(td.block_per_SM * self.arch.compute_max_core))) - return td - - def check_tile_shape_isvalid(self, td: TileDict) -> bool: - """ - Checks if the tile shapes in the TileDict are valid for the nodes in this context. - - Parameters: - - td (TileDict): The TileDict object containing tile shapes and other configurations. - - Returns: - - bool: True if all tile shapes are valid, False otherwise. - """ - for node in self.ordered_nodes: - if np.prod(td.get_tile(node)) == 0: - return False - node_grid_size = np.prod([ - (y + x - 1) // x for x, y in zip(td.get_tile(node), node.get_space_dim()) - ]) - if node_grid_size != td.grid_size: - return False - if (hasattr(node, "reduce_op") and node.reduce_op is not None and - len(node.reduce_op.axis) == len(td.output_tile)): - for i, tile_extent in enumerate(td.output_tile): - if node.reduce_op.axis[i].dom.extent % tile_extent: - return False - - return True - - def recommend_block_size(self, td: TileDict) -> List[int]: - """ - Recommends optimal block sizes based on the TileDict configuration. - - Parameters - ---------- - td : TileDict - The TileDict object containing the tile configuration. - - Returns - ------- - List[int] - A list of recommended block sizes sorted based on their score. - """ - node_space_sizes = [int(np.prod(td.get_tile(node))) for node in self.ordered_nodes] - max_block_size = functools.reduce(math.gcd, node_space_sizes) - - if max_block_size < self.arch.warp_size * self.arch.sm_partition and max_block_size == min( - node_space_sizes): - node_reduce_sizes = [ - int(np.prod(list(td.get_rstep(node).values()))) for node in self.ordered_nodes - ] - total_sizes = [x * y for x, y in zip(node_space_sizes, node_reduce_sizes)] - max_possible_size = functools.reduce(math.gcd, total_sizes) - possible_block_sizes = list( - filter( - lambda x: x % max_block_size == 0 and x <= 1024, - get_all_factors(max_possible_size), - )) - possible_block_sizes = list( - filter( # either be a factor of space or cover fully cover the space - lambda x: all([x % s == 0 or s % x == 0 for s in node_space_sizes]), - possible_block_sizes, - )) - factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) - return factor_ordered - else: - possible_block_sizes = get_all_factors(max_block_size) - possible_block_sizes = list(filter(lambda x: x <= 1024, possible_block_sizes)) - factor_ordered = sorted(possible_block_sizes, key=self.score_block_size) - return factor_ordered - - def assign_block_size(self, td: TileDict, topk=1): - """ - Assigns block sizes to the TileDict based on the recommended block sizes. - - Parameters - ---------- - td : TileDict - The TileDict object to assign block sizes to. - topk : int, optional - The number of top block sizes to consider. - - Yields - ------- - Dict - The block size assignment for the primary function node. - """ - block_size_ordered = self.recommend_block_size(td) - for block_size in block_size_ordered: - result = {} - failed = False - result = self._assign_block_size(self.prim_func_node, td, block_size) - if result is None: - failed = True - break - if failed: - continue - else: - yield result - topk -= 1 - if topk == 0: - break - - def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): - """ - Assigns a block size to a given PrimFuncNode based on the TileDict configuration and the specified block size. - - Parameters - ---------- - node : PrimFuncNode - The node to assign the block size to. - td : TileDict - The TileDict object containing the tile configuration. - block_size : int - The block size to be assigned. - - Returns - ------- - Hint - A Hint object containing the assigned block size and other related settings. - """ - tile, rsteps = td.get_tile(node), td.get_rstep(node) - factors = factorize(block_size) - cur_threads = [1 for _ in tile] - reduce_thread = {k: 1 for k in rsteps} - ndim = len(tile) - - def _score(node, thread): # small is better - score = 0 - block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] - shape = node.propagate_inputs(block_tile) - for i, _ in enumerate(node.input_buffers): - score += np.prod(shape[i]) / self.arch.bandwidth[1] - for buffer in node.output_buffers: - score += coalesced_tensor_shape(thread, buffer.shape, 8) / self.arch.bandwidth[0] - return score - - for factor in reversed(factors): - score_map = {} - for i in range(ndim): - if cur_threads[i] >= tile[i]: - continue - if (tile[i] % (cur_threads[i] * factor)) != 0: - continue - cur_threads[i] *= factor - score_map[i] = (_score(node, cur_threads), i) - cur_threads[i] //= factor - if len(score_map) > 0: - # assign to space axis - dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) - cur_threads[dim_order[0]] *= factor - else: - # assign to reduce axis - target_ax = None - for ax, ax_len in reversed(list(rsteps.items())): - if ax_len % (reduce_thread[ax] * factor) == 0: - target_ax = ax - break - assert target_ax - reduce_thread[target_ax] *= factor - - codegen_dict = Hint() - codegen_dict.block = tile - codegen_dict.thread = cur_threads - codegen_dict.rstep = [rsteps[ax.var.name] for ax in node.raxis] - codegen_dict.reduce_thread = [reduce_thread[ax.var.name] for ax in node.raxis] - codegen_dict.cached_tensors = td.cached_tensors_map[node] - codegen_dict.rasterization_plan = self.plan_rasterization(td) - - if node.get_dtype().bits == 16: # set step=2 for 16bit case to ensure coalesced access - codegen_dict._step = [1 for _ in range(ndim)] - for i in reversed(range(ndim)): - if codegen_dict.block[i] // codegen_dict.thread[i] % 2 == 0: - codegen_dict._step[i] = 2 - break - elif node.get_dtype().bits == 8: # set step=4 for 8bit case to ensure coalesced access - codegen_dict._step = [1 for _ in range(ndim)] - for i in reversed(range(ndim)): - if codegen_dict.block[i] // codegen_dict.thread[i] % 4 == 0: - codegen_dict._step[i] = 4 - break - # Plan vectorize - codegen_dict.vectorize = self._plan_vectorize(node, td, block_size) - codegen_dict.arch = self.arch - codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") - return codegen_dict - - def _plan_vectorize(self, node: PrimFuncNode, td: TileDict, block_size: int): - """ - Plans vectorization for a given PrimFuncNode based on the TileDict configuration and block size. - - Parameters - ---------- - node : PrimFuncNode - The node for which to plan vectorization. - td : TileDict - The TileDict object containing the tile configuration. - block_size : int - The block size used for vectorization planning. - - Returns - ------- - Dict - A dictionary mapping tensors to their vectorization size. - """ - - def is_cont(shape, vec): - if len(shape) == 0: - return vec == 1 - last = shape[-1] - if last == 1: - return is_cont(shape[0:-1], vec // last) - else: - return last % vec == 0 - - def is_shape_aligned(shape, factor): - return int(np.prod(shape)) % factor == 0 - - def is_type_allowed(dtype, vec): - return dtype.bits * vec <= 128 - - vectorize_sizes = [16, 8, 4, 2] - dtypes = node.get_reduce_inputs_dtype() - shapes = node.propagate_reduction_inputs(td.get_tile(node), td.get_rstep(node)) - vectorize_result = {} - for tensor, shape in shapes.items(): - for v in vectorize_sizes: - if (is_shape_aligned(shape, block_size * v) and is_cont(shape, v) and - is_type_allowed(dtypes[tensor], v)): - vectorize_result[tensor] = v - break - return vectorize_result - - def plan_rasterization(self, td: TileDict): # pylint: disable=unused-argument - """ - Plans the rasterization for the given TileDict. This function is not implemented yet. - - Parameters - ---------- - td : TileDict - The TileDict object to plan rasterization for. - - Raises - ------- - RasterRationPlan - This function is not implemented yet. - """ - return NoRasterization() diff --git a/python/bitblas/base/roller/policy/tensorcore.py b/python/bitblas/base/roller/policy/tensorcore.py deleted file mode 100644 index f4047ef08..000000000 --- a/python/bitblas/base/roller/policy/tensorcore.py +++ /dev/null @@ -1,349 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Policy for tensorcore schedule""" -import tvm -from typing import Dict, List, Tuple, Optional -import numpy as np - -from ..arch import TileDevice -from ..hint import Hint, Stride, TileDict, IntrinInfo -from ..node import PrimFuncNode -from .common import coalesced_factor, factorize, get_all_factors -from .default import DefaultPolicy -from ..rasterization import NoRasterization, Rasterization2DColumn - - -class TensorCorePolicy(DefaultPolicy): - - def __init__(self, - func: tvm.tir.PrimFunc, - arch: TileDevice, - tags: Optional[Dict] = None) -> None: - super().__init__(func, arch, tags) - # this is the trick for wmma. - # However, for int8 mma, the wmma_k should be 32. - self.wmma_k = 16 - self.pipeline_stage: int = 1 - self.use_async_copy: bool = False - self.block_reduction_depth: Optional[int] = None - self._legalize_info() - - def _legalize_info(self): - pipleline_stage = self.prim_func_node.get_tag("pipeline_stage") - if pipleline_stage: - self.pipeline_stage = pipleline_stage - else: - if self.arch.compute_capability == "sm_80": - self.pipeline_stage = 2 - else: - self.pipeline_stage = 1 - use_async_copy = self.prim_func_node.get_tag("use_async_copy") - if use_async_copy: - self.use_async_copy = use_async_copy - else: - if self.arch.compute_capability == "sm_80": - self.use_async_copy = True - else: - self.use_async_copy = False - # TODO: block reduction depth is not used for now. - # As there still exists some performance issues for block reduction. - # block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth") - # if block_reduction_depth: - # self.block_reduction_depth = block_reduction_depth - - def _compute_tc_strides( - self, - node: PrimFuncNode, - tile: List[int], - rstep: Optional[Dict[str, int]] = None, - ) -> Tuple[Stride, Stride, Stride]: - if rstep is None: - rstep = {} - # strides was used for shared memory padding. which is necessary for avoiding - # shared memory load bank conflict when we do not applying tensorcore layout. - shapes = node.propagate_reduction_inputs(tile, rstep) - AS_shape, BS_shape = shapes.values() - CS_shape = tile - A_ax_m, A_ax_k, B_ax_k, B_ax_n, C_ax_m, C_ax_n = node.infer_tensorcore_axis() - - # applying strides - # TODO(leiwang1999): offset should be dynamically set. we can use tag -> enable_offset to control this option.. - offset = 8 - A_high_ax = min(A_ax_m, A_ax_k) - B_high_ax = min(B_ax_n, B_ax_k) - C_high_ax = min(C_ax_m, C_ax_n) - A_stride = Stride(stride=np.prod(AS_shape[A_high_ax + 1:]) + offset, ax=A_high_ax) - B_stride = Stride(stride=np.prod(BS_shape[B_high_ax + 1:]) + offset, ax=B_high_ax) - C_stride = Stride(stride=np.prod(CS_shape[C_high_ax + 1:]) + offset, ax=C_high_ax) - return A_stride, B_stride, C_stride - - def infer_node_smem_usage(self, td: TileDict, node: PrimFuncNode): - value, cached_tensors = super().infer_node_smem_usage(td, node) - value *= self.pipeline_stage - return value, cached_tensors - - def _assign_reduce_step(self, node): - if not node.get_tag("tensorcore_config"): - return super()._assign_reduce_step(node) - # get reduce input size - target_transaction = self.arch.transaction_size[0] * 2 - # 512 bytes // type bits - reduce_input_dtype = node.get_buffer_dtype( - node.block_analyzer.get_input_buffers(node.reduction_block)[0]) - basic = (target_transaction * 8) // reduce_input_dtype.bits - - result = {} - for iter_info in node.raxis: - iter_name = iter_info.var.name - iter_dom = iter_info.dom.extent - if iter_dom % 16 > 0: - result[iter_name] = (16 if iter_dom < basic else basic) # for the case of padding - elif iter_dom % basic == 0: - result[iter_name] = basic - else: - return super()._assign_reduce_step(node) - return result - - def _expand_reduce_axis(self, td: TileDict): - # For tensorcore program, if we got a small tilesize, we should consider expand the reduce axis - # to improve compute efficiency. - def _check_small_tile(td: TileDict): - minimal_threadhold = 32 - for node in self.ordered_nodes: - tile = td.get_tile(node) - if any([t <= minimal_threadhold for t in tile]): - return True - return False - - if not _check_small_tile(td): - return None - - smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap) - rstep_map = td.rstep_map.copy() - is_block_reduction = self.block_reduction_depth is not None - - def _optimize(node, rstep): - all_steps = self.get_node_reduce_step_candidates(node) - # todo(lei): optimize the all_steps enlarge policy to be a multiple of the original all_steps[k] - for k in all_steps: - all_steps[k] = list(filter(lambda x: x % rstep[k] == 0, all_steps[k])) - if any([v == [] for v in all_steps.values()]): - return rstep - - def _shared_memory_usage(td: TileDict): - return node.footprint(td.output_tile, new_rstep_map, td.tensor_strides_map[node]) - - def _score(rstep_id): - rstep = { - k.var.name: all_steps[k.var.name][rstep_id[k.var.name]] for k in node.raxis - } - score = 0 - shape = node.propagate_inputs_on_reduction(td.get_tile(node), rstep=rstep) - input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) - for i, input_buffer in enumerate(input_buffers): - score += coalesced_factor(shape[i], input_buffer.shape) - return score - - def _enlarge(rstep_id): - candidates = [] - for ax in rstep_id: - if rstep_id[ax] + 1 == len(all_steps[ax]): - continue - r = rstep_id.copy() - r[ax] += 1 - candidates.append((r, _score(r))) - if len(candidates) == 0: - return None - return max(candidates, key=lambda x: x[1])[0] - - cur_rstep_id = { - k.var.name: all_steps[k.var.name].index(rstep[k.var.name]) for k in node.raxis - } - new_rstep_map = rstep_map.copy() - while True: - new_rstep_id = _enlarge(cur_rstep_id) - if new_rstep_id is None: - break - new_rstep_map = { - k.var.name: all_steps[k.var.name][new_rstep_id[k.var.name]] for k in node.raxis - } - old_rstep_map = td.rstep_map - td.rstep_map = new_rstep_map - smem_usage, _ = _shared_memory_usage(td) - td.rstep_map = old_rstep_map - if smem_usage > smem_limit: - break - else: - cur_rstep_id = new_rstep_id - rstep = { - k.var.name: all_steps[k.var.name][cur_rstep_id[k.var.name]] for k in node.raxis - } - return rstep - - for node in self.ordered_nodes: - if len(node.raxis) > 0: - rstep = _optimize(node, rstep_map) - rstep_map = rstep - - if is_block_reduction: - # If block reduction, we should constrain the max value is 64 - # Otherwise it will introduce an issue of cuda invalid args. - MAX_REDUCE_K = 64 - for k in rstep_map: - rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K) - td.rstep_map = rstep_map - td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td) - return - - def get_node_reduce_step_candidates(self, node): - if not node.get_tag("tensorcore_config"): - return super().get_node_reduce_step_candidates(node) - else: - # must be a a multiple of wmma_k - return { - k.var.name: - [x * self.wmma_k for x in get_all_factors(int(k.dom.extent) // self.wmma_k)] - for k in node.raxis - } - - def check_tile_shape_isvalid(self, td: TileDict): - for node in self.ordered_nodes: - if node.get_tag("tensorcore_config"): - ax_m, ax_n = node.get_tag("tensorcore_config") - block_m, block_n = ( - td.tile_map[node][ax_m], - td.tile_map[node][ax_n], - ) - # check the tile size is valid - wmma_invalid = [ - block_m < wmma_m or block_n < wmma_n - for wmma_m, wmma_n in self.arch.get_avaliable_tensorintrin_shapes() - ] - if all(wmma_invalid): - return False - if any([y % x for x, y in zip(td.tile_map[node], node.get_space_dim())]): - return False - return super().check_tile_shape_isvalid(td) - - def _can_implement_layout(self, node: PrimFuncNode, td: TileDict): - # Not implemented yet - # This function is used to check whether we can implement swizzling - # layout under this tile config - return False - - def compute_node_stride_map(self, node: PrimFuncNode, td: TileDict): - if not node.get_tag("tensorcore_config"): - return super().compute_node_stride_map(node, td) - use_layout = self._can_implement_layout(node, td) - - AS_stride, BS_stride, C_stride = self._compute_tc_strides(node, td.get_tile(node), - td.get_rstep(node)) - A_stride, B_stride, _ = self._compute_tc_strides(node, td.get_tile(node)) - tensor_strides = {} - output_strides = { - int(i + len(node.input_buffers)): Stride() for i, _ in enumerate(node.output_buffers) - } - tensor_strides = {} - # when connected to shared input, should use full stride without rstep - for i, (_, _) in enumerate(zip([AS_stride, BS_stride], [A_stride, B_stride])): - if use_layout: - continue - _ = node.block_analyzer.get_input_buffers(node.reduction_block)[i].name - # TODO(lei): should dig further for shared memory connection case. - - return output_strides, tensor_strides - - def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): - if not node.get_tag("tensorcore_config"): - return super()._assign_block_size(node, td, block_size) - ax_m, ax_n = node.get_tag("tensorcore_config") - if block_size % self.arch.warp_size != 0: - return None - tile, rsteps = td.get_tile(node), td.get_rstep(node) - warps = block_size // self.arch.warp_size - ndim = len(tile) - - wmma = self.arch.get_avaliable_tensorintrin_shapes()[-1] - wmma_tile = [1 for _ in range(ndim)] - wmma_tile[ax_m] = wmma[0] - wmma_tile[ax_n] = wmma[1] - - space = [tile[i] // wmma_tile[i] for i in range(ndim)] - if tile[ax_m] < wmma_tile[ax_m] or tile[ax_n] < wmma_tile[ax_n]: - # allow pad, otherwise, we can not get a valid tile shape - return None - - factors = factorize(np.prod(space) // warps) - - def _score(node, thread): # small is better - score = 0 - block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] - shape = node.propagate_inputs_on_reduction(block_tile) - input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) - for i, _ in enumerate(input_buffers): - score += np.prod(shape[i]) / self.arch.bandwidth[1] - return score - - warp_tile = wmma_tile.copy() - for factor in reversed(factors): - score_map = {} - for i in range(ndim): - if tile[i] % (warp_tile[i] * factor) != 0: - continue - warp_tile[i] *= factor - score_map[i] = (_score(node, warp_tile), i) - warp_tile[i] //= factor - if len(score_map) == 0: - return None - dim_order = sorted(score_map.keys(), key=lambda x: score_map[x]) - warp_tile[dim_order[0]] *= factor - - codegen_dict = Hint() - codegen_dict.block = tile - codegen_dict.warp = warp_tile - codegen_dict.use_tc = True - codegen_dict.pipeline_stage = self.pipeline_stage - codegen_dict.block_reduction_depth = self.block_reduction_depth - codegen_dict.use_async = self.use_async_copy - codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis] - codegen_dict.cached_tensors = td.cached_tensors_map[node] - codegen_dict.rasterization_plan = self.plan_rasterization(td) - - intrin_info = node.get_tag("intrin_info") - if intrin_info: - codegen_dict.intrin_info = IntrinInfo(**intrin_info) - if intrin_info["out_dtype"] in ["float32"]: - codegen_dict.shared_scope = "shared.dyn" - # smem capacity - if td.smem_cost > self.arch.smem_cap: - codegen_dict.shared_scope = "shared.dyn" - - codegen_dict.complete_config(node) - codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size) - codegen_dict.arch = self.arch - codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes") - codegen_dict.tensorcore_legalization() - return codegen_dict - - def plan_rasterization(self, td: TileDict): - conditions = [] - # only support single node for now - conditions.append(len(self.ordered_nodes) > 1) - # only on Ampere+ arch - conditions.append(self.arch.compute_capability < "80") - - def _check_memory_size(): - overall_gmem_size_in_bytes: int = 0 - for node in self.ordered_nodes: - for buffer in node.input_buffers: - overall_gmem_size_in_bytes += ( - int(np.prod(buffer.shape)) * tvm.DataType(buffer.dtype).bits // 8) - return overall_gmem_size_in_bytes < self.arch.l2_cache_size_bytes - - conditions.append(_check_memory_size()) - if any(conditions): - return NoRasterization() - # otherwise, simply provide a block rasterization factor - raster_factor = int(self.arch.compute_max_core**0.5) - - return Rasterization2DColumn(raster_factor) diff --git a/python/bitblas/base/roller/rasterization.py b/python/bitblas/base/roller/rasterization.py deleted file mode 100644 index 4fb779069..000000000 --- a/python/bitblas/base/roller/rasterization.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Rasteration Plan For L2 Cache Locality""" - -from typing import List - - -class Rasterization: - - def __init__(self) -> None: - pass - - def get_code(self) -> List[str]: - raise NotImplementedError() - - -class NoRasterization(Rasterization): - - def __init__(self) -> None: - super().__init__() - - def __repr__(self) -> str: - return "" - - def get_code(self) -> List[str]: - return [] - - -class Rasterization2DRow(Rasterization): - """ - Rasterization by Row, each Row line width is panel_width - _________ - _________| - |_________ - __________| - """ - - def __init__(self, panel_width=4) -> None: - super().__init__() - self.panel_width_ = panel_width - - def __repr__(self) -> str: - return f"" - - def get_code(self) -> List[str]: - raise NotImplementedError() - - -class Rasterization2DColumn(Rasterization): - """ - Rasterization by Column, each column line width is panel_width - _ - | | | | - | | | | - |_| |_| - """ - - def __init__(self, panel_width=4) -> None: - super().__init__() - self.panel_width_ = panel_width - - def __repr__(self) -> str: - return f"" - - def get_device_function(self) -> str: - return """ -__device__ __inline__ dim3 rasterization2DColumn(const int panel_width) { - const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; - const auto totalPanel = (gridDim.x * gridDim.y +panel_width * gridDim.x - 1) / (panel_width * gridDim.x); - const auto totalBlock = gridDim.x * gridDim.y; - const auto panelIdx = baseBlockIdx / (panel_width *gridDim.x); - const auto strideLd = panelIdx + 1 < totalPanel ?panel_width : (totalBlock - panelIdx * (panel_width *gridDim.x)) / gridDim.x; - const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * panel_width * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * panel_width *gridDim.x) / strideLd; - const auto by = (baseBlockIdx - panelIdx * panel_width *gridDim.x) % strideLd + panelIdx * panel_width; - const auto bz = blockIdx.z; - - dim3 blockIdx(bx, by, bz); - return blockIdx; -} - """ - - def get_code(self, panel_width: int = None) -> List[str]: - if panel_width is None: - panel_width = self.panel_width_ - return [ - self.get_device_function(), - "const dim3 blockIdx = rasterization2DColumn({});\n".format(panel_width), - ] diff --git a/python/bitblas/base/roller/shape_inference/__init__.py b/python/bitblas/base/roller/shape_inference/__init__.py deleted file mode 100644 index 188aa0bb7..000000000 --- a/python/bitblas/base/roller/shape_inference/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .tir import get_analyzer_by_tir # pylint: disable=unused-import diff --git a/python/bitblas/base/roller/shape_inference/common.py b/python/bitblas/base/roller/shape_inference/common.py deleted file mode 100644 index 730bbbeef..000000000 --- a/python/bitblas/base/roller/shape_inference/common.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from collections import OrderedDict -from typing import Dict, List - -from tvm import arith - - -class Statement(): - def __init__(self, output: str, dependent_region: dict, var_map: OrderedDict, range_map: OrderedDict): - self.output = output - self.dependent_region = dependent_region - self.var_map = var_map - self.range_map = range_map - -def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): - return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) - -class InputShapeInference(): - def __init__(self, deps: List[Statement]): - self.deps = deps - - def _infer(self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int]): - shape = shape.copy() - ana = arith.Analyzer() - for dep in reversed(self.deps): - for var, bound in zip(dep.var_map.values(), shape[dep.output]): - ana.update(var, bound) - for var, bound in dep.range_map.items(): - if var.name in rstep: - bound = arith.ConstIntBound(0, min(bound.max_value, rstep[var.name] - 1)) - ana.update(var, bound) - for name, regions in dep.dependent_region.items(): - for region in regions: - bounds = [ana.const_int_bound(index) for index in region] - if name in shape: # simply merge two bounds - bounds = [_merge_two_bounds(x, y) for x, y in zip(shape[name], bounds)] - shape[name] = bounds - - for name, bounds in shape.items(): - shape[name] = [c.max_value - c.min_value + 1 for c in bounds] - return shape - - def infer(self, shape, rstep: Dict[str, int] = {}): - if isinstance(shape, (list, tuple)): - shape = {"output0" : [arith.ConstIntBound(0, val - 1) for val in shape]} - shape = self._infer(shape, rstep) - return shape - - def get_input_exprs(self, output_exprs): - result = output_exprs.copy() - ana = arith.Analyzer() - for dep in reversed(self.deps): - for var, expr in zip(dep.var_map.values(), result[dep.output]): - ana.bind(var, expr) - for var in dep.range_map: - ana.bind(var, 0) - for name, regions in dep.dependent_region.items(): - if name in result: - continue - region = regions[0] - input_expr = [ana.simplify(index) for index in region] - result[name] = input_expr - return result - diff --git a/python/bitblas/base/roller/shape_inference/tir.py b/python/bitblas/base/roller/shape_inference/tir.py deleted file mode 100644 index 35bf0b7d8..000000000 --- a/python/bitblas/base/roller/shape_inference/tir.py +++ /dev/null @@ -1,399 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Dict, List, Tuple, Set, Mapping -from tvm.tir.schedule.schedule import BlockRV -from tvm.ir import structural_equal -from tvm import arith, tir - - -class Statement: - def __init__(self, block_analyzer, block: BlockRV): - self.block_analyzer = block_analyzer - self.block = block - # assume one tir block only has one output buffer - self.dep_name = block_analyzer.get_output_buffers(block)[0].name - self.dependent_region = _extract_dependent_region(block_analyzer, block) - - self.reverse_bound_inference = {} - - def make_reverse(self, input_name: str, input_iter: List[tir.PrimExpr]): - if len(self.block_analyzer.get_reduce_axis(self.block)) > 0: - return None - if len(self.dependent_region[input_name]) != 1: - return None - indices = self.dependent_region[input_name][0] - iter_map_range = { - _iter.var: _iter.dom for _iter in self.block_analyzer.get_spatial_axis(self.block) - } - iter_map_result = arith.detect_iter_map( - indices, - iter_map_range, - check_level=arith.iter_affine_map.IterMapLevel.Surjective, - simplify_trivial_iterators=False, - ) - if len(iter_map_result.errors) > 0: - return None - results = arith.iter_affine_map.inverse_affine_iter_map(iter_map_result.indices, input_iter) - output_indices = [] - for _iter in self.block_analyzer.get_spatial_axis(self.block): - if _iter.var in results: - output_indices.append(results[_iter.var]) - else: - # not Bijective mapping case - output_indices.append(tir.Var("undefined", dtype="int32") % int(_iter.dom.extent)) - return output_indices - - -def _merge_two_bounds(x: arith.ConstIntBound, y: arith.ConstIntBound): - return arith.ConstIntBound(min(x.min_value, y.min_value), max(x.max_value, y.max_value)) - - -class TensorDepNode(object): - """ - For tensor dependency analysis. - """ - - def __init__(self, name): - self.name = name - self._next = [] - self._prev = [] - - def add_next(self, node): - self._next.append(node) - self.deduplicate(self._next) - - def add_prev(self, node): - self._prev.append(node) - self.deduplicate(self._prev) - - def deduplicate(self, lst): - seen = set() - lst[:] = [n for n in lst if not (n in seen or seen.add(n))] - - def __str__(self): - return self.name - - def __repr__(self): - return self.name - - -class DependencyAnalysis(object): - def __init__(self, deps): - self.deps = deps - # issue: duplicate name when we have two same ops. - self.name2dep = self._construct_unique_name2dep(deps) - self.mapping = {} # name -> TensorDepNode - - def _construct_unique_name2dep(self, deps): - """ - This is a workaround for the issue that we have two same ops' fuse case. - See https://github.com/apache/tvm/issues/16433 - """ - _names:Set = set() - name2dep:Mapping = {} - for dep in deps: - output_buffer = dep.block_analyzer.get_output_buffers(dep.block)[0] - base_name = output_buffer.name - if base_name not in _names: - _names.add(base_name) - else: - i = 1 - while f"{base_name}_{i}" in _names: - i += 1 - base_name = f"{base_name}_{i}" - _names.add(base_name) - name2dep[base_name] = dep - return name2dep - - def get_or_create_node(self, name): - if name not in self.mapping: - self.mapping[name] = TensorDepNode(name) - return self.mapping[name] - - def traverse_dependencies(self, compute): - if isinstance(compute, Statement): - node = self.get_or_create_node( - compute.block_analyzer.get_output_buffers(compute.block)[0].name - ) - # Loop through input tensors - for input_buffer in compute.block_analyzer.get_input_buffers(compute.block): - # Get the input node - input_node = self.traverse_dependencies(input_buffer) - input_node.add_next(node) - node.add_prev(input_node) - elif isinstance(compute, tir.Buffer): - node = self.get_or_create_node(compute.name) - return node - - def analyze(self): - # Starting point for traversal - for _, compute in self.name2dep.items(): - self.traverse_dependencies(compute) - - def print_dependencies(self): - for name, node in self.mapping.items(): - print(f"{name} depends on {', '.join([prev.name for prev in node._prev])}") - - def find_path_from_source(self, start_name, target_name): - """ - Finds the path (if it exists) from a starting node (source) to a target node. - Returns the path as a list of nodes. - """ - visited = set() - path = [] - if self._find_path_recursive(self.mapping[start_name], target_name, visited, path): - return path - return [] - - def _find_path_recursive(self, current_node, target_name, visited, path): - """ - Recursive helper function for find_path_from_source. - """ - if current_node.name == target_name: - path.append(current_node) - return True - - if current_node.name in visited: - return False - - visited.add(current_node.name) - path.append(current_node) - - for next_node in current_node._next: - if self._find_path_recursive(next_node, target_name, visited, path): - return True - - path.pop() - return False - - -class InputShapeInference: - def __init__(self, deps: List[Statement]): - self.deps = deps - self.target_mapping = {} - self.buffer_mapping = {} - self.reduce_axes = [] - for dep in self.deps: - for ax in dep.block_analyzer.get_reduce_axis(dep.block): - self.reduce_axes.append(ax) - self.dep_analysis = DependencyAnalysis(self.deps) - self.dep_analysis.analyze() - - def construct_dependency_target(self, targets: Tuple[str]): - if targets in self.target_mapping: - return self.target_mapping[targets] - # should be buffer name instead of block name - name2dep = { - dep.block_analyzer.get_output_buffers(dep.block)[0].name: dep for dep in self.deps - } - mapping = {} - input_vars = [] - for target in targets: - vars = [ - iter.var - for iter in name2dep[target].block_analyzer.get_spatial_axis(name2dep[target].block) - ] - input_vars.append(vars) - mapping[target] = [vars] - ana = arith.Analyzer() - - for dep in self.deps: - for name in dep.dependent_region: - if name not in mapping: - continue - dep_name = dep.dep_name - indices = mapping[name][0] - output_indices = dep.make_reverse(name, indices) - if dep_name in targets: - continue - if dep_name not in mapping: - mapping[dep_name] = [output_indices] - elif not region_exist_in_list(output_indices, mapping[dep_name]): - mapping[dep_name].append(output_indices) - - for dep in reversed(self.deps): - indices_list = mapping[dep.dep_name] - ax_vars = [iter.var for iter in dep.block_analyzer.get_spatial_axis(dep.block)] - for input_name, regions in dep.dependent_region.items(): - if input_name in targets: - continue - if input_name not in mapping: - mapping[input_name] = [] - for indices in indices_list: - for region in regions: - vmap = { - k: (tir.Cast(k.dtype, v) if v.dtype != k.dtype else v) - for k, v in zip(ax_vars, indices) - } - region = [ - ana.simplify(tir.stmt_functor.substitute(ax, vmap)) for ax in region - ] - if not region_exist_in_list(region, mapping[input_name]): - mapping[input_name].append(region) - buffers = [] - for dep in self.deps: - for buffer in dep.block_analyzer.get_buffers(dep.block): - buffers.append(buffer) - - for buffer in buffers: - self.buffer_mapping[buffer.name] = buffer - - self.target_mapping[targets] = input_vars, mapping - return input_vars, mapping - - def infer( - self, shape: Dict[str, List[arith.ConstIntBound]], rstep: Dict[str, int] = {}, targets=None - ): - compute_targets = tuple(shape.keys()) - input_vars, mapping = self.construct_dependency_target(compute_targets) - ana = arith.Analyzer() - results = {} - intermediate_bind = {} - for vars, bounds in zip(input_vars, shape.values()): - for var, bound in zip(vars, bounds): - ana.update(var, bound, True) - for ax in self.reduce_axes: - # assume the dom.min is always 0, maybe we can extend the IterInfo to include the min value. - if ax.var.name in rstep: - bound = arith.ConstIntBound( - int(ax.dom.min), int(ax.dom.min + min(ax.dom.extent, rstep[ax.var.name]) - 1) - ) - else: - bound = arith.ConstIntBound(int(ax.dom.min), int(ax.dom.min + ax.dom.extent - 1)) - ana.update(ax.var, bound, True) - - for name, regions in mapping.items(): - if targets is not None and name not in targets: - continue - if compute_targets[0:1] == compute_targets: - (compute_target,) = compute_targets - path = self.dep_analysis.find_path_from_source(name, compute_target) - if len(path) > 2: - intermediate_nodes = path[1:-1] - for node in intermediate_nodes: - iters = mapping[node.name] - if len(iters) != len(regions) or len(iters) != 1: - continue - if len(*iters) != len(*regions): - break - regions = iters - intermediate_bind[name] = compute_target - - for region in regions: - bound = [ana.const_int_bound(indice) for indice in region] - if name in results: # simply merge two bounds - bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] - results[name] = bound - else: - for region in regions: - bound = [ana.const_int_bound(indice) for indice in region] - if name in results: # simply merge two bounds - bound = [_merge_two_bounds(x, y) for x, y in zip(results[name], bound)] - results[name] = bound - - for name, bounds in results.items(): - results[name] = [c.max_value - c.min_value + 1 for c in bounds] - return results, intermediate_bind - - def get_input_exprs(self, output_exprs): - input_vars, mapping = self.construct_dependency_target(tuple(output_exprs.keys())) - ana = arith.Analyzer() - for ax in self.reduce_axes: - ana.bind(ax.var, 0) - vmap = {} - for vars, exprs in zip(input_vars, output_exprs.values()): - for var, expr in zip(vars, exprs): - if expr.dtype != var.dtype: - expr = tir.Cast(var.dtype, expr) - vmap[var] = expr - result = {} - - for name, regions in mapping.items(): - region = regions[0] - result[name] = [ - ana.simplify(tir.stmt_functor.substitute(index, vmap)) for index in region - ] - return result - - -def region_exist_in_list(a, list) -> bool: - def expr_is_same(a, b) -> bool: - if isinstance(a, tir.IntImm) and isinstance(b, tir.IntImm): - return a.value == b.value - return structural_equal(a, b) - - def region_is_same(a, b) -> bool: - for indice_a, indice_b in zip(a, b): - if not expr_is_same(indice_a, indice_b): - return False - return True - - return any([region_is_same(a, x) for x in list]) - - -def walk_indice(expr): - if isinstance(expr, tir.expr.BinaryOpExpr): - a = walk_indice(expr.a) - b = walk_indice(expr.b) - if a is not None and b is not None: - return expr - else: - return None - elif isinstance(expr, tir.expr.ConstExpr): - return expr - elif isinstance(expr, tir.Var): - return expr - elif isinstance(expr, tir.ProducerLoad): - return None - elif isinstance(expr, tir.Cast): - a = walk_indice(expr.value) - if a is not None: - return expr - return None - elif isinstance(expr, tir.Call): - return None - else: - raise Exception("Unhandled node type in walk_indice(): %s" % expr) - - -def _extract_dependent_region(block_analyzer, block: BlockRV) -> Dict[str, List[tir.PrimExpr]]: - input_buffers = block_analyzer.get_input_buffers(block) - dependent_region = {buffer.name: [] for buffer in input_buffers} - - def fvisit(x): - if not isinstance(x, tir.BufferLoad): - return - if x.buffer.name not in dependent_region: - return - index = [] - for indice, shape_limit in zip(x.indices, x.buffer.shape): - expr = walk_indice(indice) - if expr is None: - expr = tir.Var("undefined", dtype="int8") % shape_limit - if isinstance(expr, tir.IntImm) and expr.value == 0: - """for tensor ir zero dim smplification case. - for ax0, ax1, ax2 in T.grid(T.int64(1024), T.int64(1024), T.int64(1024)): - with T.block("T_dense"): - v0, v1, v2 = T.axis.remap("SSR", [ax0, ax1, ax2]) - T.reads(A_reindex[T.int64(0), v0, v2], B_reindex[T.int64(0), v1, v2]) - T.writes(T_dense_reindex[T.int64(0), v0, v1]) - with T.init(): - T_dense_reindex[T.int64(0), v0, v1] = T.float16(0) - T_dense_reindex[T.int64(0), v0, v1] = T_dense_reindex[T.int64(0), v0, v1] + A_reindex[T.int64(0), v0, v2] * B_reindex[T.int64(0), v1, v2] - For exmaple, the T_dense_reindex has three dims, however there're only two spatial loops. - """ - continue - index.append(expr) - if not region_exist_in_list(index, dependent_region[x.buffer.name]): - dependent_region[x.buffer.name].append(index) - - stmt = block_analyzer.sch.get(block) - tir.stmt_functor.post_order_visit(stmt, fvisit=fvisit) - return dependent_region - - -def get_analyzer_by_tir(block_analyzer, args) -> InputShapeInference: - deps = [Statement(block_analyzer, block) for block in args] - - return InputShapeInference(deps) diff --git a/python/bitblas/base/schedule_rule.py b/python/bitblas/base/schedule_rule.py deleted file mode 100644 index 53319b4fc..000000000 --- a/python/bitblas/base/schedule_rule.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm schedule_rule.py in dlight. -"""A lightweight wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc.""" -from typing import Callable, List, Union - -from tvm import tir -from tvm.target import Target - - -class ScheduleRule: # pylint: disable=too-few-public-methods - """A thin wrapper on an arbitrary function that can be used to schedule a TIR PrimFunc. - - Given a PrimFunc, a target, and a tunable flag, the apply method of a ScheduleRule - returns either a Schedule, a list of Schedules, or None, where None means that the rule - is not applicable to the given PrimFunc. If the tunable flag is True, the ScheduleRule is - allowed to return either a Schedule or a list of Schedules, and the Schedules are allowed to - contain tunable instructions. If the tunable flag is False, the ScheduleRule is only allowed to - return a Schedule, and the Schedule is not allowed to contain tunable instructions. - """ - - def apply( - self, - func: tir.PrimFunc, - target: Target, - tunable: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - """Apply the ScheduleRule to the given PrimFunc. - - Parameters - ---------- - func : tir.PrimFunc - The PrimFunc to apply the ScheduleRule to. - target : Target - The compilation target the schedule is supposed to be built for. - tunable : bool - Whether the schedule is allowed to contain tunable instructions. - - Returns - ------- - results : Union[None, tir.Schedule, List[tir.Schedule]] - Either a Schedule, a list of Schedules, or None, where None means that the rule - is not applicable to the given PrimFunc. - """ - raise NotImplementedError - - def apply_config( - self, - func: tir.PrimFunc, - config, - ): - """Apply the ScheduleRule to the given PrimFunc. - - Parameters - ---------- - func : tir.PrimFunc - The PrimFunc to apply the ScheduleRule to. - target : Target - The compilation target the schedule is supposed to be built for. - configs : - # todo: Discribe the configs - Returns - ------- - results : Union[None, tir.Schedule, List[tir.Schedule]] - Either a Schedule, a list of Schedules, or None, where None means that the rule - is not applicable to the given PrimFunc. - """ - raise NotImplementedError - - @staticmethod - def from_callable( - name, - ) -> Callable[ - [ - Callable[ - [tir.PrimFunc, Target, bool], - Union[None, tir.Schedule, List[tir.Schedule]], - ], - ], - "ScheduleRule", - ]: - """Create a ScheduleRule from a callable. - - Parameters - ---------- - name : str - - Returns - ------- - decorator : Callable - A decorator that takes a callable and returns a ScheduleRule. - - Examples - -------- - .. code-block:: python - - @ScheduleRule.from_callable("MyRule") - def my_rule(func: tir.PrimFunc, target: Target, tunable: bool) -> Union[None, Schedule] - # Do something with func and target - """ - - def decorator(f) -> "ScheduleRule": # pylint: disable=invalid-name - class _Rule(ScheduleRule): - def apply( - self, - func: tir.PrimFunc, - target: Target, - tunable: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - return f(func, target, tunable) - - _Rule.__name__ = name - return _Rule() - - return decorator - - def is_target_available( - self, target: Target - ) -> bool: # pylint: disable=unused-argument - """Check whether the rule is available for the given target. - - Parameters - ---------- - target : Target - The compilation target the schedule is supposed to be built for. - - Returns - ------- - available : bool - Whether the rule is available for the given target. - """ - return True diff --git a/python/bitblas/base/transform.py b/python/bitblas/base/transform.py deleted file mode 100644 index 647efa772..000000000 --- a/python/bitblas/base/transform.py +++ /dev/null @@ -1,218 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -""" -Apply ScheduleRules onto an IRModule to generate default schedules without tuning, -or a space for MetaSchedule tuning -""" -from typing import List, Optional, Dict -import os -import shutil -import tempfile -import os.path as osp -import tvm -from tvm import tir -from tvm import meta_schedule as ms -from tvm.ir import IRModule -from tvm.ir.transform import PassContext, module_pass -from tvm.target import Target -from .schedule_rule import ScheduleRule -from ..base.analysis import check_func_with_dynamic -from .utils import fast_tune, fast_tune_with_dynamic_range -import logging - -logger = logging.getLogger(__name__) - - -def _is_scheduled(func: tir.PrimFunc) -> bool: - if not isinstance(func, tir.PrimFunc): - return False - if not func.attrs: - return False - if "tir.is_scheduled" not in func.attrs: - return False - return func.attrs["tir.is_scheduled"] == 1 - - -@module_pass(opt_level=0, name="ApplyDefaultSchedule") -class ApplyDefaultSchedule: # pylint: disable=too-few-public-methods - """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" - - def __init__(self, *rules: ScheduleRule): - """Construct a new ApplyDefaultSchedule pass. - - Parameters - ---------- - *rules : ScheduleRule - The ScheduleRules to apply to all PrimFuncs in the module. - """ - self.rules = list(rules) - - def transform_module( # pylint: disable=missing-function-docstring - self, - mod: IRModule, - _: PassContext, - ) -> IRModule: - target = Target.current(allow_none=False) - - updated_functions = {} - for g_var, func in mod.functions_items(): - if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): - sch = _apply_rules(func, target, self.rules, tunable=False) - if sch is not None: - assert len(sch) == 1 - updated_functions[g_var] = (sch[0].mod["main"].with_attr("tir.is_scheduled", 1)) - for g_var, func in updated_functions.items(): - mod[g_var] = func - return mod - - -@module_pass(opt_level=0, name="ApplyFastTuning") -class ApplyFastTuning: # pylint: disable=too-few-public-methods - """A IRModule pass that applies a list of ScheduleRules to all PrimFuncs in the module.""" - - def __init__( - self, - topk: int = 10, - target: Optional[Target] = None, - parallel_build: bool = True, - meta_database_dir: str = None, - whitelist: Optional[List[str]] = None, - dynamic_range: Optional[Dict[str, List[int]]] = None, - ): - """Construct a new ApplyFastTuning pass. - - Parameters - ---------- - meta_database : str - The path of database. - dynamic_range : Dict[str, List[int]] - Use for generate kernel based on dynamic range. - """ - if whitelist is None: - whitelist = [] - if dynamic_range is None: - dynamic_range = {} - self.topk = topk - self.target = Target.current() if target is None else target - self.parallel_build = parallel_build - self.meta_database_dir = meta_database_dir - self.whitelist = whitelist - self.dynamic_range = dynamic_range - self.temp_dir = tempfile.TemporaryDirectory() - path_workload = osp.join(self.temp_dir.name, "database_workload.json") - path_tuning_record = osp.join(self.temp_dir.name, "database_tuning_record.json") - self.cache_meta_database = ms.database.JSONDatabase( - path_workload, path_tuning_record, module_equality="structural") - - def _in_white_list(self, func_name: str) -> bool: - if len(self.whitelist) == 0: - return True - return any([name in func_name for name in self.whitelist]) - - def transform_module( # pylint: disable=missing-function-docstring - self, - mod: IRModule, - _: PassContext, - ) -> IRModule: - target = self.target - updated_functions = {} - - for g_var, func in mod.functions_items(): - if isinstance(func, tir.PrimFunc) and not _is_scheduled(func): - if not self._in_white_list(g_var.name_hint): - continue - normalize_mod_func_ = tvm._ffi.get_global_func("tvm.meta_schedule.normalize_mod") - _normalized_func_mod = normalize_mod_func_(func) - - if self.cache_meta_database.has_workload(_normalized_func_mod): - tuning_record = self.cache_meta_database.query_tuning_record( - _normalized_func_mod, - target, - g_var.name_hint, - ) - if tuning_record: - trace = tuning_record.trace - sch = tvm.tir.Schedule(func) - trace.apply_to_schedule(sch, remove_postproc=False) - updated_functions[g_var] = sch.mod["main"].with_attr("tir.is_scheduled", 1) - continue - - if check_func_with_dynamic(func): - - dispatch_mod = fast_tune_with_dynamic_range( - func, - target=target, - topk=self.topk, - parallel_build=self.parallel_build, - global_symbol=g_var.name_hint, - dynamic_range=self.dynamic_range, - ) - - if dispatch_mod: - for g, f in dispatch_mod.functions_items(): - if g.name_hint == g_var.name_hint: - # avoid duplicated global symbol - updated_functions[g_var] = f.without_attr( - "global_symbol").with_attr("tir.is_scheduled", 1) - else: - updated_functions[g] = f.with_attr("tir.is_scheduled", 1) - # cannot reuse meta database as it cannot be recorvered from the trace - workload = self.cache_meta_database.commit_workload(_normalized_func_mod) - else: - # otherwise is static shape analysis - _, best = fast_tune( - func, - target=target, - topk=self.topk, - parallel_build=self.parallel_build, - ) - - if best is not None: - updated_functions[g_var] = best.sch.mod["main"].with_attr( - "tir.is_scheduled", 1) - workload = self.cache_meta_database.commit_workload(_normalized_func_mod) - # only record the best schedule - self.cache_meta_database.commit_tuning_record( - ms.database.TuningRecord( - best.sch.trace, - workload, - [best.latency], - target, - ms.arg_info.ArgInfo.from_prim_func(func=best.sch.mod["main"]), - )) - - for g_var, func in updated_functions.items(): - mod[g_var] = func - - # copy database - if self.meta_database_dir is not None: - if not osp.exists(self.meta_database_dir): - os.makedirs(self.meta_database_dir) - # TODO(lei): maybe another way to copy the database - shutil.copytree(self.temp_dir.name, self.meta_database_dir, dirs_exist_ok=True) - - return mod - - def __del__(self): - # clean up the temp cache - self.temp_dir.cleanup() - - -def _apply_rules( - func: tir.PrimFunc, - target: Target, - rules: List[ScheduleRule], - tunable: bool, -) -> Optional[List[tir.Schedule]]: - for rule in rules: - try: - space = rule.apply(func, target, tunable) - except Exception: - logger.debug(f"[BitBLAS][Error] applying rule {rule} failed") - space = None - if space is None: - continue - if isinstance(space, tir.Schedule): - space = [space] - return space - return None diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py deleted file mode 100644 index 50adc135f..000000000 --- a/python/bitblas/base/utils.py +++ /dev/null @@ -1,517 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import tvm -import os -from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind -from concurrent.futures import ThreadPoolExecutor, as_completed -import numpy as np -from typing import List, Tuple, Optional, Dict, Union, Literal -from tvm import tir, IRModule -from tvm.runtime import Module -from tvm.tir import Schedule -from tvm.relax.expr import Function -import bitblas -from .analysis import get_root_block, get_reduction_blocks, find_var_from_func -from bitblas.base.roller.arch import CUDA -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -import tempfile -import itertools -from tvm.ir.supply import GlobalVarSupply -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -import logging - -logger = logging.getLogger(__name__) - - -def get_rasterization_code(pannel_width: int = 8) -> str: - return f""" - const int MAX_BLOCK_N = {pannel_width}; - const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; - const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); - const auto totalBlock = gridDim.x * gridDim.y; - const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); - const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; - const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; - const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; - const auto bz = blockIdx.z; - const dim3 blockIdx(bx, by, bz); - """ - - -class CompileResult: - """ - Class to store the result of compilation - """ - - def __init__(self, config, sch, mod: Module): - self.config = config - self.sch = sch - self.mod = mod - self.code = mod.imported_modules[0].get_source() if mod else None - self.latency = 1e9 - self.profile_tensors = [] - self.time_evaluator = None - - def profile(self): - profile_tensors = self.profile_tensors - return self.time_evaluator(*profile_tensors).mean * 1e3 - - -def _apply_config( - func: tir.PrimFunc, - config=None, # todo(lei): update typing -) -> Optional[tir.Schedule]: - """ - find rules: - case 1. if the main block has no reduce op, then use the Elementwise rule. - case 2. if the config enabled tensorcore, then use the TensorCore rule. - case 3. if any([t > 1 for t in config.reduce_thread]), we should use the InnerThread Reduction Rule. - case 4. else we should use general reduction rule. - """ - logger.debug("Apply config {}".format(config)) - - sch = tir.Schedule(func) - root_block = get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - reduction_blocks = get_reduction_blocks(sch, blocks) - - if not reduction_blocks: - return bitblas.gpu.ElementWise().apply_config(func, config) - elif config.use_tc: - if config.arch.sm_version >= 80: - # For A100(sm_80) or more advanced gpu, use MMA tensorization. - return bitblas.gpu.MatmulTensorizationMMA().apply_config(func, config) - else: - # For other GPUs, use WMMA tensorization. - return bitblas.gpu.MatmulTensorizationWMMA().apply_config(func, config) - else: - _reduction_rules = [] - - _reduction_rules.append(bitblas.gpu.GEMV()) - if not any([t > 1 for t in config.reduce_thread]): - # Matrix multiplication template doesn't support inner thread reduction - _reduction_rules.append(bitblas.gpu.Matmul()) - _reduction_rules.append(bitblas.gpu.GeneralReduction()) - - for rule in _reduction_rules: - sch = rule.apply_config(func, config) - try: - sch = rule.apply_config(func, config) - except Exception as e_msg: - logger.debug("Apply config failed: ", e_msg) - continue - if sch is not None: - return sch - return None - - -def get_dummy_input_arrays( - func: Union[tir.PrimFunc, Function], - device: tvm.runtime.Device, - distribution: Literal["uniform", "onefill"] = "uniform", -): - - def var_wrapper(v): - if isinstance(v, tvm.tir.Var): - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return func.attrs["opt_shapes"][v.name].value - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - profile_tensors = [] - for param in func.params: - if isinstance(func, tir.PrimFunc): - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - elif isinstance(func, Function): - arg = param.struct_info - else: - raise ValueError("Not supported type: ", type(func)) - - def map_numpy_type(intype): - typemap = { - 'e4m3_float8': 'float8_e4m3fn', - 'e5m2_float8': 'float8_e5m2', - } - if intype in typemap: - return typemap[intype] - else: - return intype - - numpy_dtype = map_numpy_type(arg.dtype) - if distribution == "uniform": - profile_tensors.append( - tvm.nd.array( - np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), - device=device, - )) - elif distribution == "onefill": - profile_tensors.append( - tvm.nd.array( - np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), - device=device, - )) - else: - raise ValueError("Not supported distribution: ", distribution) - return profile_tensors - - -def apply_and_build_parallel(func, - configs, - arch, - num_repeats=3, - max_workers=10, - timeout=30, - data_distribution="uniform") -> CompileResult: - cpresults = [] - - profile_tensors = get_dummy_input_arrays(func, arch.device, distribution=data_distribution) - max_workers = min(len(configs), os.cpu_count(), max_workers) - - # apply config in thread parallel - _sched: List[Schedule] = [] - - def _apply_schedule(f, c): - try: - sch = _apply_config(f, c) - except Exception as apply_schedule_error: - logger.debug("Apply schedule failed: {}".format(apply_schedule_error)) - sch = None - return sch - - with ThreadPoolExecutor(max_workers=4) as scheduler: - futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} - for future in as_completed(futures, timeout=timeout): - _sched.append(future.result()) - - builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout) - - # build in process parallel - def _build(context) -> str: - idx, mod, arch = context - if mod is None: - return idx, None, None - # TODO(lei): - # this is a trick to implement rasteration, will be removed in the future - config = configs[idx] - - @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) - def tvm_callback_cuda_postproc(code, _): - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - with tvm.transform.PassContext(config={"tir.use_async_copy": True, **config.pass_context}): - rt_mod = tvm.build(mod, target=arch.target) - - from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel - - artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) - code = rt_mod.imported_modules[0].get_source() - rt_mod.export_library(artifact_path, fcompile=tar) - return idx, code, artifact_path - - _mods = [sch.mod if sch is not None else None for sch in _sched] - - for map_result in builder.map_with_error_catching( - _build, - [(i, mod, arch) for i, mod in enumerate(_mods)], - ): - if map_result.status == StatusKind.TIMEOUT: - logger.debug("LocalBuilder: Timeout") - elif map_result.status == StatusKind.EXCEPTION: - # TODO(lei): redirect the exception to file if needed - logger.debug("LocalBuilder: An exception occurred {}".format(map_result.value)) - continue - elif map_result.status == StatusKind.COMPLETE: - idx, code, artifact_path = map_result.value - if artifact_path is None: - logger.debug("Artifact path is None") - continue - sch = _sched[idx] - config = configs[idx] - rt_mod = tvm.runtime.load_module(artifact_path) - cpresult = CompileResult(config, sch, rt_mod) - timer_cuda_mod = rt_mod.time_evaluator( - rt_mod.entry_name, arch.device, number=num_repeats) - cpresult.profile_tensors = profile_tensors - cpresult.time_evaluator = timer_cuda_mod - cpresult.code = code - cpresults.append(cpresult) - else: - raise ValueError(f"Unreachable: unexpected result: {map_result}") - - del builder - - best = None - best_latency = 1e9 - for cpresult in cpresults: - config = cpresult.config - try: - latency = cpresult.profile() - except Exception as e_mesg: - logger.debug(f"Evaluation with config failed {e_mesg}") - continue - logger.info("Evaluation with config {}".format(config)) - logger.info("Time cost of this config: {:.3f} ms".format(latency)) - - cpresult.latency = latency - if latency < best_latency: - best_latency = latency - best = cpresult - - return cpresults, best - - -def apply_and_build( - func, - configs, - arch, - parallel_build=False, - data_distribution="uniform", -) -> Tuple[List[CompileResult], CompileResult]: - max_workers = 10 if parallel_build else 1 - return apply_and_build_parallel( - func, configs, arch, max_workers=max_workers, data_distribution=data_distribution) - - -def fast_tune( - func: tir.PrimFunc, - target: tvm.target.Target, - topk: int = 10, - parallel_build: bool = True, - data_distribution: Literal["uniform", "onefill"] = "uniform", -): - # check the function is a primfunc - if not isinstance(func, tir.PrimFunc): - raise ValueError("Only support func is PrimFunc") # pragma: no cover - - if target.kind.name != "cuda": - logger.error("Only support CUDA target") - return None, None - - specilized_func = func - if func.attrs is not None and "opt_shapes" in func.attrs: - opt_shapes = func.attrs["opt_shapes"] - # should be int value - if not all([isinstance(v.value, int) for v in opt_shapes.values()]): - logger.error("The opt_shapes should be int value") - return None, None - # currently only support one dynamic range - if len(opt_shapes) > 1: - logger.error("Currently only support one dynamic range") - return None, None - - for buffer in func.buffer_map.values(): - for axis in buffer.shape: - if isinstance(axis, tvm.tir.Var) and axis.name not in opt_shapes: - raise NotImplementedError( - "Currently do not support fast tune with none-dynamic range set") - if opt_shapes: - for name, shape in opt_shapes.items(): - var = find_var_from_func(func, name) - specilized_func = func.specialize({ - var: shape.astype(var.dtype) - }).with_attr("is_specialized") - - arch = CUDA(target) - - policy = DefaultPolicy(func=func, arch=arch) - try: - specilized_func, tags = get_tensorized_func_and_tags(specilized_func, arch.target) - except Exception as e_msg: - logger.debug("Get tensorized func and tags failed: ", e_msg) - tags = None - if tags: - policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) - - configs = policy.emit_config(topk) - - if len(configs) == 0: - raise ValueError("No valid config generated") - - cpresults, best = apply_and_build( - func, - configs, - arch, - parallel_build=parallel_build, - data_distribution=data_distribution, - ) - - return cpresults, best - - -# always use the first function as the base -def collect_buffers_to_declare(func): - params = [] - # collect dynamic symbolic - dyn_symbolic: List[tvm.tir.Var] = [] - buffers_to_declare = [] - for param in func.params: - if param not in func.buffer_map: - continue - buffer = func.buffer_map[param] - for axis in buffer.shape: - if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: - dyn_symbolic.append(axis) - buffers_to_declare.append(buffer) - params.append(buffer.data) - - # the args should be buffers + dynamic symbolic - params += list(dyn_symbolic) - - return params, buffers_to_declare - - -def refactor_specialized_func(g_var, func, params, buffers_to_declare): - body = func.body - attrs = func.attrs - global_symbol = g_var - if "opt_shapes" in func.attrs: - opt_shapes = func.attrs["opt_shapes"] - - def serialize_name(opt_shapes: Dict): - return "_opt_" + "_".join([f"{k}_{v}" for k, v in opt_shapes.items()]) - - global_symbol += serialize_name(opt_shapes) - ret_type = func.ret_type - for buf in buffers_to_declare: - body = tvm.tir.DeclBuffer(buf, body=body) - - # device func must be private - device_func = tvm.tir.PrimFunc( - params, body, ret_type, attrs=attrs).without_attr("global_symbol") - return global_symbol, device_func - - -def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[str]): - global_symbol = g_var - attrs = func.attrs - buffer_map = func.buffer_map - params = func.params - ret_type = func.ret_type - - # collect dynamic symbolic - dyn_symbolic: List[tvm.tir.Var] = [] - _invoke_params = [] - for param in func.params: - if param not in func.buffer_map: - continue - buffer = func.buffer_map[param] - for axis in buffer.shape: - if isinstance(axis, tvm.tir.Var) and axis not in dyn_symbolic: - dyn_symbolic.append(axis) - _invoke_params.append(buffer.data) - _invoke_params += list(dyn_symbolic) - - func_range: List[int] = [] - global_symbols = [] - for g_var, refactor_func in refactored_funcs: - opt_shapes = refactor_func.attrs["opt_shapes"] - func_range.append(list(opt_shapes.values())[0]) - global_symbols.append(g_var) - - # TODO(lei): general the dispatch function to support multiple dynamic symbolics - assert len(dyn_symbolic) == 1, "Only support one dynamic symbolics currently" - - ib = tvm.tir.ir_builder.create() - syb = list(dyn_symbolic)[-1] - last_range = 0 - for i, (_range, g_var) in enumerate(zip(func_range, global_symbols)): - if i == 0: - with ib.if_scope(syb <= _range): - ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) - else: - with ib.if_scope(tvm.tir.all(syb > last_range, syb <= _range)): - ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) - last_range = _range - with ib.if_scope(syb > last_range): - ib.emit(tvm.tir.Call(None, g_var, _invoke_params)) - stmt = ib.get() - dispatch_func = tvm.tir.PrimFunc(params, stmt, ret_type, buffer_map, attrs).with_attrs({ - "tir.is_global_func": True, - "global_symbol": global_symbol - }) - return dispatch_func - - -def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc, - specialized_funcs: List[tir.PrimFunc]) -> IRModule: - dispatch_mod: IRModule = tvm.IRModule() - g_var_supply = GlobalVarSupply(dispatch_mod) - refactored_funcs = [] - for func in specialized_funcs: - params, buffers_to_declare = collect_buffers_to_declare(func) - global_symbol, device_func = refactor_specialized_func(g_var, func, params, - buffers_to_declare) - global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False) - dispatch_mod[global_symbol] = device_func - refactored_funcs.append((global_symbol, device_func)) - dispatch_func = create_dispatch_func(g_var, original_func, refactored_funcs=refactored_funcs) - dispatch_mod.update(tvm.IRModule.from_expr(dispatch_func)) - return dispatch_mod - - -def fast_tune_with_dynamic_range( - func: tir.PrimFunc, - target: tvm.target.Target, - topk: int = 10, - parallel_build: bool = True, - global_symbol: Optional[str] = None, - dynamic_range: Optional[Dict[str, List[int]]] = None, -) -> IRModule: - if dynamic_range is None: - dynamic_range = {} - if target.kind.name != "cuda": - logger.error("Only support CUDA target") - return None - if not global_symbol: - global_symbol = func.attrs["global_symbol"] - - # set opt_shapes for the primfunc with dynamic symbolic - opt_shapes: Dict[str, List[int]] = {} - for buffer in func.buffer_map.values(): - for axis in buffer.shape: - if isinstance(axis, tvm.tir.Var): - if axis.name in dynamic_range: - opt_shapes[axis.name] = dynamic_range[axis.name] - else: - raise ValueError(f"[BitBLAS] The axis {axis.name} is not in dynamic_range") - func = func.with_attr("opt_shapes", opt_shapes) - - if "opt_shapes" not in func.attrs: - logger.error( - "[BitBLAS] The primfunc has no opt_shapes, please set opt_shapes for the primfunc") - return None - else: - # should be list value - if not all([isinstance(v, tvm.ir.Array) for v in func.attrs["opt_shapes"].values()]): - logger.error("The opt_shapes should be list value") - return None - - logger.info("Start fast tuning with dynamic range") - opt_shapes = func.attrs["opt_shapes"] - - # Step 1.Calculate the Cartesian product using itertools.product - product_list = list(itertools.product(*(opt_shapes[key] for key in opt_shapes))) - - # Convert the Cartesian product to a list of dictionaries - specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list] - - specilized_tuned_funcs: List[tir.PrimFunc] = [] - for item in specialize_items: - func = func.with_attr("opt_shapes", item) - _, best = fast_tune(func, target, topk, parallel_build) - if best is None: - return None - specilized_tuned_funcs.append(best.sch.mod["main"]) - - return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs) diff --git a/python/bitblas/cache/__init__.py b/python/bitblas/cache/__init__.py deleted file mode 100644 index 0c8fd3b9c..000000000 --- a/python/bitblas/cache/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .operator import ( - global_operator_cache, # noqa: F401 - load_global_ops_cache, # noqa: F401 - get_database_path, # noqa: F401 - set_database_path, # noqa: F401 -) diff --git a/python/bitblas/cache/operator.py b/python/bitblas/cache/operator.py deleted file mode 100644 index 9b30a6200..000000000 --- a/python/bitblas/cache/operator.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import bitblas -from bitblas.ops.operator import OperatorConfig, Operator -from dataclasses import asdict -import os -import json -import tempfile -from hashlib import sha256 -import shutil -import tvm -from tvm.contrib.tar import tar -import logging - -logger = logging.getLogger(__name__) - -BITBLAS_DATABASE_PATH = os.path.expanduser("~/.cache/bitblas") - - -class OperatorCache: - """ - Manages a cache for operator instances (e.g., Matmul, Convolution) based on their configurations. - """ - - def __init__(self): - self.cache = {} - - def add(self, config: OperatorConfig, op_inst: Operator): - self.cache[config] = op_inst - - def get(self, config: OperatorConfig): - return self.cache.get(config) - - def exists(self, config): - return config in self.cache - - def clear(self): - self.cache.clear() - - def size(self): - return len(self.cache) - - def save_into_database(self, database_path=None, target=None): - database_path = self._ensure_database_path(database_path) - for config, op_inst in self.cache.items(): - arch_str = self._determine_arch_str(op_inst, target) - arch_path = os.path.join(database_path, arch_str) - self._ensure_directory(arch_path) - hash_str = sha256(repr(config).encode()).hexdigest() - config_path = os.path.join(arch_path, hash_str) - # if the config already exists, skip saving - if os.path.exists(config_path): - continue - self._ensure_directory(config_path) - self._save_operator_config_and_artifact(config, op_inst, config_path) - - def load_from_database(self, database_path, target=None): - if not os.path.exists(database_path): - logger.info( - f"Database path {database_path} does not exist, skipping loading operators from the database" - ) - return - arch_str = self._determine_target_arch_str(target) - arch_path = os.path.join(database_path, arch_str) - if not os.path.exists(arch_path): - logger.info( - f"Target {arch_str} does not exist in the database, skipping loading operators from the database" - ) - return - self._load_operators_from_arch_path(arch_path, target) - - def _ensure_database_path(self, database_path): - if database_path is None: - return tempfile.mkdtemp() - os.makedirs(database_path, exist_ok=True) - return database_path - - def _determine_arch_str(self, op_inst, target): - return (target if target else "-".join(list(op_inst.target.keys) + [op_inst.target.arch])) - - def _ensure_directory(self, path): - os.makedirs(path, exist_ok=True) - - def _save_operator_config_and_artifact(self, config, op_inst, config_path): - config_type, operator_type = type(config).__name__, type(op_inst).__name__ - with open(os.path.join(config_path, f"{config_type}.json"), "w") as json_file: - json.dump(asdict(config), json_file) - artifact_path = os.path.join(config_path, "tvm_rt_mod." + tar.output_format) - try: - op_inst.rt_mod.export_library(artifact_path, fcompile=tar) - except Exception as e: - # library does not support export_library - export_error = e # noqa: F841 - pass - json_data = {"config_type": config_type, "operator_type": operator_type} - json_file_path = os.path.join(config_path, "mapping.json") - with open(json_file_path, "w") as json_file: - json.dump(json_data, json_file) - - # For writing source.cu file - source_file_path = os.path.join(config_path, "source.cu") - with open(source_file_path, "w") as source_file: - source_file.write(op_inst.get_source()) - - # For writing optimized.py file - optimized_file_path = os.path.join(config_path, "optimized.py") - with open(optimized_file_path, "w") as optimized_file: - if op_inst.optimized_func is not None: - optimized_file.write(op_inst.optimized_func.script(show_meta=False)) - if op_inst.wrapper.lib_name is not None: - # copy lib name to the same directory as the artifact - src_name = op_inst.wrapper.src_name - shutil.copy( - src_name, - os.path.join(config_path, os.path.basename("wrapper_source.cu")), - ) - lib_name = op_inst.wrapper.lib_name - shutil.copy( - lib_name, - os.path.join(config_path, os.path.basename("wrapper_compiled.so")), - ) - - def _determine_target_arch_str(self, target): - return (target if isinstance(target, str) else "-".join(list(target.keys) + [target.arch])) - - def _load_operators_from_arch_path(self, arch_path, target): - for root, dirs, _ in os.walk(arch_path): - for directory in dirs: - config_path = os.path.join(root, directory) - self._load_operator(config_path, target) - - def _load_operator(self, config_path, target): - mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None - for file in os.listdir(config_path): - full_path = os.path.join(config_path, file) - if file == "mapping.json": - with open(full_path) as f: - mapping = json.load(f) - elif file.endswith(".json"): - with open(full_path) as f: - config = json.load(f) - elif file.endswith(".tar"): - rt_mod = tvm.runtime.load_module(full_path) - elif file == "wrapper_compiled.so": - lib_name = full_path - elif file == "wrapper_source.cu": - src_name = full_path - - if mapping and config and rt_mod: - self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target) - - def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target): - config_cls = getattr(bitblas, mapping["config_type"]) - operator_cls = getattr(bitblas, mapping["operator_type"]) - op_inst = operator_cls( - config=config_cls(**config), target=target, enable_tuning=False, from_database=True) - op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name) - self.add(config_cls(**config), op_inst) - - -global_operator_cache = OperatorCache() - - -def load_global_ops_cache(database_path=BITBLAS_DATABASE_PATH, target=None): - if target is None: - target = bitblas.auto_detect_nvidia_target() - logger.info(f"Loading operators from database {database_path} for target {target}") - global_operator_cache.load_from_database(database_path, target) - return global_operator_cache - - -def get_database_path(): - return BITBLAS_DATABASE_PATH - - -def set_database_path(path): - global BITBLAS_DATABASE_PATH - BITBLAS_DATABASE_PATH = path - return BITBLAS_DATABASE_PATH diff --git a/python/bitblas/generator.py b/python/bitblas/generator.py deleted file mode 100644 index 4ac6f2be2..000000000 --- a/python/bitblas/generator.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -class BitBLASGenerator: - - def __init__(self): - # Initialize the generator with configuration - pass - - def generate_cuda_code(self): - pass - - def generate_header(self): - pass diff --git a/python/bitblas/gpu/__init__.py b/python/bitblas/gpu/__init__.py deleted file mode 100644 index df0635b3c..000000000 --- a/python/bitblas/gpu/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -""" -GPU-generic schedule rules. -For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead -""" -from .fallback import Fallback # noqa: F401 -from .element_wise import ElementWise # noqa: F401 -from .gemv import GEMV # noqa: F401 -from .gemv_dequantize import GEMVWithDequantizeInfo # noqa: F401 -from .general_reduction import GeneralReduction # noqa: F401 -from .matmul import ( - Matmul, # noqa: F401 - MatmulTensorizationMMA, # noqa: F401 - MatmulTensorizationWMMA, # noqa: F401 -) -from .matmul_mma_dequantize import ( - MatmulTensorizationMMAWithDequantizeInfo, # noqa: F401 -) -from .matmul_wmma import MatmulTensorizationLegacy # noqa: F401 - -from .reduction import Reduction # noqa: F401 -from .transpose import Transpose # noqa: F401 diff --git a/python/bitblas/gpu/base.py b/python/bitblas/gpu/base.py deleted file mode 100644 index 3bf927244..000000000 --- a/python/bitblas/gpu/base.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# /* Modifications Copyright (c) Microsoft. */ -# The code below is mostly copied from apache/tvm base.py in dlight. -"""Base schedule rule for GPU operators.""" - -from tvm.target import Target - -from ..base import ScheduleRule - - -class GPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods - """The Schedule Rule specific to GPU targets, will return None if the target is not GPU.""" - - def is_target_available(self, target: Target) -> bool: - """Check whether the target is available for gpu rule. - - Parameters - ---------- - target : Target - The compilation target to check. - - Returns - ------- - available : bool - Whether the target is available for this rule. - """ - return super().is_target_available(target) and "gpu" in target.keys diff --git a/python/bitblas/gpu/element_wise.py b/python/bitblas/gpu/element_wise.py deleted file mode 100644 index 07ea3a27e..000000000 --- a/python/bitblas/gpu/element_wise.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring -"""A fallback schedule rule for GPU operators.""" -from typing import List - -from tvm import tir - -from ..base import ScheduleRule, normalize_prim_func, try_inline - - -class ElementWise(ScheduleRule): - """ - An elementwise schedule rule for GPU operators. - """ - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> tir.Schedule: - block_factors = config.block - thread_factors = config.thread - step_factors = config.step - - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - block_infos = try_inline(sch, block_infos) - - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - sch.reorder(*s_loops, *r_loops, *o_loops) - - block_loops = [] - vthread_loops = [] - thread_loops = [] - inner_loops = [] - for s_loop, block_factor, step_factor, thread_factor in zip( - s_loops, block_factors, step_factors, thread_factors - ): - block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) - vthread_loop, inner_loop = sch.split( - inner_loop, factors=[None, thread_factor * step_factor] - ) - thread_loop, inner_loop = sch.split( - inner_loop, factors=[None, step_factor] - ) - block_loops.append(block_loop) - vthread_loops.append(vthread_loop) - thread_loops.append(thread_loop) - inner_loops.append(inner_loop) - - # inner virtual thread first - vthread_loops = list(reversed(vthread_loops)) - sch.reorder( - *block_loops, - *vthread_loops, - *thread_loops, - *inner_loops, - *r_loops, - *o_loops - ) - sch.bind(sch.fuse(*block_loops), "blockIdx.x") - sch.bind(sch.fuse(*thread_loops), "threadIdx.x") - if len(vthread_loops) > 3: - vthread_loops = vthread_loops[0:2] + [sch.fuse(*vthread_loops[2:])] - - for i, ax in enumerate(vthread_loops): - sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) - - return sch diff --git a/python/bitblas/gpu/fallback.py b/python/bitblas/gpu/fallback.py deleted file mode 100644 index 3711d3682..000000000 --- a/python/bitblas/gpu/fallback.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm fallback.py in dlight. -# pylint: disable=missing-docstring -"""A fallback schedule rule for GPU operators.""" -from typing import List, Tuple - -from tvm import tir -from tvm.target import Target - -from ..base import normalize_prim_func, try_inline -from . import utils -from .base import GPUScheduleRule - - -class Fallback(GPUScheduleRule): - """ - A fallback schedule rule for all GPU operators. It will try to inline all the blocks first, - and then apply a simple block/grid mapping to the spatial loops on top of the remaining blocks. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> tir.Schedule: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - max_threads_per_block = utils.max_threads_per_block(target) - - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - block_infos = try_inline(sch, block_infos) - reduction_blocks: List[Tuple[tir.schedule.BlockRV, tir.schedule.LoopRV]] = [] - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block) - ] - ) - or len(sch.get_loops(block)) == 0 - ): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - sch.reorder(*s_loops, *r_loops, *o_loops) - bx, tx = sch.split( # pylint: disable=invalid-name - sch.fuse(*s_loops), - factors=[None, max_threads_per_block], - ) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - - if len(r_loops) > 0: - reduction_blocks.append((block, r_loops[0])) - - for block, r_loop in reduction_blocks: - sch.decompose_reduction(block, r_loop) - - return sch - \ No newline at end of file diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py deleted file mode 100644 index 60a290a81..000000000 --- a/python/bitblas/gpu/gemv.py +++ /dev/null @@ -1,794 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm gemv.py in dlight. -"""A rule for GEMV and DecodeGEMV.""" - -from functools import reduce -from typing import List, Optional, Union, Dict - -from tvm import DataType, arith, ir, tir -from tvm.target import Target - -from ..base import ( - BlockInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, - is_broadcast_epilogue, - normalize_prim_func, - try_inline_contiguous_spatial, - get_output_blocks, -) -from .base import GPUScheduleRule -from .gemv_dequantize import GEMVWithDequantizeInfo - - -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` - buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): - return None - if not isinstance(buffer_store.value, tir.Add): - return None - if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, - ): - return None - return buffer_store.value.b - - -def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): - loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent - - -def get_bytes(dtype: Union[DataType, str]) -> int: - if isinstance(dtype, str): - dtype = DataType(dtype) - return int(dtype.bits) // 8 - - -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: - """Check if the block is a GEMV. - - Parameters - ---------- - - sch : tir.Schedule - The schedule - - block_info : BlockInfo - The block info to be checked - - - Returns - ------- - ret : Optional[List[tir.Buffer]] - The vector buffers used in the GEMV if it is a GEMV, otherwise None. - """ - block = block_info.block_rv - block_stmt = sch.get(block) - conditions = [] - conditions.append(block_info.is_reduction()) - conditions.append(len(block_stmt.reads) >= 2) - conditions.append(len(block_stmt.writes) == 1) - conditions.append(_get_reduction_expr(block_stmt) is not None) - conditions.append( - len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) - > 0) - if not all(conditions): - return None - - iter_num = len(block_stmt.iter_vars) - ret = [ - read.buffer - for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 - ] - if len(ret) == len(block_stmt.reads): - func = sch.mod["main"] - opt_shapes: Dict = {} - if "opt_shapes" in func.attrs: - opt_shapes = func.attrs["opt_shapes"] - # check with dynamic symbolic and at least one is unit - if not all([opt_shapes.get(buf.name, (1,))[0] == 1 for buf in ret]): - return None - elif len(ret) == 0: - return None - return ret - - -def normalize( - sch: tir.Schedule, - block_info: BlockInfo, -) -> Optional[bool]: - """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) - access = arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, - ) - buffers_use_vars = [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.writes - ] - buffers_use_vars.extend([ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.reads - ]) - if collect_vars_used_in_prim_expr(access.base) & set( - iter_var.var for iter_var in block_stmt.iter_vars): - return None - iter_to_info = {i.var: i for i in block_info.iters} - batch_loops, s_loops, r_loops, c_loops = [], [], [], [] - inner_axis = access.args[-1].source.source - is_inner_reduction = iter_to_info[inner_axis].kind == "R" - - for split_expr in access.args: - var = split_expr.source.source - info = iter_to_info.get(var) - loop = info.loop_rv - is_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we only support the reduction dim being grouped atm - if not is_reduction: - return None - c_loops.append(c_loop) - if is_reduction: - r_loops.append(loop) - elif all([var in buf_vars for buf_vars in buffers_use_vars]): - batch_loops.append(loop) - else: - s_loops.append(loop) - - assert s_loops - assert r_loops - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] - if not batch_loops: - batch_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) - sch.fuse(*batch_loops) - sch.fuse(*s_loops) - sch.fuse(*r_loops) - return is_inner_reduction - - -class GEMV(GPUScheduleRule): - """A rule for GEMV and DecodeGEMV.""" - - def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - if "dequantize_info" in func.attrs: - dequantize_rule = GEMVWithDequantizeInfo() - return dequantize_rule.apply(func, target, False) - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if len(block_infos) == 1: - epilogue = None - elif len(block_infos) == 2: - epilogue = block_infos[1] - if not epilogue.is_injective(): - return None - else: - return None - - block_info = block_infos[0] - if len(block_info.iters) not in [2, 3]: - # either [B, S, R] = [B, S, R] * [B, R] - # or [S, R] = [S, R] * [R] - return None - block = block_info.block_rv - vector_input_buffers = is_gemv(sch, block_info) - if vector_input_buffers is None: - return None - - # Step 1. Normalize the block, merge spatial and reduction iters - is_inner_reduction = normalize(sch, block_info) - - # Step 2. Do the scheduling - if is_inner_reduction is None: - return None - elif is_inner_reduction: - self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) - return sch - else: - return self.sch_outer_reduction(sch, target, block, vector_input_buffers, epilogue) - - def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument - self, - sch: tir.Schedule, - target: Target, - block: tir.schedule.BlockRV, - vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], - ): - """Schedule the inner reduction block.""" - - def get_max_factor(n, factors): - factors = sorted(factors, reverse=True) - for factor in factors: - if n % factor == 0: - return factor - return 1 - - def apply( - sch: tir.Schedule, - gemv, - TAG_S, - TAG_R, - TS, - TR, - TILE_S, - TILE_R, - VEC_LOAD, - VEC_C, - LOAD_V_SHARED, - LOAD_V_VEC, - UNROLL, - ): - # rfactor: reduce to tx * vec_c - _, s, r, c = sch.get_loops(block=gemv) - s = sch.fuse(_, s) - r = sch.fuse(r, c) - bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) - r, tr, tile_r_vec_n, vec_c = sch.split( - r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True) - sch.reorder(r, tile_r_vec_n, tr, vec_c) - tr_vec_c = sch.fuse(tr, vec_c) - rf = sch.rfactor(tr_vec_c, 0) - - # rfactor: reduce to tx - bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) - tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) - rf2 = sch.rfactor(tr, 0) - - # bind, vectorize compute - bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) - tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) - sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) - sch.bind(bx, "blockIdx.x") - sch.bind(ts, TAG_S) - sch.bind(tr, TAG_R) - sch.vectorize(vec_c) - - shared_mem_usage = 0 - for buf in vector_input_buffers: - buf_size = reduce(lambda x, y: x * y, buf.shape, tir.IntImm( - buf.shape[0].dtype, 1)) * get_bytes(buf.dtype) - shared_mem_usage += buf_size - try: - max_shared_memory_per_block = target.max_shared_memory_per_block - except Exception: - max_shared_memory_per_block = 49152 - LOAD_V_SHARED = ( - LOAD_V_SHARED and isinstance(shared_mem_usage, tir.IntImm) and - shared_mem_usage.value <= max_shared_memory_per_block) - - # vectorize load A - # (TODO) this is now actually problematic since the number of loops is dependent on the - # number of dimensions of A_q - Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") - sch.compute_at(Aq_local, r, preserve_unit_loops=True) - s_local, r_local = sch.get_loops(block=Aq_local)[-2:] - s_local, vec_load = sch.split( - s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True) - sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 - sch.vectorize(vec_load) - - # load vector into shared memory, shape should be the whole vector - if LOAD_V_SHARED: - V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") - sch.compute_at(V_shared, tr, preserve_unit_loops=True) - l = sch.get_loops(block=V_shared)[-1] # noqa: E741 - loop: tir.For = sch.get(l) - if isinstance(loop.extent, tir.IntImm): - # avoid introducing predicates when vector length is too large - vec_length = max( - min( - get_max_factor( - (int)(loop.extent), - [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], - ) // TS // TR, - LOAD_V_VEC, - ), - 1, - ) - else: - vec_length = LOAD_V_VEC - if TAG_R == "threadIdx.x": - _, ty, tx, vec = sch.split( - l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True) - else: - _, ty, tx, vec = sch.split( - l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.vectorize(vec) - - # reduce tile_s * tr * vec to tile_s * tr - sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) - tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] - ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) - tile_s, vec_s = sch.split( - tile_s, - factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], - preserve_unit_iters=True, - ) - sch.reorder(ts, tr, tile_s, vec_s, vec_c) - sch.bind(ts, TAG_S) - sch.bind(tr, TAG_R) - sch.vectorize(vec_s) - - # reduce tile_s * tr to tile_s - sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) - tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] - ts_tile_s = sch.fuse(*ts_tile_s) - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) - sch.reorder(tile_s, ts, tr) - sch.bind(ts, TAG_S) - sch.bind(tr, TAG_R) - - sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3]) - sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) - - sch.set_scope(rf, buffer_index=0, storage_scope="local") - sch.set_scope(rf2, buffer_index=0, storage_scope="local") - - unroll_factor = UNROLL - - sch.annotate( - block_or_loop=sch.get_loops(rf)[3], - ann_key="pragma_auto_unroll_max_step", - ann_val=unroll_factor, - ) - sch.annotate( - block_or_loop=sch.get_loops(rf)[3], - ann_key="pragma_unroll_explicit", - ann_val=1, - ) - - sch.annotate( - block_or_loop=sch.get_loops(rf2)[3], - ann_key="pragma_auto_unroll_max_step", - ann_val=unroll_factor, - ) - sch.annotate( - block_or_loop=sch.get_loops(rf2)[3], - ann_key="pragma_unroll_explicit", - ann_val=1, - ) - - if LOAD_V_SHARED: - sch.annotate( - block_or_loop=sch.get_loops(V_shared)[-4], - ann_key="pragma_unroll_explicit", - ann_val=unroll_factor, - ) - sch.annotate( - block_or_loop=sch.get_loops(V_shared)[-4], - ann_key="pragma_vectorize", - ann_val=1, - ) - - # Schedule epilogue - if epilogue_info is not None: - epilogue = epilogue_info.block_rv - if is_broadcast_epilogue(sch, block, epilogue): - sch.reverse_compute_at(epilogue, bx) - sch.set_scope(block, 0, "shared") - _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx = sch.split(sch.fuse(*s), factors=[None, TS]) - sch.bind(tx, "threadIdx.x") - else: - sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) - ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) - ts_tile_s = sch.get_loops(epilogue)[-1] - ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) - sch.bind(ts, TAG_S) - sch.set_scope(block, 0, "local") - # pylint: enable=invalid-name - return sch - - # Specify the `len_tx` and `len_ty` according to the loop extent - batch, s, r, c = sch.get_loops(block=block) - len_batch, len_s, len_r, len_c = ( - get_extent(sch, batch), - get_extent(sch, s), - get_extent(sch, r), - get_extent(sch, c), - ) - len_S = len_batch * len_s - len_R = len_r * len_c - - TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" - if target.kind.name == "cuda": - VEC_C = 4 - LOAD_V_SHARED = True - LOAD_V_VEC = 8 - UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: - TS, TR = 4, 64 - else: - TS, TR = 16, 32 - elif target.kind.name == "metal": - # Note that the following tile size is tuned on M2 Ultra for 7B - TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" - VEC_C = 1 - LOAD_V_SHARED = False - LOAD_V_VEC = -1 - UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: - TS, TR = 4, 16 - else: - TS, TR = 2, 64 - elif target.kind.name == "rocm": - VEC_C = 4 - LOAD_V_SHARED = True - LOAD_V_VEC = 8 - UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: - TS, TR = 1, 128 - else: - TS, TR = 8, 64 - elif target.kind.name == "opencl" and "android" in str(target.host): - TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" - VEC_C = 8 - LOAD_V_SHARED = False - LOAD_V_VEC = -1 - UNROLL = 8 - TS, TR = 2, 32 - elif target.kind.name == "vulkan": - VEC_C = 4 - LOAD_V_SHARED = True - LOAD_V_VEC = 4 - UNROLL = 256 - if isinstance(len_S, int): - if len_S > len_R: - TS, TR = 4, 32 - else: - TS, TR = 16, 32 - elif target.kind.name == "opencl" and "mali" in str(target.attrs): - VEC_C = 8 - LOAD_V_SHARED = False - LOAD_V_VEC = -1 - UNROLL = 64 - TS, TR = 1, 64 - else: - VEC_C = 1 - LOAD_V_SHARED = False - LOAD_V_VEC = -1 - UNROLL = 64 - TS, TR = 1, 64 - - if not isinstance(len_S, int): - TS, TR = 1, 64 - - while TS * TR > target.max_num_threads: - if TS > 1: - TS //= 2 - else: - TR //= 2 - - TILE_S, TILE_R = ( - 1, - (len_c if len_c > 1 else max( - get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1)), - ) - VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) - VEC_LOAD = 1 - - return apply( - sch, - gemv=block, - TAG_S=TAG_S, - TAG_R=TAG_R, - TS=TS, - TR=TR, - TILE_S=TILE_S, - TILE_R=TILE_R, - VEC_LOAD=VEC_LOAD, - VEC_C=VEC_C, - LOAD_V_SHARED=LOAD_V_SHARED, - LOAD_V_VEC=LOAD_V_VEC, - UNROLL=UNROLL, - ) - - def sch_outer_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument - self, - sch: tir.Schedule, - target: Target, - block: tir.schedule.BlockRV, - vector_input_buffers: List[tir.Buffer], - epilogue_info: Optional[BlockInfo], - ): - """Schedule the outer reduction block.""" - # NOTE: Only Android is supported so far - if not (target.kind.name == "opencl" and "android" in str(target.host)): - return None - batch, s, r, c = sch.get_loops(block) - len_s = get_extent(sch, s) - - # The config is designed for Adreno - tx_len = 64 - vec_len = (4 if len_s > 4096 else 2) if isinstance(len_s, int) else 1 - inner_r = 4 - - bx, tx, vec = sch.split(s, factors=[None, tx_len, vec_len]) - r0, r1 = sch.split(r, factors=[None, inner_r]) - sch.bind(batch, "blockIdx.y") - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.reorder(bx, tx, r0, r1, c, vec) - - sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=8) - sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) - - cache_v = sch.cache_read(block, vector_input_buffers[0], "local") - sch.compute_at(cache_v, r1, preserve_unit_loops=True) - sch.vectorize(sch.get_loops(cache_v)[-1]) - - sch.vectorize(vec) - - # Schedule epilogue - if epilogue_info is not None: - sch.reverse_compute_at(epilogue_info.block_rv, tx) - - sch.set_scope(block, 0, "local") - - sch.decompose_reduction(block, r0) - - return sch - - def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - sch = tir.Schedule(func) - - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - reduction_block: tir.schedule.BlockRV = None - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if (any([ - sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) - ]) or len(sch.get_loops(block)) == 0): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - if len(r_loops) > 0: - reduction_block = block - # skip analysis for following blocks - break - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - vec = 1 - if len(config.vectorize): - vec = list(config.vectorize.values())[-1] - - num_warps = int(prod(config.thread)) - warp_size = int(prod(config.reduce_thread)) - - block_b = reduction_block - output_blocks = get_output_blocks(sch, block_infos) - # compute inline - for block_info in reversed(block_infos): - block = block_info.block_rv - if block not in (reduction_block, *output_blocks): - sch.compute_inline(block) - try: - i, j, k = sch.get_loops(block_b) - except Exception: - j, k = sch.get_loops(block_b) - block_local_A = sch.cache_read(block_b, 0, "local") - block_local_B = sch.cache_read(block_b, 1, "local") - block_local_C = sch.cache_write(block_b, 0, "local") - # reverse inline - if reduction_block is not None and reduction_block != output_blocks[0]: - sch.reverse_compute_inline(output_blocks[0]) - - bx, j = sch.split(j, factors=[None, num_warps]) - k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) - sch.reorder(bx, j, k, tx) - - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.bind(j, "threadIdx.y") - - self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] - self.grid_size = [sch.get(bx).extent, 1, 1] - - sch.compute_at(block_local_A, tx, preserve_unit_loops=True) - sch.compute_at(block_local_B, tx, preserve_unit_loops=True) - sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) - - block_local_a_v = sch.get_loops(block_local_A)[-1] - sch.vectorize(block_local_a_v) - block_local_b_v = sch.get_loops(block_local_B)[-1] - sch.vectorize(block_local_b_v) - - return sch - - def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - reduction_block: tir.schedule.BlockRV = None - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if (any([ - sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) - ]) or len(sch.get_loops(block)) == 0): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - if len(r_loops) > 0: - reduction_block = block - # skip analysis for following blocks - break - - C = reduction_block - CL = sch.cache_write(reduction_block, 0, "local") - - blck_axis = [] - vthd_axis = [] - thrd_axis = [] - tile_axis = [] - # for gemv, we should skip dynamic symbolic in s_loops - s_loops = [loop for loop in s_loops if isinstance(sch.get(loop).extent, tir.IntImm)] - assert len(s_loops) == len(config.block), f"{len(s_loops)} != {len(config.block)}" - for i, loop in enumerate(s_loops): - if sch.get(loop).extent % config.block[i]: - raise NotImplementedError("Undivisible block in TIR schedule is still buggy.") - bx, _t = sch.split(loop, factors=[None, config.block[i]]) - blck_axis.append(bx) - if config.step[i] > 1: - _t, tn = sch.split(_t, factors=[None, config.step[i]]) - tile_axis.append(tn) - if config.block[i] <= config.thread[i] * config.step[i]: - tx = _t - else: - vx, tx = sch.split(_t, factors=[None, config.thread[i]]) - vthd_axis.append(vx) - thrd_axis.append(tx) - - reduce_outer_axis, reduce_inner_axis = [], [] - - for i in config.raxis_order: - loop = r_loops[i] - ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) - reduce_outer_axis.append(ro) - reduce_inner_axis.append(ri) - - vthd_axis = list(reversed(vthd_axis)) # inner virtual thread first - axis_order = ( - blck_axis + vthd_axis + thrd_axis + reduce_outer_axis + reduce_inner_axis + tile_axis) - - sch.reorder(*axis_order) - blck_fused = sch.fuse(*blck_axis) - thrd_fused = sch.fuse(*thrd_axis) - sch.bind(blck_fused, "blockIdx.x") - sch.bind(thrd_fused, "threadIdx.x") - if len(vthd_axis) > 3: - vthd_axis = vthd_axis[0:2] + [sch.fuse(*vthd_axis[2:])] - for i, ax in enumerate(vthd_axis): - sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) - for ax in tile_axis: - sch.unroll(ax) - - sch.reverse_compute_at(CL, thrd_fused) - if len(tile_axis) > 0: - for ax in sch.get_loops(CL)[-len(tile_axis):]: - sch.unroll(ax) - - sch.decompose_reduction(C, reduce_outer_axis[0]) - - try_inline_contiguous_spatial(sch, block_infos) - - return sch - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> tir.Schedule: - if not isinstance(func, tir.PrimFunc): - return None - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if len(block_infos) == 1: - epilogue = None - elif len(block_infos) == 2: - epilogue = block_infos[1] - if not epilogue.is_injective(): - return None - else: - return None - - block_info = block_infos[0] - if len(block_info.iters) not in [2, 3, 4]: - # either [SK, B, S, R] = [SK, B, S, R] * [SK, B, R] - # either [B, S, R] = [B, S, R] * [B, R] - # or [S, R] = [S, R] * [R] - return None - - if is_gemv(sch, block_info) is None: - return None - - if "dequantize_info" in func.attrs: - dequantize_rule = GEMVWithDequantizeInfo() - return dequantize_rule.apply_config(func, config) - - if any([t > 1 for t in config.reduce_thread]): - return self.sch_inner_reduction_with_config(func, config) - - return self.sch_outer_reduction_with_config(func, config) diff --git a/python/bitblas/gpu/gemv_dequantize.py b/python/bitblas/gpu/gemv_dequantize.py deleted file mode 100644 index 5ccc5b40e..000000000 --- a/python/bitblas/gpu/gemv_dequantize.py +++ /dev/null @@ -1,369 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""A rule for GEMV and DecodeGEMV.""" -from functools import reduce -from typing import List, Dict -from tvm.target import Target -from tvm.tir.function import PrimFunc -from tvm import DataType, tir -import logging -from ..base import ( - normalize_prim_func, - get_output_blocks, - get_block, -) -from .base import GPUScheduleRule -from .matmul_analysis import auto_inline_producers, auto_inline_consumers - -logger = logging.getLogger(__name__) - - -class GEMVWithDequantizeInfo(GPUScheduleRule): - """A rule for Dequantized GEMV.""" - - def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ): - sch = tir.Schedule(func) - from .intrin import get_lop3_intrin_group - - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - if not check_dequantize_info(dequantize_info): - logger.debug("Dequantize info is not valid") - return None - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - if not check_weight_decode_info(weight_decode_info): - logger.debug("Weight Dequantize info is not valid") - return None - - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - reduction_block: tir.schedule.BlockRV = None - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if (any([ - sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) - ]) or len(sch.get_loops(block)) == 0): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - if len(r_loops) > 0: - reduction_block = block - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - def get_vectorize_factor(target_format): - # coalesced access requires the vectorize factor to be the same as the transaction size - return 128 // DataType(target_format).bits - - vec = get_vectorize_factor(weight_decode_info["target_format"]) - num_warps = 1 - warp_size = 32 - - block_b = reduction_block - output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 - B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) - - block_decode_B = sch.cache_read(block_b, 1, "local") - sch.compute_inline(B_decode_block) - - j, k = sch.get_loops(block_b)[-2:] - if len(sch.get_loops(block_b)) == 3: - i = sch.get_loops(block_b)[0] - sch.bind(i, "blockIdx.z") - elif len(sch.get_loops(block_b)) == 4: - # splitk case - sk, i = sch.get_loops(block_b)[:2] - sch.bind(sk, "blockIdx.y") - sch.bind(i, "blockIdx.z") - - # get target dequantize buffer's idx - def get_idx(weight_decode_info: Dict): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - block_shared_local_A = sch.cache_read(block_b, 0, "local") - block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") - block_local_C = sch.cache_write(block_b, 0, "local") - - auto_inline_producers(sch, block_shared_local_B) - auto_inline_consumers(sch, block_local_C) - - bx, j = sch.split(j, factors=[None, num_warps]) - k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) - # for dp4a/hfma2 - inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4 - _, vk = sch.split(vk, factors=[None, inst_factor]) - sch.reorder(bx, j, k, tx) - - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.bind(j, "threadIdx.y") - - self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] - self.grid_size = [sch.get(bx).extent, 1, 1] - - sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) - sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) - sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) - sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) - - block_local_a_v = sch.get_loops(block_shared_local_A)[-1] - sch.vectorize(block_local_a_v) - block_local_b_v = sch.get_loops(block_shared_local_B)[-1] - sch.vectorize(block_local_b_v) - - skip_blocks = [block_shared_local_B] - - if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": - if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: - block_local_scales = sch.cache_read(block_decode_B, - get_idx(weight_decode_info) + 1, "local") - sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_scales) - skip_blocks.append(block_local_scales) - - if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: - block_local_zeros = sch.cache_read(block_decode_B, - get_idx(weight_decode_info) + 2, "local") - sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - skip_blocks.append(block_local_zeros) - - auto_inline_producers(sch, block_decode_B, skip_blocks) - - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) - sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) - return sch - - def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - sch = tir.Schedule(func) - from .intrin import get_lop3_intrin_group - - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - if not check_dequantize_info(dequantize_info): - logger.debug("Dequantize info is not valid") - return None - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e5m2", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - if not check_weight_decode_info(weight_decode_info): - logger.debug("Weight Dequantize info is not valid") - return None - - block_infos = normalize_prim_func(sch) - - if block_infos is None: - return None - - reduction_block: tir.schedule.BlockRV = None - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block = block.block_rv - - if (any([ - sch.get(loop_rv).thread_binding is not None for loop_rv in sch.get_loops(block) - ]) or len(sch.get_loops(block)) == 0): - continue - - for loop, iter_type in zip(sch.get_loops(block), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block)) - if len(r_loops) > 0: - reduction_block = block - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - def get_vectorize_factor(target_format): - # coalesced access requires the vectorize factor to be the same as the transaction size - return config.arch.transaction_size[-1] // DataType(target_format).bits - - vec = get_vectorize_factor(weight_decode_info["target_format"]) - num_warps = int(prod(config.thread)) - warp_size = int(prod(config.reduce_thread)) - - block_b = reduction_block - output_blocks = get_output_blocks(sch, block_infos) # noqa: F841 - B_decode_block = get_block(sch, block_infos, weight_decode_info["decode_block"]) - - block_decode_B = sch.cache_read(block_b, 1, "local") - sch.compute_inline(B_decode_block) - - j, k = sch.get_loops(block_b)[-2:] - if len(sch.get_loops(block_b)) == 3: - i = sch.get_loops(block_b)[0] - sch.bind(i, "blockIdx.z") - elif len(sch.get_loops(block_b)) == 4: - # splitk case - sk, i = sch.get_loops(block_b)[:2] - sch.bind(sk, "blockIdx.y") - sch.bind(i, "blockIdx.z") - assert len(config.thread) == 2, "SplitK only support 2D thread config" - num_warps = int(num_warps // config.thread[0]) - - # get target dequantize buffer's idx - def get_idx(weight_decode_info: Dict): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - block_shared_local_A = sch.cache_read(block_b, 0, "local") - block_shared_local_B = sch.cache_read(block_decode_B, get_idx(weight_decode_info), "local") - block_local_C = sch.cache_write(block_b, 0, "local") - - auto_inline_producers(sch, block_shared_local_B) - auto_inline_consumers(sch, block_local_C) - - bx, j = sch.split(j, factors=[None, num_warps]) - k, tx, vk = sch.split(k, factors=[None, warp_size, vec]) - # for dp4a/hfma2 - inst_factor = 2 if weight_decode_info["target_format"] == "float16" else 4 - _, vk = sch.split(vk, factors=[None, inst_factor]) - sch.reorder(bx, j, k, tx) - - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.bind(j, "threadIdx.y") - - self.block_size = [sch.get(tx).extent, sch.get(j).extent, 1] - self.grid_size = [sch.get(bx).extent, 1, 1] - - sch.compute_at(block_decode_B, tx, preserve_unit_loops=True) - sch.compute_at(block_shared_local_A, tx, preserve_unit_loops=True) - sch.compute_at(block_shared_local_B, tx, preserve_unit_loops=True) - sch.reverse_compute_at(block_local_C, j, preserve_unit_loops=True) - - block_local_a_v = sch.get_loops(block_shared_local_A)[-1] - sch.vectorize(block_local_a_v) - block_local_b_v = sch.get_loops(block_shared_local_B)[-1] - sch.vectorize(block_local_b_v) - - skip_blocks = [block_shared_local_B] - - if "zeros_mode" in weight_decode_info and weight_decode_info["zeros_mode"] == "quantized": - if "with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]: - block_local_scales = sch.cache_read(block_decode_B, - get_idx(weight_decode_info) + 1, "local") - sch.compute_at(block_local_scales, tx, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_scales) - skip_blocks.append(block_local_scales) - - if "with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]: - block_local_zeros = sch.cache_read(block_decode_B, - get_idx(weight_decode_info) + 2, "local") - sch.compute_at(block_local_zeros, tx, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - skip_blocks.append(block_local_zeros) - - auto_inline_producers(sch, block_decode_B, skip_blocks) - - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize(sch.get_loops(block_decode_B)[-1], intrin_info["compute"]) - sch.annotate(block_b, ann_key="pragma_import_c", ann_val=intrin_info["c_source"]) - return sch - - def apply_config(self, func: PrimFunc, config): - if any([t > 1 for t in config.reduce_thread]): - return self.sch_inner_reduction_with_config(func, config) - else: - return None diff --git a/python/bitblas/gpu/general_reduction.py b/python/bitblas/gpu/general_reduction.py deleted file mode 100644 index cc03acd99..000000000 --- a/python/bitblas/gpu/general_reduction.py +++ /dev/null @@ -1,465 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=invalid-name -"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" -from typing import List, Union -from functools import reduce - -from tvm import tir -from tvm.target import Target - -from ..base import normalize_prim_func, try_inline_contiguous_spatial -from ..base.analysis import get_root_block, get_reduction_blocks, BlockInfo -from .base import GPUScheduleRule - - -class GeneralReduction(GPUScheduleRule): - """General Reduction rule for operators including softmax, layer norm, RMS norm, etc""" - - def apply( # pylint: disable=too-many-locals - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - - if target.kind.name == "cuda": - len_tx = 256 - unroll_depth = 256 - else: - len_tx = 64 - unroll_depth = 64 - - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if block_infos is None or len(block_infos) == 0: - return None - - dom_kind = block_infos[0].dom_kind() - num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) - num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) - - # Align the number of block iters of the last block. - num_last_block_iter = len(block_infos[-1].dom_kind()) - if num_last_block_iter < len(dom_kind): - index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), - ndim=num_last_block_iter, - ) - sch.transform_block_layout(block_infos[-1].block_rv, index_map) - - try: - # TODO: fix num_leading_s = 0 case - assert num_trailing_r > 0 - for block in block_infos[1:-1]: - assert block.dom_kind() == dom_kind - assert block_infos[-1].is_injective() - assert len(block_infos[-1].dom_kind()) <= len(dom_kind) - except AssertionError: - return None - - loops = sch.get_loops(block_infos[-1].block_rv) - bx = sch.fuse(*loops[:num_leading_s]) - r_loop, tx = sch.split(loops[-1], [None, len_tx]) - sch.reorder(tx, r_loop) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) - - for block in reversed(block_infos[:-1]): - block = block.block_rv - for i, _ in enumerate(sch.get(block).writes): - sch.set_scope(block, buffer_index=i, storage_scope="shared") - sch.compute_at(block, bx, preserve_unit_loops=True) - r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) - r_loop, tx = sch.split(r_loop, [None, len_tx]) - sch.reorder(tx, r_loop) - sch.bind(tx, "threadIdx.x") - sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) - - # TODO: It's just a workaround to avoid unroll spatial loops, because of the bug of - # the pass lower-thread-allreduce. We should fix it in the future. - # sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - # sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1) - return sch - - def sch_inner_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - block_factors = config.block - thread_factors = config.thread - reduce_therad_factors = config.reduce_thread - - # For inter thread reduction case, one thread must only compute one element - assert thread_factors == block_factors - - # inline all the other blocks - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - - schedule_block: tir.schedule.BlockRV = None - reduction_blocks: List[tir.schedule.BlockRV] = [] - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block_rv = block.block_rv - - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): - continue - - for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block_rv)) - if len(r_loops) > 0: - # always use the last reduction block for scheduling - schedule_block = block - reduction_blocks.append(block_rv) - - # Align the number of block iters of the last block. - dom_kind = schedule_block.dom_kind() - num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) - num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) - - schedule_block = schedule_block.block_rv - loops = sch.get_loops(schedule_block) - s_loops = loops[:num_leading_s] - r_loops = loops[-num_trailing_r:] - - block_axis = [] - thread_axis = [] - - for s_loop, block_factor in zip(s_loops, block_factors): - block_loop, thread_loop = sch.split(s_loop, factors=[None, block_factor]) - block_axis.append(block_loop) - thread_axis.append(thread_loop) - - axis_order = block_axis + thread_axis - - sch.reorder(*axis_order) - blck_fused = sch.fuse(*block_axis) - thrd_fused = sch.fuse(*thread_axis) - sch.bind(blck_fused, "blockIdx.x") - sch.bind(thrd_fused, "threadIdx.y") - - reduce_outer_axis, reduce_inner_axis, reduce_inter_threads = [], [], [] - for i in config.raxis_order: - loop = r_loops[i] - ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) - ri, thd = sch.split(ri, factors=[None, config.reduce_thread[i]]) - reduce_inter_threads.append(thd) - reduce_outer_axis.append(ro) - reduce_inner_axis.append(ri) - - axis_order = reduce_inter_threads + reduce_outer_axis + reduce_inner_axis - sch.reorder(*axis_order) - fused_reduce_inter_threads = sch.fuse(*reduce_inter_threads) - sch.bind(fused_reduce_inter_threads, "threadIdx.x") - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - reg_tile = sch.cache_write(schedule_block, 0, "local") - - # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. - for i, input_region in enumerate(sch.get(schedule_block).reads): - if input_region.buffer.name not in config.cached_tensors: - continue - - # otherwise cooperative fetch in shared memory. - cache_shared = sch.cache_read(schedule_block, i, "shared") - sch.compute_at(cache_shared, reduce_outer_axis[-1]) - - dim_offset = ( - len(reduce_inner_axis) + len(reduce_outer_axis) + 2 - ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 - - loops = sch.get_loops(cache_shared) - if len(loops) == dim_offset: - # handle fetching only one element - loops.append(sch.add_unit_loop(schedule_block)) - assert len(loops) > dim_offset - - _, ty, tx, tv = sch.split( - sch.fuse(*loops[dim_offset:]), - factors=[ - None, - int(prod(thread_factors)), - int(prod(reduce_therad_factors)), - vectorize, - ], - ) - sch.vectorize(tv) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - sch.reverse_compute_at(reg_tile, thrd_fused) - - # resolve compute_at - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if block_infos is None or len(block_infos) == 0: - return None - return sch - - def sch_outer_reduction_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - block_factors = config.block - thread_factors = config.thread - step_factors = config.step - - # inline all the other blocks - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - - schedule_block: BlockInfo = None - for block in block_infos: - s_loops: List[tir.schedule.LoopRV] = [] - r_loops: List[tir.schedule.LoopRV] = [] - o_loops: List[tir.schedule.LoopRV] = [] - dom_kind = block.dom_kind() - block_rv = block.block_rv - - if ( - any( - [ - sch.get(loop_rv).thread_binding is not None - for loop_rv in sch.get_loops(block_rv) - ] - ) - or len(sch.get_loops(block.block_rv)) == 0 - ): - continue - - for loop, iter_type in zip(sch.get_loops(block_rv), dom_kind): - {"S": s_loops, "R": r_loops, "O": o_loops}[iter_type].append(loop) - - if not s_loops: - s_loops.append(sch.add_unit_loop(block_rv)) - if len(r_loops) > 0: - # always use the last reduction block for scheduling - schedule_block = block - - # Align the number of block iters of the last block. - dom_kind = schedule_block.dom_kind() - num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) - num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) - - num_last_block_iter = len(block_infos[-1].dom_kind()) - if num_last_block_iter < len(dom_kind): - index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), - ndim=num_last_block_iter, - ) - sch.transform_block_layout(block_infos[-1].block_rv, index_map) - - schedule_block = schedule_block.block_rv - loops = sch.get_loops(schedule_block) - s_loops = loops[:num_leading_s] - r_loops = loops[-num_trailing_r:] - - reg_tile = sch.cache_write(schedule_block, 0, "local") - - block_axis = [] - vthread_axis = [] - thread_axis = [] - inner_axis = [] - for s_loop, block_factor, step_factor, thread_factor in zip( - s_loops, block_factors, step_factors, thread_factors - ): - block_loop, inner_loop = sch.split(s_loop, factors=[None, block_factor]) - vthread_loop, inner_loop = sch.split( - inner_loop, factors=[None, thread_factor * step_factor] - ) - thread_loop, inner_loop = sch.split(inner_loop, factors=[None, step_factor]) - block_axis.append(block_loop) - vthread_axis.append(vthread_loop) - thread_axis.append(thread_loop) - inner_axis.append(inner_loop) - - reduce_outer_axis, reduce_inner_axis = [], [] - for i in config.raxis_order: - loop = r_loops[i] - ro, ri = sch.split(loop, factors=[None, config.rstep[i]]) - reduce_outer_axis.append(ro) - reduce_inner_axis.append(ri) - - vthread_axis = list(reversed(vthread_axis)) # inner virtual thread first - axis_order = ( - block_axis - + vthread_axis - + thread_axis - + reduce_outer_axis - + reduce_inner_axis - + inner_axis - ) - - sch.reorder(*axis_order) - blck_fused = sch.fuse(*block_axis) - thrd_fused = sch.fuse(*thread_axis) - sch.bind(blck_fused, "blockIdx.x") - sch.bind(thrd_fused, "threadIdx.x") - if len(vthread_axis) > 3: - vthread_axis = vthread_axis[0:2] + [sch.fuse(*vthread_axis[2:])] - for i, ax in enumerate(vthread_axis): - sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) - - # todo(lei): should add the shared_inputs/stride memory pad analysis at shared memory fusion stage. - for i, input_region in enumerate(sch.get(schedule_block).reads): - if input_region.buffer.name not in config.cached_tensors: - continue - - # otherwise cooperative fetch in shared memory. - cache_shared = sch.cache_read(schedule_block, i, "shared") - sch.compute_at(cache_shared, reduce_outer_axis[-1]) - - dim_offset = ( - len(vthread_axis) + len(reduce_outer_axis) + 2 - ) # outer loops are: blck_fused, thrd_fused, vthread_axis, reduce_outer_axis - if input_region.buffer.name in config.vectorize: - vectorize = config.vectorize[input_region.buffer.name] - else: - vectorize = 1 - - loops = sch.get_loops(cache_shared) - if len(loops) == dim_offset: - # handle fetching only one element - loops.append(sch.add_unit_loop(schedule_block)) - assert len(loops) > dim_offset - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - _, tx, tv = sch.split( - sch.fuse(*loops[dim_offset:]), factors=[None, int(prod(thread_factors)), vectorize] - ) - sch.vectorize(tv) - sch.bind(tx, "threadIdx.x") - - sch.reverse_compute_at(reg_tile, thrd_fused) - - sch.decompose_reduction(schedule_block, reduce_outer_axis[0]) - - # resolve compute_at - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if block_infos is None or len(block_infos) == 0: - return None - - return sch - - def sch_mutiple_reductions_with_config( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - config, - ): - block_factors = config.block - thread_factors = config.thread - reduce_therad_factors = config.reduce_thread - - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if block_infos is None or len(block_infos) == 0: - return None - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - len_tx = prod(thread_factors) * prod(reduce_therad_factors) - block_factor = prod(block_factors) - - dom_kind = block_infos[0].dom_kind() - num_leading_s = len(dom_kind) - len(dom_kind.lstrip("S")) - num_trailing_r = len(dom_kind) - len(dom_kind.rstrip("R")) - - # Align the number of block iters of the last block. - num_last_block_iter = len(block_infos[-1].dom_kind()) - if num_last_block_iter < len(dom_kind): - index_map = tir.IndexMap.from_func( - lambda *iters: ( - [tir.const(0, iters[0].dtype)] * (len(dom_kind) - num_last_block_iter) - + list(iters) - ), - ndim=num_last_block_iter, - ) - sch.transform_block_layout(block_infos[-1].block_rv, index_map) - - try: - # TODO: fix num_leading_s = 0 case - assert num_trailing_r > 0 - for block in block_infos[1:-1]: - assert block.dom_kind() == dom_kind - assert block_infos[-1].is_injective() - assert len(block_infos[-1].dom_kind()) <= len(dom_kind) - except AssertionError: - return None - - loops = sch.get_loops(block_infos[-1].block_rv) - bx, _ = sch.split(sch.fuse(*loops[:num_leading_s]), factors=[None, block_factor]) - r_loop, tx = sch.split(loops[-1], [None, len_tx]) - sch.reorder(tx, r_loop) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - - for block in reversed(block_infos[:-1]): - block = block.block_rv - for i, _ in enumerate(sch.get(block).writes): - sch.set_scope(block, buffer_index=i, storage_scope="shared") - sch.compute_at(block, bx, preserve_unit_loops=True) - r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:]) - r_loop, tx = sch.split(r_loop, [None, len_tx]) - sch.reorder(tx, r_loop) - sch.bind(tx, "threadIdx.x") - - return sch - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> tir.Schedule: - # check the number of reduction blocks - sch = tir.Schedule(func) - root_block = get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - reduction_blocks = get_reduction_blocks(sch, blocks) - if len(reduction_blocks) > 1: - # schedule for multiple reduction blocks (e.g. softmax) - return self.sch_mutiple_reductions_with_config(func, config) - - if any([t > 1 for t in config.reduce_thread]): - # todo(lei) should implement block reduction schedule - return self.sch_inner_reduction_with_config(func, config) - else: - return self.sch_outer_reduction_with_config(func, config) diff --git a/python/bitblas/gpu/intrin/__init__.py b/python/bitblas/gpu/intrin/__init__.py deleted file mode 100644 index d9d9ba942..000000000 --- a/python/bitblas/gpu/intrin/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .lop3 import get_lop3_intrin_group # noqa: F401 diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py deleted file mode 100644 index b5426cf59..000000000 --- a/python/bitblas/gpu/intrin/lop3.py +++ /dev/null @@ -1,1667 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.tir.function import TensorIntrin -from tvm.script import tir as T -from typing import Dict, Literal -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_packed_to_unsigned_convert_with_zeros, -) - -decode_i4_to_f16 = """ -template -__device__ void decode_i4b_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x000f000f; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; - uint const i4s = *reinterpret_cast(_i4s); -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - } -} - -template -__device__ void decode_i4s_to_f16(T1 *_i4s, T2 *B_local_decode, const int N = 8) -{ - decode_i4b_to_f16(_i4s, B_local_decode, N); -} - -template -__device__ void decode_i4u_to_f16(T1 *_i4u, T2 *B_local_decode, const int N = 8) -{ - decode_i4b_to_f16(_i4u, B_local_decode, N); -} -""" - -decode_i4_to_f16_scale = """ -template -__device__ void decode_i4b_to_f16_scale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x000f000f; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; - uint const i4s = *reinterpret_cast(_i4s); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} - -template -__device__ void decode_i4s_to_f16_scale(T1 *_i4s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - decode_i4b_to_f16_scale(_i4s, B_local_decode, N, scale); -} - -template -__device__ void decode_i4u_to_f16_scale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - decode_i4b_to_f16_scale(_i4u, B_local_decode, N, scale); -} - -""" - -decode_i4_to_f16_scale_zeros_original = """ -template -__device__ void decode_i4b_to_f16_zeros_original(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x000f000f; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; - uint const i4s = *reinterpret_cast(_i4s); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - // input zeros maybe int32(qzeros) or half format - T4 const zero_r = *zeros; - uint const packed_zeros = __pack_half2(zero_r, zero_r); - - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} - -template -__device__ void decode_i4u_to_f16_scale_zeros_original(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) -{ - decode_i4b_to_f16_zeros_original(_i4u, B_local_decode, N, scale, zeros); -} -""" - -decode_i4_to_f16_scale_zeros_rescale = """ -template -__device__ void decode_i4b_to_f16_scale_zeros_rescale(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x000f000f; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - // Minus 7 to scale the value to signed - static constexpr uint MEDIAN_NUM = isSigned ? 0x64086408 : 0x64006400; - uint const i4s = *reinterpret_cast(_i4s); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - T4 const zero_r = *zeros; - uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); - } -} - -template -__device__ void decode_i4u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) -{ - decode_i4b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); -} - -""" - -decode_i4_to_f16_scale_zeros_quantized = """ -template -__device__ void decode_i4b_to_f16_scale_zeros_quantized(T1 *_i4s, T2 *B_local_decode, const int N = 8, const T3 *scale = nullptr, const T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x000f000f; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - // Minus 7 to scale the value to signed - uint const i4s = *reinterpret_cast(_i4s); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - // input zeros maybe int32(qzeros) or half format - T4 const zero_r = *zeros; - uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i4s >> (4 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); - - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} - -template -__device__ void decode_i4u_to_f16_scale_zeros_quantized(storage_dtype *_i4u, target_dtype *B_local_decode, scale_dtype *scale = nullptr, zero_dtype *zeros = nullptr, const int N = 8) -{ - decode_i4b_to_f16_scale_zeros_quantized(_i4u, B_local_decode, N, scale, zeros); -} -""" - -decode_i2_to_f16 = """ -template -__device__ void decode_i2b_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00030003; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; - int16_t const i2s_i16 = *reinterpret_cast(_i2s); - // decode 2 elems at one time. - // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} - // otherwise the pointer of _i2s should be moved to - int i2s = (i2s_i16 & 0x00ff); - i2s |= ((i2s_i16 & 0xff00) << 8); - -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - } -} - -template -__device__ void decode_i2s_to_f16(T1 *_i2s, T2 *B_local_decode, const int N = 8) -{ - decode_i2b_to_f16(_i2s, B_local_decode, N); -} - -template -__device__ void decode_i2u_to_f16(T1 *_i2u, T2 *B_local_decode, const int N = 8) -{ - decode_i2b_to_f16(_i2u, B_local_decode, N); -} -""" - -decode_i2_to_f16_scale = """ -template -__device__ void decode_i2b_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00030003; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; - int16_t const i2s_i16 = *reinterpret_cast(_i2s); - // decode 2 elems at one time. - // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} - // otherwise the pointer of _i2s should be moved to - int i2s = (i2s_i16 & 0x00ff); - i2s |= ((i2s_i16 & 0xff00) << 8); - -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); - } -} - -template -__device__ void decode_i2s_to_f16_scale(T1 *_i2s, T2 *B_local_decode, T3 *scale, const int N = 8) -{ - decode_i2b_to_f16_scale(_i2s, B_local_decode, scale, N); -} - -template -__device__ void decode_i2u_to_f16_scale(T1 *_i2u, T2 *B_local_decode, T3 *scale, const int N = 8) -{ - decode_i2b_to_f16_scale(_i2u, B_local_decode, scale, N); -} -""" - -decode_i2_to_f16_scale_zeros_original = """ -template -__device__ void decode_i2b_to_f16_scale_zeros_original(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00030003; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; - int16_t const i2s_i16 = *reinterpret_cast(_i2s); - // decode 2 elems at one time. - // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} - // otherwise the pointer of _i2s should be moved to - int i2s = (i2s_i16 & 0x00ff); - i2s |= ((i2s_i16 & 0xff00) << 8); - -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); - } -} - -template -__device__ void decode_i2u_to_f16_scale_zeros_original(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) -{ - decode_i2b_to_f16_scale_zeros_original(_i2u, B_local_decode, scale, zeros, N); -} -""" - -decode_i2_to_f16_scale_zeros_rescale = """ -template -__device__ void decode_i2b_to_f16_scale_zeros_rescale(T1 *_i2s, T2 *B_local_decode, T3 *scale = nullptr, T3 *zeros = nullptr, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00030003; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64026402 : 0x64006400; - int16_t const i2s_i16 = *reinterpret_cast(_i2s); - // decode 2 elems at one time. - // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} - // otherwise the pointer of _i2s should be moved to - int i2s = (i2s_i16 & 0x00ff); - i2s |= ((i2s_i16 & 0xff00) << 8); - -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*scale, *scale)), "r"(0)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(__pack_half2(*zeros, *zeros))); - } -} - -template -__device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_decode, T3 *scale, T3 *zeros, const int N = 8) -{ - decode_i2b_to_f16_scale_zeros_rescale(_i2u, B_local_decode, scale, zeros, N); -} -""" - -decode_i2_to_f16_scale_zeros_quantized = """ -template -__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00030003; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; - int16_t const i2s_i16 = *reinterpret_cast(_i2s); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - T4 const zero_r = *zeros; - uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); - - // decode 2 elems at one time. - // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} - // otherwise the pointer of _i2s should be moved to - int i2s = (i2s_i16 & 0x00ff); - i2s |= ((i2s_i16 & 0xff00) << 8); - -#pragma unroll - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); - - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} -template -__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) -{ - decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); -} -""" - -decode_i1_to_f16 = """ -template -__device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - } -} - -template -__device__ void decode_i1s_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 - - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); - } -} -""" - -decode_i1_to_f16_scale = """ -template -__device__ void decode_i1u_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} - -template -__device__ void decode_i1s_to_f16_scale(T1 *_i1s, T2 *B_local_decode, T3 *scale = nullptr, const int N = 8) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - static constexpr uint TRANSFORM_SUBTRACT = 0xbc00bc00; // for signed int 2x - 1 - // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 - - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(h[i])); - asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(TRANSFORM_SUBTRACT)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} -""" - -decode_i1_to_f16_scale_zeros_original = """ -template -__device__ void decode_i1b_to_f16_zeros_original(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - // input zeros maybe int32(qzeros) or half format - T4 const zero_r = *zeros; - uint const packed_zeros = __pack_half2(zero_r, zero_r); - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_zeros)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); - } -} -template -__device__ void decode_i1u_to_f16_scale_zeros_original(T1 *_i1u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) -{ - decode_i1b_to_f16_zeros_original(_i1u, B_local_decode, N, scale, zeros); -} -""" -decode_i1_to_f16_scale_zeros_rescale = """ -template -__device__ void decode_i1b_to_f16_scale_zeros_rescale(T1 *_i1s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) -{ - uint *h = reinterpret_cast(B_local_decode); - - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x00010001; - static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; - static constexpr uint MEDIAN_NUM = 0x64006400; - // interleave {e31,e29,e27,e25,e23,e21,e19,e17,e15,e13,e11,e9,e7,e5,e3,e1,e30,e28,e26,e24,e22,e20,e18,e16,e14,e12,e10,e8,e6,e4,e2,e0} - // only decode e7,e5,e3,e1,e8,e6,e4,e2,e0 - int8_t const i1s_i16 = *reinterpret_cast(_i1s); - int i1s = (i1s_i16 & 0x0f); - i1s |= ((i1s_i16 & 0xf0) << 12); - T3 const scale_r = *scale; - uint const packed_scales = __pack_half2(scale_r, scale_r); - T4 const zero_r = *zeros; - uint const packed_zeros = 0x80008000 | __pack_half2(zero_r, zero_r); - -#pragma unroll - // decode 2 elems at one time. - for (int i = 0; i < (N / 2); i++) - { - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(h[i]) - : "r"(i1s >> (1 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); - asm volatile("sub.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(MEDIAN_NUM)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(packed_zeros)); - } -} - -template -__device__ void decode_i1u_to_f16_scale_zeros_rescale(T1 *_i4u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) -{ - decode_i1b_to_f16_scale_zeros_rescale(_i4u, B_local_decode, N, scale, zeros); -} -""" - -decode_i1s_to_i8s = """template -__device__ void decode_i1s_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) -{ - int i8s[4]; - // vector load - *reinterpret_cast(i8s) = *reinterpret_cast(_i8s); - int16_t i1b_i16 = *reinterpret_cast(_i1b); - // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} - // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} - int i1b = (i1b_i16 & 0x0f0f); - i1b |= ((i1b_i16 & 0xf0f0) << 12); - // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} - // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // First, we extract the i1b and construct an intermediate fp16 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 - static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 - static constexpr uint I8s_MAGIC_NUM = 0x00000000; - static constexpr uint TRANSFORM_SUBTRACT = 0xffffffff; // for signed int 2x - 1 - - for (int i = 0; i < N / 4; i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); - i8s[i] = __vadd4(i8s[i], i8s[i]); - i8s[i] = __vadd4(i8s[i], TRANSFORM_SUBTRACT); - } - *reinterpret_cast(_i8s) = *reinterpret_cast(i8s); -} - -template -__device__ void decode_i1u_to_i8s(T1 *_i1b, T2 *_i8s, const int N = 16) -{ - int *i8s = reinterpret_cast(_i8s); - int16_t i1b_i16 = *reinterpret_cast(_i1b); - // permutate: {e0,e4,e8,e12,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15} - // into: {e0,e4,e8,e12,x,x,x,x,e1,e5,e9,x,x,x,x,e13,e2,e6,e10,e14,e1,e5,e9,e13,e3,e7,e11,e15,x,x,x,x} - int i1b = (i1b_i16 & 0x0f0f); - i1b |= ((i1b_i16 & 0xf0f0) << 12); - // i1b {0..,e15,e14,e13,e12,e11,e10,e9,e8,e7,e6,e5,e4,e3,e2,e1,e0} - // interleave {0..,e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} - // First, we extract the i1b and construct an intermediate fp16 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 - static constexpr uint BOTTOM_MASK = 0x01010101; // 0x1 -> 0b01 select 0,1 - static constexpr uint I8s_MAGIC_NUM = 0x00000000; - static constexpr uint MEDIAN_NUM = 0x00000000; - - for (int i = 0; i < N / 4; i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i1b >> i), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); - } -} - -""" - -decode_i2s_to_i8s = """template -__device__ void decode_i2s_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) -{ - // convert 8 int2b_t to 8 int8b_t -> 2 int32 - uint *i8s = reinterpret_cast(_i8s); - - // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} - // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} - uint const i2b = *reinterpret_cast(_i2b); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 - static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 - static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 - static constexpr uint MEDIAN_NUM = 0x02020202; -#pragma unroll - for (int i = 0; i < (N / 4); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); - i8s[i] = __vsub4(i8s[i], MEDIAN_NUM); - } -} -template -__device__ void decode_i2u_to_i8s(T1 *_i2b, T2 *_i8s, const int N = 16) -{ - // convert 8 int2b_t to 8 int8b_t -> 2 int32 - uint *i8s = reinterpret_cast(_i8s); - - // i2b = {e7,e6,e5,e4,e3,e2,e1,e0} - // also require interleave {e7,e3,e6,e2,e5,e1,e4,e0} - uint const i2b = *reinterpret_cast(_i2b); - - // First, we extract the i4s and construct an intermediate fp16 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; // 0b11101010 - static constexpr uint BOTTOM_MASK = 0x03030303; // 0xf -> 0b11 select 0,3 - static constexpr uint I8s_MAGIC_NUM = 0x00000000; // 1024 - -#pragma unroll - for (int i = 0; i < (N / 4); i++) - { - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i2b >> (2 * i)), "n"(BOTTOM_MASK), "n"(I8s_MAGIC_NUM), "n"(immLut)); - } -} -""" - -decode_i4s_to_i8s = """template -__device__ void decode_i4s_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) -{ - uint *i8s = reinterpret_cast(_i8s); - uint *i4b = reinterpret_cast(_i4b); - // First, we extract the i4s and construct an intermediate i8 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 - static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 - static constexpr uint MEDIAN_NUM = 0x07070707; -#pragma unroll - for (int i = 0; i < (N / 8); i++) - { - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i + 2]) - : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); - i8s[i] = __vsubss4(i8s[i], MEDIAN_NUM); - i8s[i + 2] = __vsubss4(i8s[i + 2], MEDIAN_NUM); - } -} - -template -__device__ void decode_i4u_to_i8s(T1 *_i4b, T2 *_i8s, const int N = 16) -{ - uint *i8s = reinterpret_cast(_i8s); - uint *i4b = reinterpret_cast(_i4b); - // First, we extract the i4s and construct an intermediate i8 number. - static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; - static constexpr uint BOTTOM_MASK = 0x0f0f0f0f; // 0xf -> 0b1111 select 0,4,8,12 - static constexpr uint I4b_TO_I8s_MAGIC_NUM = 0x00000000; // 0 -#pragma unroll - for (int i = 0; i < (N / 8); i++) - { - // Extract elt_01 - (i4s & 0x000f000f) | 0x64006400 - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i]) - : "r"(i4b[0] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); - - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" - : "=r"(i8s[i + 2]) - : "r"(i4b[1] >> (4 * i)), "n"(BOTTOM_MASK), "n"(I4b_TO_I8s_MAGIC_NUM), "n"(immLut)); - } -} -""" - - -def get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="uint", - target_dtype="float16", - loops_extent=8, - with_scale=False, - with_zeros=False, - zeros_mode="original", -): - """ - loops extent is the number of elements to be decoded in one stage - for memory friendly process, the loops_extent should be a multiple of (sizeof(int) // 8). - However, for the case of int1b, it is not possible to decode 8 elements in one stage, so we have to use 16. - """ - if target_dtype == "float16": - d4f = "f16" - elif target_dtype == "int8": - d4f = "i8s" - else: - raise ValueError("Unsupported target dtype: {}".format(target_dtype)) - source_symbol = "u" if source_format == "uint" else "s" - func_name = "decode_i{}{}_to_{}".format(source_bit, source_symbol, d4f) - if with_scale: - func_name += "_scale" - if with_zeros: - func_name += f"_zeros_{zeros_mode}" - assert storage_dtype in ["int8", "int32", "uint32"] - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - elem_per_unit = storage_nbit // source_bit - n_storage_elems = loops_extent // elem_per_unit - if with_zeros and zeros_mode == "quantized": - decode_func = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit) - elif source_format == "int": - if source_bit == 1: - decode_func = _tir_packed_int_to_int_convert(storage_type, storage_nbit) - else: - decode_func = _tir_packed_to_signed_convert(storage_type, storage_nbit) - elif source_format == "uint": - decode_func = _tir_packed_to_unsigned_convert(storage_type, storage_nbit) - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if with_scale is False: - - @T.prim_func - def fast_decode_desc(compressed: T.handle, decompressed: T.handle) -> None: - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - - with T.block("root"): - T.reads(Compressed[0:n_storage_elems]) - T.writes(Decompressed[0:loops_extent]) - for i in T.grid(loops_extent): - with T.block("decode"): - vi = T.axis.remap("S", [i]) - Decompressed[vi] = decode_func( - source_bit, - Compressed[vi // elem_per_unit], - vi % elem_per_unit, - dtype=target_dtype, - ) - - @T.prim_func - def fast_decode_impl(compressed: T.handle, decompressed: T.handle) -> None: - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - - with T.block("root"): - T.reads(Compressed[0:n_storage_elems]) - T.writes(Decompressed[0:loops_extent]) - T.call_extern( - "handle", - func_name, - Compressed.data, - Decompressed.data, - loops_extent, - ) - - elif with_zeros is False: - - @T.prim_func - def fast_decode_desc(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - scope="global", - ) - with T.block("root"): - T.reads(Compressed[0:n_storage_elems], Scale[0:1]) - T.writes(Decompressed[0:loops_extent]) - for i in T.grid(loops_extent): - with T.block("decode"): - vi = T.axis.remap("S", [i]) - Decompressed[vi] = ( - decode_func( - source_bit, - Compressed[vi // elem_per_unit], - vi % elem_per_unit, - dtype=target_dtype, - ) * Scale[0]) - - @T.prim_func - def fast_decode_impl(compressed: T.handle, decompressed: T.handle, scale: T.handle) -> None: - s0 = T.int32() - - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - offset_factor=1, - strides=[s0], - scope="global", - ) - with T.block("root"): - T.reads(Compressed[0:n_storage_elems], Scale[0:1]) - T.writes(Decompressed[0:loops_extent]) - T.call_extern( - "handle", - func_name, - Compressed.data, - Decompressed.data, - Scale.access_ptr("r"), - loops_extent, - ) - - elif zeros_mode == "quantized": - - def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): - if zeros_mode == "original": - return [weight, zeros, scale] - elif zeros_mode == "rescale": - return [weight, scale, zeros] - elif zeros_mode == "quantized": - return [weight, zeros, scale] - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - - def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): - if zeros_mode == "original": - return (weight - zeros) * scale - elif zeros_mode == "rescale": - return weight * scale - zeros - elif zeros_mode == "quantized": - return weight * scale - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - - # Scale with Zeros - @T.prim_func - def fast_decode_desc( - compressed: T.handle, - decompressed: T.handle, - scale: T.handle, - zeros: T.handle, - ) -> None: - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - scope="local", - ) - Zeros = T.match_buffer( - zeros, - [ - 1, - ], - dtype=storage_dtype, - scope="local", - ) - with T.block("root"): - T.reads(*get_dequantize_buffers_list( - Compressed[0:n_storage_elems], - Scale[0:1], - Zeros[0:1], - zeros_mode=zeros_mode, - )) - T.writes(Decompressed[0:loops_extent]) - for i in T.grid(loops_extent): - with T.block("decode"): - vi = T.axis.remap("S", [i]) - Decompressed[vi] = get_dequantize_func( - decode_func( - source_bit, - Compressed[vi // elem_per_unit], - vi % elem_per_unit, - Zeros[0], - dtype=target_dtype, - ), - Scale[0], - Zeros[0], - zeros_mode, - ) - - @T.prim_func - def fast_decode_impl( - compressed: T.handle, - decompressed: T.handle, - scale: T.handle, - zeros: T.handle, - ) -> None: - s0 = T.int32() - s1 = T.int32() - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - offset_factor=1, - strides=[s0], - scope="local", - ) - Zeros = T.match_buffer( - zeros, - [ - 1, - ], - dtype=storage_dtype, - offset_factor=1, - strides=[s1], - scope="local", - ) - with T.block("root"): - T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) - T.writes(Decompressed[0:loops_extent]) - T.call_extern( - "handle", - func_name, - Compressed.data, - Decompressed.data, - Scale.access_ptr("r"), - Zeros.access_ptr("r"), - loops_extent, - ) - - else: - - def get_dequantize_buffers_list(weight, scale, zeros, zeros_mode="original"): - if zeros_mode == "original": - return [weight, zeros, scale] - elif zeros_mode == "rescale": - return [weight, scale, zeros] - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - - def get_dequantize_func(weight, scale, zeros, zeros_mode="original"): - if zeros_mode == "original": - return (weight - zeros) * scale - elif zeros_mode == "rescale": - return weight * scale - zeros - else: - raise ValueError(f"Unsupported zeros_mode: {zeros_mode}") - - # Scale with Zeros - @T.prim_func - def fast_decode_desc( - compressed: T.handle, - decompressed: T.handle, - scale: T.handle, - zeros: T.handle, - ) -> None: - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - scope="global", - ) - Zeros = T.match_buffer( - zeros, - [ - 1, - ], - dtype=target_dtype, - scope="global", - ) - with T.block("root"): - T.reads(*get_dequantize_buffers_list( - Compressed[0:n_storage_elems], - Scale[0:1], - Zeros[0:1], - zeros_mode=zeros_mode, - )) - T.writes(Decompressed[0:loops_extent]) - for i in T.grid(loops_extent): - with T.block("decode"): - vi = T.axis.remap("S", [i]) - Decompressed[vi] = get_dequantize_func( - decode_func( - source_bit, - Compressed[vi // elem_per_unit], - vi % elem_per_unit, - dtype=target_dtype, - ), - Scale[0], - Zeros[0], - zeros_mode, - ) - - @T.prim_func - def fast_decode_impl( - compressed: T.handle, - decompressed: T.handle, - scale: T.handle, - zeros: T.handle, - ) -> None: - s0 = T.int32() - s1 = T.int32() - Compressed = T.match_buffer( - compressed, - [ - n_storage_elems, - ], - dtype=storage_dtype, - scope="local", - ) - Decompressed = T.match_buffer( - decompressed, - [ - loops_extent, - ], - dtype=target_dtype, - scope="local", - ) - Scale = T.match_buffer( - scale, - [ - 1, - ], - dtype=target_dtype, - offset_factor=1, - strides=[s0], - scope="global", - ) - Zeros = T.match_buffer( - zeros, - [ - 1, - ], - dtype=target_dtype, - offset_factor=1, - strides=[s1], - scope="global", - ) - with T.block("root"): - T.reads(Compressed[0:n_storage_elems], Scale[0:1], Zeros[0:1]) - T.writes(Decompressed[0:loops_extent]) - T.call_extern( - "handle", - func_name, - Compressed.data, - Decompressed.data, - Scale.access_ptr("r"), - Zeros.access_ptr("r"), - loops_extent, - ) - - return fast_decode_desc, fast_decode_impl - - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_int32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN = ("lop3_fast_decode_u4_to_uint32_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="uint32", target_dtype="float16", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_uint32_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_UINT32_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="uint32", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u4_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( - "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="quantized", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_original_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_ORIGINAL_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="original", - ), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN = ( - "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_zeros_rescale_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_ZEROS_RESCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - target_dtype="float16", - loops_extent=8, - with_scale=True, - with_zeros=True, - zeros_mode="rescale", - ), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=8), -) - -LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u4_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT4_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=4, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i2_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, storage_dtype="int8", target_dtype="int8", loops_extent=16), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_i1_to_int8_to_i8_l16_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - source_format="int", - storage_dtype="int8", - target_dtype="int8", - loops_extent=16), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i4_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=4, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i2_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i2_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT2_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=2, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - ), -) - -LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( - "lop3_fast_decode_i1_to_int8_to_f16_l8_scale_") -TensorIntrin.register( - LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN, - *get_fast_decode_intrin( - source_bit=1, - storage_dtype="int8", - source_format="int", - target_dtype="float16", - loops_extent=8, - with_scale=True, - ), -) - - -def get_lop3_intrin_group( - out_dtype: Literal["float16", "int8"], - source_format: Literal["int", "uint"] = "uint", - source_bit: int = 4, - storage_dtype: Literal["int32", "int8"] = "int8", - with_scaling: bool = False, - with_zeros: bool = False, - zeros_mode: Literal["original", "rescale", "quantized"] = "original", -) -> Dict[str, str]: - """ - This function is used to get the intrinsic group of the LOP3 operation to avoid the overhead of fast decoding. - LOP3 is a type of logic operation that takes three inputs. The intrinsic group refers to the set of - intrinsic operations that can be performed on these inputs. This function retrieves and returns this group. - - Parameters - ---------- - in_dtype : Literal["int8"] - The data type of the input. It should be "int8". - - out_dtype : Literal["float16", "int8"] - The data type of the output. It can be either "float16" or "int8". - - storage_nbit : int, optional - The number of bits used for storage. By default, it is 4. - - with_scale : bool, optional - A boolean parameter that indicates whether scaling should be applied. By default, it is False. - - Returns - ------- - Dict[str, str] - A dictionary mapping the names of the intrinsics to their corresponding implementations. - """ - assert out_dtype in ["float16", "int8"] - - dtype_mapping = {"float16": "f16", "int8": "i8", "int32": "i32"} - target_dtype = dtype_mapping[out_dtype] - target_bits = tvm.DataType(out_dtype).bits - loop_extent = 128 // target_bits - if source_format not in ["int", "uint"]: - raise ValueError("Invalid source_format. Expected 'int' or 'uint'.") - source_symbol = "i" if source_format == "int" else "u" - - _intrin = f"lop3_fast_decode_{source_symbol}{source_bit}_to_{storage_dtype}_to_{target_dtype}_l{loop_extent}_" - if with_scaling: - _intrin += "scale_" - if with_zeros: - _intrin += f"zeros_{zeros_mode}_" - - import_c_map = { - "i4_to_f16": decode_i4_to_f16, - "i2_to_f16": decode_i2_to_f16, - "i1_to_f16": decode_i1_to_f16, - "i4_to_f16_scale": decode_i4_to_f16_scale, - "i2_to_f16_scale": decode_i2_to_f16_scale, - "i1_to_f16_scale": decode_i1_to_f16_scale, - "i4_to_f16_scale_zeros_original": decode_i4_to_f16_scale_zeros_original, - "i2_to_f16_scale_zeros_original": decode_i2_to_f16_scale_zeros_original, - "i1_to_f16_scale_zeros_original": decode_i1_to_f16_scale_zeros_original, - "i4_to_f16_scale_zeros_rescale": decode_i4_to_f16_scale_zeros_rescale, - "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, - "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, - "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, - "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, - "i1_to_i8": decode_i1s_to_i8s, - "i2_to_i8": decode_i2s_to_i8s, - "i4_to_i8": decode_i4s_to_i8s, - } - key = f"i{source_bit}_to_{target_dtype}" - if with_scaling: - key += "_scale" - if with_zeros: - key += f"_zeros_{zeros_mode}" - - return { - "c_source": import_c_map[key], - "compute": _intrin, - } diff --git a/python/bitblas/gpu/matmul.py b/python/bitblas/gpu/matmul.py deleted file mode 100644 index ad450eff2..000000000 --- a/python/bitblas/gpu/matmul.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring, invalid-name -"""A GEMM schedule rule for GPU operators.""" -from dataclasses import dataclass -from typing import Optional - -from tvm import tir -from tvm.target import Target -from tvm.tir.stmt import ForKind - -from ..base import analysis -from .base import GPUScheduleRule -from . import utils -from .matmul_analysis import ( - auto_inline_consumer_chain, - auto_inline_producers, - get_in_out_dtypes, - get_index_map, - normalize_to_matmul, - get_reduction_blocks, -) -from .matmul_mma import MatmulTensorizationMMA -from .matmul_wmma import ( - MatmulInt8Tensorization, - MatmulTensorizationWMMA, -) -from functools import reduce -import logging - -logger = logging.getLogger(__name__) - - -class Matmul(GPUScheduleRule): - """The schedule rule for matmul-like computation""" - - @dataclass - class Config: - block_size_x: int = 8 - block_size_y: int = 8 - vthread_x: int = 1 - vthread_y: int = 1 - micro_size_x: int = 4 - micro_size_y: int = 4 - micro_size_k: int = 8 - vector_size: int = 1 - unroll: int = 256 # 0 means no unroll - use_shared: bool = True - storage_align: bool = False - inner_x: bool = False - - def get_configs(self, target: Target) -> Config: - """Get the schedule config for the target""" - if target.kind.name == "cuda" or target.kind.name == "rocm": - return Matmul.Config( - block_size_x=8, - block_size_y=16, - vthread_x=1, - vthread_y=1, - micro_size_x=4, - micro_size_y=4, - micro_size_k=16, - vector_size=2, - unroll=256, - use_shared=True, - storage_align=True, - inner_x=False, - ) - elif target.kind.name == "opencl" and "android" in str(target.host): - return Matmul.Config( - block_size_x=8, - block_size_y=8, - vthread_x=1, - vthread_y=1, - micro_size_x=8, - micro_size_y=2, - micro_size_k=16, - vector_size=8, - unroll=64, - use_shared=False, - storage_align=False, - inner_x=True, - ) - else: - return Matmul.Config() - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - sch = normalize_to_matmul(sch, main_block) - if sch is None: - return None - - # Step 1. Check Tensor Core support - # Tensorization config: - # If any value of I, J, K is fixed and less than this threshold, - # tensorization rule will not be applied. - minimal_tensorize_threshold = 64 - block_stmt = sch.get(main_block) - if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: - apply_tensorization: bool = True - # the batch dimension is not taken into consideration. - # Analyze read/write buffers and choose correct tensorizer: int8 or fp16. - in_dtype, out_dtype = get_in_out_dtypes(block_stmt) - if in_dtype not in ["int8", "float16"]: - apply_tensorization = False - for item_var in block_stmt.iter_vars[1:]: - extent = item_var.dom.extent - if isinstance(extent, - tir.expr.IntImm) and extent.value <= minimal_tensorize_threshold: - apply_tensorization = False - if apply_tensorization: - if in_dtype == "int8" and out_dtype == "int32": - tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) - elif utils.get_sm_version(target) >= 80: - # For A100(sm_80) or more advanced gpu, use MMA tensorization. - tensorize_sch = MatmulTensorizationMMA().apply(func, target, _) - else: - # For other GPUs, use WMMA tensorization. - tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) - if tensorize_sch is not None: - return tensorize_sch - - # Step 2. Get schedule config. - config = self.get_configs(target) - - # Step 3. Schedule matmul - y_kernel_size = config.vthread_y * config.block_size_y * config.micro_size_y - x_kernel_size = config.vthread_x * config.block_size_x * config.micro_size_x - if config.inner_x: - sch.pad_einsum( - main_block, - [1, y_kernel_size, x_kernel_size, config.micro_size_k], - ) - batch, y, x, k = sch.get_loops(main_block) - else: - sch.pad_einsum( - main_block, - [1, x_kernel_size, y_kernel_size, config.micro_size_k], - ) - batch, x, y, k = sch.get_loops(main_block) - by, vy, ty, yi = sch.split( - y, [None, config.vthread_y, config.block_size_y, config.micro_size_y]) - bx, vx, tx, xi = sch.split( - x, [None, config.vthread_x, config.block_size_x, config.micro_size_x]) - ko, ki = sch.split(k, factors=[None, config.micro_size_k]) - sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) - by = sch.fuse(batch, by) - sch.bind(bx, "blockIdx.x") - sch.bind(by, "blockIdx.y") - sch.bind(vy, "vthread.y") - sch.bind(vx, "vthread.x") - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - inner_loop = config.micro_size_x if config.inner_x else config.micro_size_y - if inner_loop % config.vector_size == 0: - _, v = sch.split(xi, [None, config.vector_size]) - sch.vectorize(v) - - if config.unroll > 0: - sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=config.unroll) - sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) - - l2g = sch.cache_write(main_block, 0, "local") - sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) - if config.micro_size_x % config.vector_size == 0: - _, v = sch.split(sch.get_loops(l2g)[-1], [None, config.vector_size]) - sch.vectorize(v) - - if config.use_shared: - - def _cooperative_fetch(index, vec_len): - block = sch.cache_read(main_block, index, "shared") - num_loops = len(sch.get_loops(block)) - sch.compute_at(block, ko, preserve_unit_loops=True) - loops = sch.get_loops(block)[-num_loops:] - ty, tx, _, vec = sch.split( - sch.fuse(*loops), - factors=[config.block_size_y, config.block_size_x, None, vec_len], - ) - sch.vectorize(vec) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - if config.storage_align: - sch.storage_align(block, 0, axis=1, factor=8, offset=vec_len) - return block - - a_g2s = _cooperative_fetch(0, vec_len=config.vector_size) - b_g2s = _cooperative_fetch(1, vec_len=config.vector_size) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - else: - auto_inline_producers(sch, main_block) - - auto_inline_consumer_chain(sch, l2g) - sch.decompose_reduction(main_block, ko) - - # Step 4. Check if there are unbound blocks. Execute fallback scheduling to them. - def is_scheduled(block: tir.schedule.BlockRV) -> bool: - loops = sch.get_loops(block) - loop_kinds = {sch.get(loop).kind for loop in loops} - return loop_kinds != {ForKind.SERIAL} - - blocks = sch.get_child_blocks(root_block) - max_threads_per_block = utils.max_threads_per_block(target) # noqa: F841 - for block in blocks: - if is_scheduled(block): - continue - # no axis of the block is bound to thread or block - s_loops = sch.get_loops(block) - bx, tx = sch.split( - sch.fuse(*s_loops), - factors=[ - None, - 256, - ], - ) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - - return sch - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> tir.Schedule: - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - # in some case conv template will use this rule, but the tile config is not - # analyzed by matmul expr. - if len(config.block) != 2: - logger.debug(f"Warning: block config {config.block} is not valid for matmul, skip.") - return None - - main_block = reduction_blocks[0] - - block_stmt = sch.get(main_block) - - # cuda core prefer b is [k, j] layout without swizzling. - index_maps = get_index_map(block_stmt, ["n", "n", "n"]) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Get schedule config. - block_row_warps = config.block[0] // (config.thread[0] * config.step[0]) - block_col_warps = config.block[1] // (config.thread[1] * config.step[1]) - thread_row_tiles = config.thread[1] // (config.step[0] * 2) - thread_col_tiles = config.thread[1] // (config.step[1] * 2) - vthread_row_tiles = (config.step[0] * 2) # expand vtrhead to avoid load band conflict - vthread_col_tiles = (config.step[1] * 2) # expand vtrhead to avoid load band conflict - chunk = config.rstep[0] - - # Step 3. Schedule matmul - BM = block_row_warps * vthread_row_tiles * thread_row_tiles - BN = block_col_warps * vthread_col_tiles * thread_col_tiles - BK = chunk - - sch.pad_einsum( - main_block, - [1, BM, BN, BK], - ) - batch, y, x, k = sch.get_loops(main_block) - by, vy, ty, yi = sch.split(y, [None, vthread_row_tiles, block_row_warps, thread_row_tiles]) - bx, vx, tx, xi = sch.split(x, [None, vthread_col_tiles, block_col_warps, thread_col_tiles]) - ko, ki = sch.split(k, factors=[None, BK]) - sch.reorder(by, bx, vy, vx, ty, tx, ko, ki, yi, xi) - by = sch.fuse(batch, by) - sch.bind(bx, "blockIdx.x") - sch.bind(by, "blockIdx.y") - sch.bind(vy, "vthread.y") - sch.bind(vx, "vthread.x") - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - def prod(iterable): - return reduce(lambda x, y: x * y, iterable, 1) - - l2g = sch.cache_write(main_block, 0, "local") - sch.reverse_compute_at(l2g, tx, preserve_unit_loops=True) - - def _cooperative_fetch(index, vec_len): - block = sch.cache_read(main_block, index, "shared") - num_loops = len(sch.get_loops(block)) - block_local = sch.cache_read(main_block, index, "local") - sch.compute_at(block_local, ki, preserve_unit_loops=True) - sch.compute_at(block, ko, preserve_unit_loops=True) - loops = sch.get_loops(block)[-num_loops:] - _, ty, tx, vec = sch.split( - sch.fuse(*loops), - factors=[None, block_row_warps, block_col_warps, vec_len], - ) - - auto_inline_producers(sch, block) - - def is_trivial_load(block): - # avoid vectorize under global[v2, v1]] shared[v1, v2] case - reads = sch.get(block).reads - writes = sch.get(block).writes - if len(reads) != 1 or len(writes) != 1: - return False - return all( - read.region[-1] == write.region[-1] for read, write in zip(reads, writes)) - - if is_trivial_load(block): - sch.vectorize(vec) - - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - - _, vec = sch.split( - sch.fuse(*sch.get_loops(block_local)[-2:]), - [None, vec_len // prod(config.step)], - ) - sch.vectorize(vec) - - return block - - for i, input_region in enumerate(sch.get(main_block).reads): - _buffer_name = input_region.buffer.name.replace("_reindex", "").replace("_pad", "") - if _buffer_name not in config.cached_tensors: - logger.warning( - f"Warning: {_buffer_name} is not in cached_tensors {config.cached_tensors}, skip." - ) - continue - - # otherwise cooperative fetch in shared memory. - vectorize = config.vectorize.get(_buffer_name, 1) - - _cooperative_fetch(i, vec_len=vectorize) - - auto_inline_consumer_chain(sch, l2g) - - _, vec = sch.split( - sch.fuse(*sch.get_loops(l2g)[-2:]), [None, vectorize // prod(config.step)]) - sch.vectorize(vec) - - sch.decompose_reduction(main_block, ko) - return sch diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py deleted file mode 100644 index 6537a555a..000000000 --- a/python/bitblas/gpu/matmul_analysis.py +++ /dev/null @@ -1,786 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring, invalid-name -"""A GEMM schedule rule for GPU operators.""" -from dataclasses import dataclass -from enum import Enum -from typing import List, Optional, Set, Union, Tuple, Dict -from tvm import tir -from tvm.ir import Range -from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap -from tvm.tir.analysis import undefined_vars -from tvm.tir.schedule.schedule import BlockRV -from ..base.analysis import ( - collect_block_iter_vars_used_in_access_region, - get_root_block, - get_reduction_blocks, -) -from tvm.target.target import Target -from tvm.tir.stmt_functor import pre_order_visit -import logging - -logger = logging.getLogger(__name__) - - -def collect_vars_from_expr(prim_expr): - vars = [] - - def callback(node): - if isinstance(node, Var): - vars.append(node) - return True - - pre_order_visit(prim_expr, callback) - - return vars - - -def _is_one(x: PrimExpr) -> bool: - return isinstance(x, tir.IntImm) and x.value == 1 - - -def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): - result = [] - for producer in sch.get_producers(block): - result.append(producer) - result.extend(_collect_producers(sch, producer)) - return result - - -def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): - result = [] - for consumer in sch.get_consumers(block): - result.append(consumer) - result.extend(_collect_consumers(sch, consumer)) - return result - - -def auto_inline_producers( - sch: tir.Schedule, - block: tir.schedule.BlockRV, - skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, -): - skip_blocks = skip_blocks or [] - while True: - inlined_cnt = 0 - producers = _collect_producers(sch, block) - for producer in producers: - if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): - continue - try: - sch.compute_inline(producer) - inlined_cnt += 1 - except Exception: # pylint: disable=bare-except - continue - if inlined_cnt == 0: - return - - -def auto_inline_consumers( - sch: tir.Schedule, - block: tir.schedule.BlockRV, -): - while True: - inlined_cnt = 0 - consumers = _collect_consumers(sch, block) - for consumer in consumers: - try: - sch.compute_inline(consumer) - inlined_cnt += 1 - except Exception: # pylint: disable=bare-except - continue - for consumer in consumers: - try: - sch.reverse_compute_inline(consumer) - inlined_cnt += 1 - except Exception: # pylint: disable=bare-except - continue - if inlined_cnt == 0: - return - - -def auto_inline_consumer_chain( - sch: tir.Schedule, - block: tir.schedule.BlockRV, -): - auto_inline_consumers(sch, block) - remaining_consumers = sch.get_consumers(block) - - if len(remaining_consumers) != 0: - # Some blocks have failed to be inlined to the producer cache-write stage. - # This could be due to another producer block that has not been scheduled. - for c in remaining_consumers: - for p in sch.get_producers(c): - if sch.get(p) != sch.get(block): - sch.compute_inline(p) - - # Try inlining into the cache-write stage again, this time it should succeed. - auto_inline_consumers(sch, block) - - -# used to match the similar region with dequantize op. -def find_first_similar_region(regions: List[BufferRegion], buffer: tir.Buffer): - for region in regions: - if len(region.buffer.shape) == len(buffer.shape): - return region - return None - - -# used to match the similar buffer with dequantize op. -def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): - for region in regions: - if len(region.buffer.shape) == len(buffer.shape): - return region.buffer - return None - - -# find the block that required to be reindex and scope. -def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: - # block that most near to the arguments - block = main_block - buffer = buffer - - while True: - last_buffer = buffer - producers = sch.get_producers(block) - - if len(producers) == 0: - # do not have any producer means it is the first block - break - - for producer in producers: - for write in sch.get(producer).writes: - if write.buffer == buffer: - block = producer - buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer) - if buffer == last_buffer: - break - return block - - -def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, - buffer: tir.Buffer) -> int: - """traverse to find the arg index from the buffer""" - producers = sch.get_producers(main_block) - - # a head buffer has no producer blocks - def find_args_index(sch: tir.Schedule, buffer: tir.Buffer): - for i, param in enumerate(sch.mod["main"].params): - if sch.mod["main"].buffer_map[param] == buffer: - return i - return None - - is_head_buffer = len(producers) == 0 - if is_head_buffer: - return find_args_index(sch, buffer) - for block in sch.get_producers(main_block): - if len(sch.get(block).reads) != 1 or len(sch.get(block).writes) != 1: - continue - for write in sch.get(block).writes: - if write.buffer == buffer: - return find_arg_idx_from_buffer_chain(sch, block, buffer) - - # if no buffer producer block found, it means the buffer is an input buffer - return find_args_index(sch, buffer) - - -class IterKind(Enum): - """Iter kinds for GEMM-liked programs. - We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], - where `I, J, K` are fundamental axes for gemm and `S` represents all - other spatial axes (e.g. batches) - kIter_S: spatial axes - kIter_I: I axes - kIter_J: J axes - kIter_K: K axes - kIter_T: trivial axes (i.e. with extent 1) - """ - - kIter_S = 0 - kIter_I = 1 - kIter_J = 2 - kIter_K = 3 - kIter_T = 4 - - -@dataclass -class IterTrait: - kind: IterKind - extent: PrimExpr - - -def make_iter_fusion_index_map( - traits: List[IterTrait], - kind_order: List[IterKind], -) -> tir.IndexMap: - fused_iters: Dict[IterKind, PrimExpr] = {} - input_iters: List[tir.Var] = [] - for i, trait in enumerate(traits): - v_i = tir.Var(f"i{i}", trait.extent.dtype) - input_iters.append(v_i) - if trait.kind == IterKind.kIter_T: - continue - if trait.kind not in kind_order: - raise ValueError(f"Unknown iter kind {trait.kind}") - if trait.kind in fused_iters: - fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i - else: - fused_iters[trait.kind] = v_i - - final_indices: List[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order - ] - - return tir.IndexMap(input_iters, final_indices, None) - - -def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: - """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] - - Parameters - ---------- - block : tir.Block - The block to be analyzed - - Returns - ------- - traits : Optional[Tuple[List[IterTrait]]] - The detected iter traits for axes in A, B and C. None if the block - does not match the pattern. - - """ - - if len(block.reads) != 2 or len(block.writes) != 1: - return None - - def get_access_axes(region: List[Range]) -> Set[Var]: - axes: Set[Var] = set() - for r in region: - if not _is_one(r.extent): - raise ValueError("Expect elemwise block access") - axes = axes.union(set(undefined_vars(r.min))) - return axes - - try: - A_axes = get_access_axes(block.reads[0].region) - B_axes = get_access_axes(block.reads[1].region) - C_axes = get_access_axes(block.writes[0].region) - except ValueError: - return None - - traits: Dict[Var, IterTrait] = {} - for iter_var in block.iter_vars: - var = iter_var.var - kind: IterKind - if _is_one(iter_var.dom.extent): - if iter_var.iter_type == tir.IterVar.CommReduce: - # for simplified case (e.g. 1x1 conv kernel) - kind = IterKind.kIter_K - else: - kind = IterKind.kIter_T - elif iter_var.iter_type == iter_var.DataPar: - if var in A_axes and var in B_axes and var in C_axes: - kind = IterKind.kIter_S - elif var in A_axes and var in C_axes: - kind = IterKind.kIter_I - elif var in B_axes and var in C_axes: - kind = IterKind.kIter_J - else: - return None - elif iter_var.iter_type == tir.IterVar.CommReduce: - if var in A_axes and var in B_axes and var not in C_axes: - kind = IterKind.kIter_K - else: - return None - else: - return None - traits[var] = IterTrait(kind, iter_var.dom.extent) - - # A Gemm-kernel requires have I, J and K axes - gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} - if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: - return None - - A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] - B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] - C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] - block_traits = [traits[i.var] for i in block.iter_vars] - return A_traits, B_traits, C_traits, block_traits - - -def get_index_map(block: tir.Block, - layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: - """Get index maps for the block - - Parameters - ---------- - block : tir.Block - The block to be analyzed - - layout : List[str] - the target layout index map to be used. - 'n' for [i, k] layout - 't' for [k, j] layout - 'a' for auto inference based on whether the last axis is reduction. - - Returns - ------- - index_maps : Optional[Tuple[tir.IndexMap]] - The index maps for the block, or None if the block is not a gemm-liked kernel - """ - if layout is None: - layout = ["n", "t", "n"] - traits = detect_iter_traits(block) - if traits is None: - return None - A_traits, B_traits, C_traits, block_traits = traits - - def get_ordered_axes(region: List[Range]) -> Set[Var]: - axes: List[Var] = [] - for r in region: - if not _is_one(r.extent): - raise ValueError("Expect elemwise block access") - axes.append(r.min) - return axes - - def is_common_reduce(var: Var) -> bool: - for iter_var in block.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False - - def has_common_reduce(var: Var) -> bool: - vars = collect_vars_from_expr(var) - return any(is_common_reduce(v) for v in vars) - - def check_last_trait(region: List[Range]): - axes = get_ordered_axes(region) - return has_common_reduce(axes[-1]) - - def infer_layout(layout: str, region: List[Range], kind: str = "A"): - """ - Infer the layout based on the region and the kind of buffer - kind: "A", "B", "C" - """ - primary_iter, secondary_iter, reduction_iter = { - "A": (IterKind.kIter_I, IterKind.kIter_K, IterKind.kIter_K), - "B": (IterKind.kIter_K, IterKind.kIter_J, IterKind.kIter_K), - "C": (IterKind.kIter_I, IterKind.kIter_J, None), - }[kind] - - spatial_iter = { - "A": IterKind.kIter_I, - "B": IterKind.kIter_J, - "C": None, - }[kind] - - if layout == "n": - return [IterKind.kIter_S, primary_iter, secondary_iter] - elif layout == "t": - return [IterKind.kIter_S, secondary_iter, primary_iter] - elif layout == "a": - # auto inference layout - # for buffer with reduction axis, we put it as the last axis - # otherwise, we put it as the first axis - if kind == "C": - return [IterKind.kIter_S, primary_iter, secondary_iter] - else: - return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter]) - else: - raise ValueError(f"Unknown layout {layout}") - - A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) - B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) - C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) - - matmul_index_map = make_iter_fusion_index_map( - block_traits, - [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K], - ) - - return ( - matmul_index_map, - A_index_map, - B_index_map, - C_index_map, - ) - - -def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: - """ - Detect In/Out data types for the given block based on the analysis if read/write buffers. - """ - assert len(block.reads) > 0 and len(block.writes) > 0 - in_dtype = block.reads[0].buffer.dtype - out_dtype = block.writes[0].buffer.dtype - return (in_dtype, out_dtype) - - -def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: - # check at least two input and one output - # at lease one input has uint dtype, and the output dtype is float - def is_dequantize(block: BlockRV) -> bool: - block_stmt = sch.get(block) - if len(block_stmt.reads) < 2: - return False - has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) - if not has_uint_input: - return False - if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): - return False - return True - - dequantize_blocks = [block for block in blocks if is_dequantize(block)] - return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None - - -def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - if iter_types != {IterVar.DataPar}: - return False, False - if not isinstance(block_stmt.body, tir.BufferStore): - return False, False - if not isinstance(block_stmt.body.value, tir.BufferLoad): - return False, False - - def get_access_vars(region: List[Range]) -> List[Var]: - axes: List[Var] = [] - for r in region: - if not _is_one(r.extent): - return None - axes.extend(undefined_vars(r.min)) - # remove trivial axis - trivial_vars = set( - iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) - axes = [axis for axis in axes if axis not in trivial_vars] - # remove duplicate axis - axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] - return axes - - lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] - rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] - is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( - rhs_access_vars) - return is_identity, is_transpose - - -def is_identity_block(block_stmt: tir.Block) -> bool: - return is_identity_or_transpose_block(block_stmt)[0] - - -def is_transpose_block(block_stmt: tir.Block) -> bool: - return is_identity_or_transpose_block(block_stmt)[1] - - -def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): - result_blocks = [] - for block in blocks: - if not is_transpose_block(sch.get(block)): - result_blocks.append(block) - continue - try: - sch.compute_inline(block) - except Exception: - try: - sch.reverse_compute_inline(block) - except Exception: - result_blocks.append(block) - return result_blocks - - -def normalize_to_matmul(sch: tir.Schedule, - main_block: BlockRV, - layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: - if layout is None: - layout = ["n", "t", "n"] - block_stmt = sch.get(main_block) - - # let layout be 'a' to auto inference the layout - index_maps = get_index_map(block_stmt, layout=layout) - if index_maps is None: - logger.debug("Cannot find the appropriate index map for tensorcore") - return None - - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # `skip_simplify` to avoid the bug in the 1x1 conv - block = sch.reindex(main_block, ("read", 0), skip_simplify=True) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1), skip_simplify=True) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0), skip_simplify=True) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - sch.mod["main"] = sch.mod["main"].with_attr("dlight.tensorcore_prenormlized", True) - return sch - - -def get_tensorized_func_and_tags( - func: tir.PrimFunc, - target: Target, - layout: Optional[List[str]] = None, - skip_normalize: bool = False, - allow_gemv: bool = False, -) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - """ - transform function to matmul if necessary (e.g. transform conv2d with im2col) - """ - if layout is None: - layout = ["a", "a", "a"] - # step1. detect whether the function can utilize tensorcore - sch = tir.Schedule(func) - root_block = get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - reduction_blocks = get_reduction_blocks(sch, blocks) - if not reduction_blocks or len(reduction_blocks) != 1: - return func, None - - def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: - block_stmt = sch.get(block) - conditions = [] - conditions.append(len(block_stmt.reads) == 2) - conditions.append(len(block_stmt.writes) == 1) - conditions.append( - len( - collect_block_iter_vars_used_in_access_region(block_stmt, - block_stmt.writes[0].region)) > 0) - if not all(conditions): - return False - return True - - # step2. transform function to tensorcore matmul (e.g. conv2d with im2col) - def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: - tags: Dict[str, Union[List[int], int]] = {} - block_stmt = sch.get(block) - - # analysis tensorcore axis - # todo(lei): maybe we can remove this in the future - (write_buffer_region,) = block_stmt.writes - out_axis = len(write_buffer_region.buffer.shape) - tags["tensorcore_config"] = [out_axis - 2, out_axis - 1] - - # analysis pipeline stage - # todo(lei): maybe we can integrate this into policy in the future - tags["pipeline_stage"] = 1 - if target.kind.name == "cuda" and check_sm_version(target.arch) == 80: - # enable pipeline stage only for sm_80 devices - tags["pipeline_stage"] = 2 - - # analysis async copy - # todo(lei): maybe we can integrate this into policy in the future - tags["use_async_copy"] = False - if tags["pipeline_stage"] == 2 and check_sm_version(target.arch) >= 80: - # async copy only works in software pipeline. - tags["use_async_copy"] = True - - # analysis intrin information - def get_ordered_axes(region: List[Range]) -> Set[Var]: - axes: List[Var] = [] - for r in region: - if not _is_one(r.extent): - raise ValueError("Expect elemwise block access") - axes.append(r.min) - return axes - - def is_common_reduce(var: Var) -> bool: - for iter_var in block_stmt.iter_vars: - if iter_var.var == var and iter_var.iter_type == IterVar.CommReduce: - return True - return False - - def has_common_reduce(var: Var) -> bool: - vars = collect_vars_from_expr(var) - return any(is_common_reduce(v) for v in vars) - - def check_last_trait(region: List[Range]): - axes = get_ordered_axes(region) - return has_common_reduce(axes[-1]) - - intrin_info: dict = {} - in_dtype, out_dtype = get_in_out_dtypes(block_stmt) - intrin_info["in_dtype"] = in_dtype - intrin_info["out_dtype"] = out_dtype - # if the last dimension is reduce axis, the B is transposed - intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) - if func.attrs is not None and "input_transform_kind" in func.attrs: - intrin_info["input_transform_kind"] = func.attrs["input_transform_kind"] - if func.attrs is not None and "weight_transform_kind" in func.attrs: - intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"] - tags["intrin_info"] = intrin_info - # Analysis Block Reduction Optimization - # Currently, we only support block reduction depth 2 for small M - # When the func is a dequantize like ops, we should consider the M - if hasattr(func.attrs, "dequantize_info"): - for arg in func.params: - inp_shape = func.buffer_map[arg].shape - M = inp_shape[0] - if isinstance(M, tir.IntImm) and M <= 128: - tags["block_reduction_depth"] = 2 - break - - return tags - - (main_block,) = reduction_blocks - if _can_be_tensorized(sch, main_block) is None: - return func, None - - block_stmt = sch.get(main_block) - if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: - # TODO(lei): we should consider the dtype of the input a and b - # instead of assuming both a and b share the same dtype. - # As the tensorcore may supports e4m3_float8 * e5m2_float8 - in_dtype, out_dtype = get_in_out_dtypes(block_stmt) - try: - _ = get_mma_intrin_group( - a_dtype=in_dtype, - b_dtype=in_dtype, - out_dtype=out_dtype, - ) - except Exception: - logger.debug("Cannot find the corresponding mma intrin group") - return func, None - - # reindex and transform functions - # Normalize tensor functions to C[S, I, J] += A[S, I, K] * B[S, J, K] - # or C[S, I, J] += A[S, I, K] * B[S, K, J] - # skip normalize when we want to detect tags only. - if not skip_normalize: - sch = normalize_to_matmul(sch, main_block, layout) - if sch is None: - return func, None - - block_stmt = sch.get(main_block) - - minimal_tensorize_threshold = 16 - # the batch dimension is not taken into consideration. - extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, - tir.expr.IntImm) and (extent.value < - (1 if allow_gemv else minimal_tensorize_threshold)): - return func, None - for item_var in block_stmt.iter_vars[2:]: - extent = item_var.dom.extent - if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): - return func, None - tags = analysis_tensorcore_tags(sch, main_block, target) - return sch.mod["main"], tags - - return func, None - - -def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, - ) - - assert dtype in [ - "float16", - "int8", - "e4m3_float8", - "e5m2_float8", - ], "Only support float16, int8, e4m3_float8, e5m2_float8" - if dtype == "float16": - ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout - ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout - elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - # int8 mma only support 32x16 to 16x32 layout - if matrix_name == "A" and trans is False: - ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a - elif matrix_name == "B" and trans is True: - ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_b - else: - raise ValueError("Unknown matrix name ", matrix_name) - - # IntraWarp memory layout was occurred by ldmatrix, we should lift the ld_matrix out - def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_layout(thread_id, local_id) - - def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_layout_trans(thread_id, local_id) - - def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 16 - local_id = kernel_j % 16 - return ldmatrix_layout(thread_id, local_id) - - if dtype == "float16": - ldmatrix_index_map = ( - ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans else ldmatrix_permutation_16x16_32x8_16x16) - else: - ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 - - ldmatrix_index_map = IndexMap.from_func(ldmatrix_index_map, index_dtype=index_dtype) - # TODO(lei): index_dtype should be analyzed from the schedule - row, col = [16, 16] if dtype == "float16" else [16, 32] - inversed_index_map = ldmatrix_index_map.inverse([row, col]) - return ldmatrix_index_map, inversed_index_map - - -def layout_propagate_chain( - sch: tir.Schedule, - start_block: BlockRV, - start_buffer: tir.Buffer, - end_block: BlockRV, - index_map: IndexMap, -): - # some layout transformation may only apply to the last n dimensions - # propagate the layout transformation to the chain of blocks - block = start_block - buffer = start_buffer - index_map = index_map - while True: - last_buffer = buffer - producers = sch.get_producers(block) - if len(producers) == 0: - break - for producer in producers: - if len(sch.get(producer).writes) != 1: - return index_map - if sch.get(producer) == sch.get(end_block): - return index_map - (write,) = sch.get(producer).writes - - read = find_first_similar_region(sch.get(producer).reads, last_buffer) - if write.buffer == buffer: - block = producer - buffer = read.buffer - write_indices = [r.min for r in write.region] - read_indices = [r.min for r in read.region] - # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout - tmp_index_map = IndexMap(write_indices, read_indices, None) - tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] - - # if dequantize like ops are used, the scaling factor should be considered - # to be applied to the final indices - scaling_factor = 1 - for i, j in zip(write.buffer.shape, read.buffer.shape): - scaling_factor *= i // j - final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices))) - final_indices[-1] = final_indices[-1] // scaling_factor - index_map = IndexMap( - write_indices, - final_indices, - None, - ) - if buffer == last_buffer: - break - return index_map diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py deleted file mode 100644 index 4bf8be4e6..000000000 --- a/python/bitblas/gpu/matmul_mma.py +++ /dev/null @@ -1,1069 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring, invalid-name -"""A GEMM schedule rule for GPU operators.""" -from typing import Literal, Optional, List - -from tvm import tir, DataType -from tvm.target import Target - -from ..base.roller import Hint -from ..base.roller.rasterization import NoRasterization -from ..base import analysis -from .base import GPUScheduleRule -from .matmul_mma_dequantize import MatmulTensorizationMMAWithDequantizeInfo -from ..base.analysis import get_coalesced_veclen -from .matmul_analysis import ( - auto_inline_consumer_chain, - is_transpose_block, - is_identity_block, - _collect_producers, - inline_transpose_block, - auto_inline_producers, - get_index_map, - get_reduction_blocks, - get_dequantize_block, - normalize_to_matmul, - get_propagate_map, -) - - -def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 - - def index_map_3d(b, i, j): - return ( - b, - i // l, - j // r, - *index_map(i % l, j % r), - ) - - return index_map_3d - - -def get_index_map_5d(index_map): - """ - for layout transformed gemm, the index map should be 5d - """ - - def index_map_5d(b, i, j, ii, jj): - return ( - b, - i, - j, - *index_map(ii, jj), - ) - - return index_map_5d - - -def get_warp_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 - if is_5d: - return get_index_map_5d(index_map) - return get_index_map_3d(index_map, l, r) - - -class MatmulTensorizationMMA(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - if "dequantize_info" in func.attrs: - dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() - return dequantize_rule.apply(func, target, False) - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - # We first inline all transpose blocks for later analysis of transposed A and B - blocks = inline_transpose_block(sch, blocks) - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - dequantize_block = get_dequantize_block(sch, blocks) - - main_block = reduction_blocks[0] - main_block_stmt = sch.get(main_block) - - # Supported data types: - # fp16, fp16, fp16: fp16 precision - # fp16, fp16, fp32: fp16 mixed precision - dtype_a = main_block_stmt.reads[0].buffer.dtype - dtype_b = main_block_stmt.reads[1].buffer.dtype - dtype_c = main_block_stmt.writes[0].buffer.dtype - if dtype_a != dtype_b: - return None - - # Get index maps - index_maps = get_index_map(main_block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # Tensorization by hardware intrinsics - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, shared_16x16_to_mma_32x8_layout, - ) - - # tile size - block_m, block_n, block_k = 128, 128, 32 - - # tensor core intrinsic size - micro_size_m, micro_size_n, micro_size_k = 16, 16, 16 - - # thread size - # thread_x == warp_size - thread_z, thread_y, thread_x = 2, 2, 32 - - vector_size = 8 - unroll_depth = 4 # noqa: F841 - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - is_transpose_a = is_transpose_block(sch.get(block)) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - is_transpose_b = is_identity_block(sch.get(block)) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - batch, i, j, k = sch.get_loops(main_block) - - swizzle_factor_for_l2_m = [1, None] - swizzle_factor_for_l2_n = [1, None] - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - swizzle_factor_for_l2_m[0] * block_m, - swizzle_factor_for_l2_n[0] * block_n, - block_k, - ], - ) - - # Step 3. Reorder loops for tiling - - # Step 3.1 inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_m]) - j, j_inner = sch.split(j, factors=[None, micro_size_n]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = main_block - block_outer = sch.blockize(i_inner) - - # Step 3.2 outer loops for tiling - # split factors for i, j, and k - micro_block_cnt_in_warp_m = block_m // thread_z // micro_size_m - micro_block_cnt_in_warp_n = block_n // thread_y // micro_size_n - micro_block_cnt_in_warp_k = block_k // micro_size_k - - i_factors = swizzle_factor_for_l2_m + [thread_z, micro_block_cnt_in_warp_m] - j_factors = swizzle_factor_for_l2_n + [thread_y, micro_block_cnt_in_warp_n] - k_factors = [None, micro_block_cnt_in_warp_k] - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, factors=k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) - - block_axis = sch.fuse(batch, i0, j0, i1, j1) - sch.bind(block_axis, "blockIdx.x") - - sch.bind(i2, "threadIdx.z") - sch.bind(j2, "threadIdx.y") - - # Step 4. Read/write to shared mem and register - def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose): - # 1) Read to shared memory - block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") - sch.compute_at(block_read_smem, k0) - auto_inline_producers(sch, block_read_smem, - [dequantize_block] if dequantize_block else []) - - # For transposed read, we directly load transposed tensor from global - # Then use ldmatrix.trans to handle transpose later - if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose): - # specifical handle transpose read (for NN matmul or TT matmul) - v0, v1 = sch.get_loops(block_read_smem)[-2:] - sch.reorder(v1, v0) - sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)) - - # bind loops - fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:]) - f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) - sch.bind(f1, "threadIdx.z") - sch.bind(f2, "threadIdx.y") - sch.bind(f3, "threadIdx.x") - sch.vectorize(f4) - - # swizzling - sch.annotate(block_read_smem, ann_key="permuted_layout", ann_val=1) - - # 2) Read to register - block_read_reg = sch.cache_read(block_outer, read_buffer_idx, "warp") - sch.compute_at(block_read_reg, k1) - - # bind_loops - micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n - micro_size_1, micro_size_2 = ((micro_size_spatial, - micro_size_k) if not is_transpose else - (micro_size_k, micro_size_spatial)) - v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1]) - v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2]) - sch.reorder(v00, v10, v01, v11) - - # reorder read axis to match the layout of ldmatrix - sch.transform_layout( - block_read_reg, - ("write", 0), - lambda v0, v1, v2: ( - v0, - v1 // micro_size_1, - v2 // micro_size_2, - *shared_16x16_to_mma_32x8_layout(v1 % micro_size_1, v2 % micro_size_2), - ), - ) - - # swizzling - mma_read_block = sch.blockize(sch.get_loops(block_read_reg)[-2]) - sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) - - return block_read_smem, block_read_reg - - block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a) - block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b) - - # Write to register, and then smem - def store_output(block_outer, write_buffer_idx): - # 1) Write to shared memory - block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") - sch.reverse_compute_at(block_write_smem, block_axis) - auto_inline_consumer_chain(sch, block_write_smem) - - # bind loops - fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:]) - f0, f1, f2 = sch.split(fused, [None, thread_x, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - # 2) Write to register - block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp") - - # bind loops - v0, v1, v2 = sch.get_loops(block_write_reg)[-3:] - v11, v12, v13 = sch.split(v1, factors=[thread_z, None, micro_size_m]) - v21, v22, v23 = sch.split(v2, factors=[thread_y, None, micro_size_n]) - sch.reorder(v11, v21, v12, v22, v13, v23) - sch.bind(v11, "threadIdx.z") - sch.bind(v21, "threadIdx.y") - - # reorder write axis to match the layout of ldmatrix - sch.transform_layout( - block_write_reg, - ("read", 0), - lambda v0, v1, v2: ( - v0, - v1 // micro_size_m, - v2 // micro_size_n, - *shared_16x16_to_mma_32x8_layout(v1 % micro_size_m, v2 % micro_size_n), - ), - ) - - return block_write_smem, block_write_reg - - _, block_write_reg = store_output(block_outer, 0) - - # Step 5. Schedule tensor core computation - block_init = sch.decompose_reduction(block_outer, k0) - block_init_inner = sch.get_child_blocks(block_init)[0] - - intrin_group = get_mma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - a_dtype=str(dtype_a), - b_dtype=str(dtype_b), - out_dtype=str(dtype_c), - trans_a=is_transpose_a, - trans_b=is_transpose_b, - not_use_mma_store_intrinic=False, - ) - - sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) - sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) - - # Step 6. Async pipeline - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) - - # Step 7. Handle dequantize block - # Now we just add a dummy kernel to compute dequantize - if dequantize_block is not None: - auto_inline_producers(sch, dequantize_block) - loops = sch.get_loops(dequantize_block) - loop = sch.fuse(*loops) - v0, v1, v2, v3 = sch.split(loop, [None, 128, 2, 4]) - sch.bind(v0, "blockIdx.x") - sch.bind(v1, "threadIdx.x") - sch.unroll(v2) - sch.vectorize(v3) - return sch - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config: Hint, - ) -> Optional[tir.Schedule]: - if "dequantize_info" in func.attrs: - dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() - return dequantize_rule.apply_config(func, config) - - if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: - return self.apply_block_reduction_with_config(func, config) - - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] - - def check_require_cache(func: tir.PrimFunc, config): - conditions: List[bool] = [] - - # check if has dynamic symbolic - def check_has_dynamic(func: tir.PrimFunc): - for param in func.params: - if param not in func.buffer_map: - continue - arg = func.buffer_map[param] - for i in arg.shape: - if isinstance(i, tir.Var): - return True - return False - - conditions.append(check_has_dynamic(func)) - # check if has post process - conditions.append(sch.get(main_block) not in output_blocks) - # check if not use async copy - conditions.append(config.use_async is False) - return any(conditions) - - cache_write_required = check_require_cache(func, config=config) - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - shared_scope = config.shared_scope - - intrin_info = config.intrin_info - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - # tensor core intrinsic size - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - if chunk * DataType(dtype).bits != 512: - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") - - # rewrite smooth layout of shared memory - def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_smem_layout_rewrite( - block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_smem_layout_rewrite( - block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) - smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[num_ty, num_tz, None, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - sch.bind(f_1, "threadIdx.z") - sch.bind(f_0, "threadIdx.y") - sch.vectorize(f_4) - sch.unroll(f_2) - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) - return block_read - - if len(config.vectorize.values()) < 2: - return None - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - trans=intrin_info.trans_a, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - vec_len=list(config.vectorize.values())[1], - can_swizzle=can_swizzle_b, - is_smooth=intrin_info.smooth_b, - trans=intrin_info.trans_b, - ) - - # rewrite global smooth layout - def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): - if not enable: - return - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - producers = _collect_producers(sch, block) - g2s_block = a_g2s if matrix_name == "A" else b_g2s - propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) - - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - smooth_gmem_layout_rewrite( - sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - smooth_gmem_layout_rewrite( - sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) - j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_warp_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def apply_block_reduction_with_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config: Hint, - ) -> Optional[tir.Schedule]: - if "dequantize_info" in func.attrs: - dequantize_rule = MatmulTensorizationMMAWithDequantizeInfo() - return dequantize_rule.sch_shared_memory_prefetch_block_reduction_with_config( - func, config) - - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - output_blocks = [sch.get(block) for block in sch.get_output_blocks(root_block)] - - def check_require_cache(func: tir.PrimFunc, config): - conditions: List[bool] = [] - - # check if has dynamic symbolic - def check_has_dynamic(func: tir.PrimFunc): - for param in func.params: - if param not in func.buffer_map: - continue - arg = func.buffer_map[param] - for i in arg.shape: - if isinstance(i, tir.Var): - return True - return False - - conditions.append(check_has_dynamic(func)) - # check if has post process - conditions.append(sch.get(main_block) not in output_blocks) - # check if not use async copy - conditions.append(config.use_async is False) - return any(conditions) - - cache_write_required = check_require_cache(func, config=config) - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - shared_scope = config.shared_scope - - intrin_info = config.intrin_info - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - # Start Schedule - # Step 0. Get schedule config. - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - assert (config.block_reduction_depth is not None), "block_reduction_depth is required" - reduce_k = config.block_reduction_depth - chunk = config.rstep[0] // reduce_k - - # tensor core intrinsic size - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - # introduce the constraint of reduce_k because reduce_k will doubling the size of - if (chunk * reduce_k) * DataType(dtype).bits != (512): - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, (reduce_k * chunk) // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - k0, kr = sch.split(k0, [None, reduce_k]) - - sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) - sch.bind(thread_idy, "threadIdx.y") - sch.bind(kr, "threadIdx.z") - - # rewrite smooth layout of shared memory - def smooth_smem_layout_rewrite(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_smem_layout_rewrite( - block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_smem_layout_rewrite( - block_outer, ("read", 1), *b_lr, enable=intrin_info.inter_transform_b) - smooth_smem_layout_rewrite(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False, trans=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_r, f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[reduce_k, num_ty, num_tz, None, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - f_0 = f_1 = sch.fuse(f_0, f_1) - sch.bind(f_0, "threadIdx.y") - sch.bind(f_r, "threadIdx.z") - sch.vectorize(f_4) - sch.unroll(f_2) - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_2, "pragma_unroll_explicit", False) - return block_read - - if len(config.vectorize.values()) < 2: - return None - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - trans=intrin_info.trans_a, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - vec_len=list(config.vectorize.values())[1], - can_swizzle=can_swizzle_b, - is_smooth=intrin_info.smooth_b, - trans=intrin_info.trans_b, - ) - - # rewrite global smooth layout - def smooth_gmem_layout_rewrite(sch, block, enable=True, trans=False, matrix_name="A"): - if not enable: - return - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - producers = _collect_producers(sch, block) - g2s_block = a_g2s if matrix_name == "A" else b_g2s - propagate_block: tir.Block = (producers[-1] if len(producers) > 0 else g2s_block) - - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - smooth_gmem_layout_rewrite( - sch, a_g2s, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - smooth_gmem_layout_rewrite( - sch, b_g2s, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x], preserve_unit_iters=False) - j0, j1 = sch.split(j, factors=[None, micro_size_y], preserve_unit_iters=False) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_warp_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_warp_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_warp_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - # sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - # sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py deleted file mode 100644 index 679e84395..000000000 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ /dev/null @@ -1,2295 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring, invalid-name -"""A GEMM schedule rule for GPU operators.""" -from typing import Optional, List -from contextlib import suppress - -from tvm import tir, DataType - -from ..base.roller.hint import Hint, IntrinInfo -from tvm.target import Target -from ..base.roller.rasterization import NoRasterization -from ..base import analysis -from .base import GPUScheduleRule -from ..base.analysis import get_coalesced_veclen -from .matmul_analysis import ( - auto_inline_consumer_chain, - auto_inline_producers, - get_reduction_blocks, - normalize_to_matmul, - get_propagate_map, - layout_propagate_chain, - find_last_producer_from_buffer, - _collect_producers, - get_in_out_dtypes, -) - - -def _bind_thread_based_on_config(sch, block, block_row_warps, block_col_warps, warp_size): - # assume the block loops has been fused - last_loop = sch.get_loops(block)[-1] - loop_extent = sch.get(last_loop).extent - vec = get_coalesced_veclen(sch.get(block)) - - if loop_extent // (vec * warp_size) == 0: - last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) - sch.bind(B_shared_tx, "threadIdx.x") - sch.vectorize(B_shared_vi) - loop_extent = sch.get(last_loop).extent - - # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. - if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // - vec) & (warp_size - 1) == 0: - last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) - sch.vectorize(B_shared_vi) - loop_extent = sch.get(last_loop).extent - - if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: - last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - loop_extent = sch.get(last_loop).extent - - if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: - last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) - sch.bind(B_shared_ty, "threadIdx.y") - loop_extent = sch.get(last_loop).extent - - if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: - last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) - sch.bind(B_shared_tz, "threadIdx.z") - loop_extent = sch.get(last_loop).extent - - sch.unroll(last_loop) - sch.annotate(last_loop, "pragma_unroll_explicit", False) - - -def _bind_thread_based_with_block_reduce_on_config(sch, block, block_row_warps, block_col_warps, - warp_size, reduce_k): - # assume the block loops has been fused - last_loop = sch.get_loops(block)[-1] - loop_extent = sch.get(last_loop).extent - vec = get_coalesced_veclen(sch.get(block)) - - if loop_extent // (vec * warp_size) == 0: - last_loop, B_shared_tx, B_shared_vi = sch.split(last_loop, factors=[1, warp_size, None]) - sch.bind(B_shared_tx, "threadIdx.x") - sch.vectorize(B_shared_vi) - loop_extent = sch.get(last_loop).extent - - # warp_size - 1 handling for 32x7 alike case, which may cause unaligned threadIdx.x mapping. - if loop_extent // vec >= 1 and loop_extent & (vec - 1) == 0 and (loop_extent // - vec) & (warp_size - 1) == 0: - last_loop, B_shared_vi = sch.split(last_loop, factors=[None, vec]) - sch.vectorize(B_shared_vi) - loop_extent = sch.get(last_loop).extent - - if loop_extent // warp_size >= 1 and loop_extent % warp_size == 0: - last_loop, B_shared_tx = sch.split(last_loop, factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - loop_extent = sch.get(last_loop).extent - - B_shared_ty = None - if loop_extent // block_row_warps >= 1 and loop_extent % block_row_warps == 0: - last_loop, B_shared_ty = sch.split(last_loop, factors=[None, block_row_warps]) - loop_extent = sch.get(last_loop).extent - - B_shared_tz = None - if loop_extent // block_col_warps >= 1 and loop_extent % block_col_warps == 0: - last_loop, B_shared_tz = sch.split(last_loop, factors=[None, block_col_warps]) - loop_extent = sch.get(last_loop).extent - - if B_shared_ty and B_shared_tz: - B_shared_ty = sch.fuse(B_shared_tz, B_shared_ty) - sch.bind(B_shared_ty, "threadIdx.y") - elif B_shared_ty: - sch.bind(B_shared_ty, "threadIdx.y") - - if loop_extent // reduce_k >= 1 and loop_extent % reduce_k == 0: - last_loop, B_shared_tk = sch.split(last_loop, factors=[None, reduce_k]) - sch.bind(B_shared_tk, "threadIdx.z") - loop_extent = sch.get(last_loop).extent - - sch.unroll(last_loop) - sch.annotate(last_loop, "pragma_unroll_explicit", False) - - -def get_index_map_3d(index_map, l=16, r=16): # noqa: E741 - - def index_map_3d(b, i, j): - return ( - b, - i // l, - j // r, - *index_map(i % l, j % r), - ) - - return index_map_3d - - -def get_index_map_5d(index_map): - """ - for layout transformed gemm, the index map should be 5d - """ - - def index_map_5d(b, i, j, ii, jj): - return ( - b, - i, - j, - *index_map(ii, jj), - ) - - return index_map_5d - - -def get_index_map(index_map, l=16, r=16, is_5d=False): # noqa: E741 - if is_5d: - return get_index_map_5d(index_map) - return get_index_map_3d(index_map, l, r) - - -class MatmulTensorizationMMAWithDequantizeInfo(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ): - """ - For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. - quantized weight - | - V - dequantized in register - | - V - save into shared memory - | - V - compute - """ - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - from .intrin import get_lop3_intrin_group - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - # always enable shared memory rewrite - cache_write_required = True - - # Check Dequantize Info - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - assert check_dequantize_info(dequantize_info) - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" - - # Start Schedule - # Step 1. Get default schedule config. - - # tensor core intrinsic size - in_dtype, out_dtype = get_in_out_dtypes(sch.get(main_block)) - intrin_info = IntrinInfo( - in_dtype=in_dtype, - out_dtype=out_dtype, - trans_b=True, - ) - if "weight_transform_kind" in func.attrs: - intrin_info.weight_transform_kind = int(func.attrs["weight_transform_kind"]) - - if "input_transform_kind" in func.attrs: - intrin_info.input_transform_kind = int(func.attrs["input_transform_kind"]) - # default Hint - config = Hint().from_dict({ - "block": [128, 128], - "warp": [64, 64], - "rstep": [32], - "pipeline_stage": 1, - "use_async": False, - "intrin_info": intrin_info, - "shared_scope": "shared.dyn", - }) - shared_scope = config.shared_scope - - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - if chunk * DataType(dtype).bits != 512: - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - # rewrite global smooth layout, for dequantize, currently only support weight only recover. - def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): - if not enable: - return - - # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = (1 if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer else 0) - buffer_idx = 0 if matrix_name == "A" else 1 - source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer - - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer) - # some trick impl may not have reindex block - (weight_dequantize_info,) = dequantize_info.values() - if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): - return - - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # step3: propagate the matmul layout to the first reindex block - - intra_indexmap = layout_propagate_chain( - sch, - start_block=main_block, - start_buffer=source_buffer, - end_block=propagate_block, - index_map=intra_indexmap, - ) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["n", "t", "n"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") - - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_layout_recover( - block_outer, - ("read", 1), - *b_lr, - enable=intrin_info.inter_transform_b, - ) - smooth_layout_recover(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.z") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_4) - sch.unroll(f_0) - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - sch.annotate(f_0, "pragma_unroll_explicit", False) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=4, - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - ) - - auto_inline_producers(sch, a_g2s) - - def decode_fetch_to_shared(block, idx): - # step1. create memory hierarchy - # global -> local -> shared - block_shared = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_shared, k0, preserve_unit_loops=True) - - decode_factor = get_coalesced_veclen(sch.get(block_shared)) - _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) - block_shared_local = sch.cache_read(block_shared, 0, "local") - # global -> dequantzed_local -> shared - # step2. inline to local block - weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) - weight_producers = _collect_producers(sch, weight_dequantize_block) - auto_inline_producers(sch, block_shared_local, weight_producers) - - # get target dequantize buffer's idx - def get_idx(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - b_idx = get_idx() - # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") - - sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) - - dequantize_block_local = block_shared_local - if ("zeros_mode" in weight_decode_info and - weight_decode_info["zeros_mode"] == "quantized"): - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - # pop the scale block - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - - for producer in weight_producers: - with suppress(Exception): - auto_inline_producers(sch, producer) - sch.compute_inline(producer) - - # fast type conversion - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize( - sch.get_loops(dequantize_block_local)[-1], - lop3_intrin_info["compute"], - ) - import_source.append(lop3_intrin_info["c_source"]) - - sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) - union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) - B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) - _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) - if not (can_swizzle_b or intrin_info.smooth_b): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) - sch.bind(B_shared_tx, "threadIdx.x") - sch.bind(B_shared_ty, "threadIdx.y") - sch.bind(B_shared_tz, "threadIdx.z") - sch.vectorize(sch.get_loops(block_shared)[-1]) - sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - - # cache small tensors, e.g. LUT - if b_idx: - block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) - sch.reverse_compute_at(block_shared_lut, j2) - _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - return block_shared_local - - _ = decode_fetch_to_shared(block_outer, 1) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1, preserve_unit_loops=True) - sch.compute_at(B_mat, k1, preserve_unit_loops=True) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate( - k0, - ann_key="software_pipeline_stage", - ann_val=[0, 0, stage - 1], - ) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def sch_dequantize_in_register_with_config( - self, - func: tir.PrimFunc, - config: Hint, - ): - """ - For devices without async copy, we can use a simple dequantize schedule without shared memory prefetch. - quantized weight - | - V - dequantized in register - | - V - save into shared memory - | - V - compute - """ - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - from .intrin import get_lop3_intrin_group - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - # always enable shared memory rewrite - cache_write_required = True - - # Check Dequantize Info - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - assert check_dequantize_info(dequantize_info) - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - assert check_weight_decode_info(weight_decode_info), "Invalid Weight Decode Info" - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - intrin_info = config.intrin_info - shared_scope = config.shared_scope - - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - if chunk * DataType(dtype).bits != 512: - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - # rewrite global smooth layout, for dequantize, currently only support weight only recover. - def smooth_gmem_layout_rewrite(sch, main_block, enable=True, trans=False, matrix_name="A"): - if not enable: - return - - # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = (1 if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer else 0) - buffer_idx = 0 if matrix_name == "A" else 1 - source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer - - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer) - # some trick impl may not have reindex block - (weight_dequantize_info,) = dequantize_info.values() - if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): - return - - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # step3: propagate the matmul layout to the first reindex block - - intra_indexmap = layout_propagate_chain( - sch, - start_block=main_block, - start_buffer=source_buffer, - end_block=propagate_block, - index_map=intra_indexmap, - ) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") - - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_layout_recover( - block_outer, - ("read", 1), - *b_lr, - enable=intrin_info.inter_transform_b, - ) - smooth_layout_recover(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.z") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_4) - sch.unroll(f_0) - sch.annotate(f_0, "pragma_unroll_explicit", False) - - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - ) - - auto_inline_producers(sch, a_g2s) - - def decode_fetch_to_shared(block, idx): - # step1. create memory hierarchy - # global -> local -> shared - block_shared = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_shared, k0, preserve_unit_loops=True) - - decode_factor = get_coalesced_veclen(sch.get(block_shared)) - _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) - block_shared_local = sch.cache_read(block_shared, 0, "local") - # global -> dequantzed_local -> shared - # step2. inline to local block - weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) - weight_producers = _collect_producers(sch, weight_dequantize_block) - auto_inline_producers(sch, block_shared_local, weight_producers) - - # get target dequantize buffer's idx - def get_idx(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - b_idx = get_idx() - # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") - - sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) - - dequantize_block_local = block_shared_local - if ("zeros_mode" in weight_decode_info and - weight_decode_info["zeros_mode"] == "quantized"): - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - # pop the scale block - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - - for producer in weight_producers: - with suppress(Exception): - auto_inline_producers(sch, producer) - sch.compute_inline(producer) - - # fast type conversion - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize( - sch.get_loops(dequantize_block_local)[-1], - lop3_intrin_info["compute"], - ) - import_source.append(lop3_intrin_info["c_source"]) - - sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) - union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) - B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) - _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) - if not (can_swizzle_b or intrin_info.smooth_b): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) - sch.bind(B_shared_tx, "threadIdx.x") - sch.bind(B_shared_ty, "threadIdx.y") - sch.bind(B_shared_tz, "threadIdx.z") - sch.vectorize(sch.get_loops(block_shared)[-1]) - sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - - # cache small tensors, e.g. LUT - if b_idx: - block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) - sch.reverse_compute_at(block_shared_lut, j2) - _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - return block_shared_local - - _ = decode_fetch_to_shared(block_outer, 1) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1, preserve_unit_loops=True) - sch.compute_at(B_mat, k1, preserve_unit_loops=True) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate( - k0, - ann_key="software_pipeline_stage", - ann_val=[0, 0, stage - 1], - ) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def sch_shared_memory_prefetch_with_config( - self, - func: tir.PrimFunc, - config: Hint, - ): - """ - For A100 Like devices, the shared memory prefetch(async) is required - to achieve optimal performance. - quantized weight - | - V - shared memory prefetch (with async copy) - | - V - dequantized into shared memory - | - V - compute - """ - if hasattr(config, "block_reduction_depth") and config.block_reduction_depth is not None: - return self.sch_shared_memory_prefetch_block_reduction_with_config(func, config) - - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - from .intrin import get_lop3_intrin_group - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - # always enable shared memory rewrite - cache_write_required = True - - # Check Dequantize Info - # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - assert check_dequantize_info(dequantize_info) - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" - - # Start Schedule - # Step 0. Get schedule config. - # tensor core intrinsic size - shared_scope = config.shared_scope - - intrin_info = config.intrin_info - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - if chunk * DataType(dtype).bits != 512: - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - # rewrite global smooth layout, for dequantize, currently only support weight only recover. - def smooth_gmem_layout_rewrite( - sch, - main_block, - enable=True, - trans=False, - matrix_name="A", - intrin_group=intrin_group, - ): - if not enable: - return - - # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = (1 if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer else 0) - buffer_idx = 0 if matrix_name == "A" else 1 - source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer - - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer) - # some trick impl may not have reindex block - (weight_dequantize_info,) = dequantize_info.values() - if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): - return - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # step3: propagate the matmul layout to the first reindex block - - intra_indexmap = layout_propagate_chain( - sch, - start_block=main_block, - start_buffer=source_buffer, - end_block=propagate_block, - index_map=intra_indexmap, - ) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - intra_index_map, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # get target dequantize buffer's offset - def get_offset(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_dequantize_info["source_format"]["format"] == "nf": - return 1 - return 0 - - offset = get_offset() - dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) - group_size = weight_dequantize_info["group_size"] - - _, mn, mk = intrin_group["micro_kernel"] - - def get_param_indices( - indexmap, - l=mn, - r=mk, - group_size=group_size # noqa: E741 - ): # noqa: E741 - # assume the param layout is n, k - rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] - warp_i, warp_j = rl % l, rr % r - spatial_i, spatial_j = rl // l, rr // r - warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) - new_indices = ( - spatial_i * l + warp_i, - (spatial_j * r + warp_j) // group_size, - ) - return new_indices - - with_scaling = bool(weight_dequantize_info["with_scaling"]) - if with_scaling: - sch.unsafe_rewrite_buffer_region( - dequantize_block, - ("read", offset + 1), - get_param_indices(intra_index_map), - ) - with_zeros = bool(weight_dequantize_info["with_zeros"]) - if with_zeros: - sch.unsafe_rewrite_buffer_region( - dequantize_block, - ("read", offset + 2), - get_param_indices(intra_index_map), - ) - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - sch.bind(thread_idz, "threadIdx.z") - - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_layout_recover( - block_outer, - ("read", 1), - *b_lr, - enable=intrin_info.inter_transform_b, - ) - smooth_layout_recover(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_0, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, num_ty, num_tz, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - sch.bind(f_2, "threadIdx.z") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_4) - sch.unroll(f_0) - sch.annotate(f_0, "pragma_unroll_explicit", False) - - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - ) - - auto_inline_producers(sch, a_g2s) - - def decode_fetch_to_shared(block, idx): - # step1. create memory hierarchy - # global -> local -> shared - block_shared = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_shared, k0, preserve_unit_loops=True) - - # TODO(lei): the factor should be analyzed more deeper. - decode_factor = get_coalesced_veclen(sch.get(block_shared)) - _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) - block_shared_local = sch.cache_read(block_shared, 0, "local") - # global -> dequantzed_local -> shared - # step2. inline to local block, should skip qzeros - is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and - weight_decode_info["zeros_mode"] == "quantized") - weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) - weight_producers = ( - _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) - auto_inline_producers(sch, block_shared_local, weight_producers) - - # get target dequantize buffer's idx - def get_idx(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - b_idx = get_idx() - # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") - # global -> prefetch_shared -> vector load -> dequantzed_local -> shared - block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, - shared_scope) - sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) - - dequantize_block_local = block_shared_local - if is_qzeros: - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - - for producer in weight_producers: - with suppress(Exception): - auto_inline_producers(sch, producer) - sch.compute_inline(producer) - - # fast type conversion - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize( - sch.get_loops(dequantize_block_local)[-1], - lop3_intrin_info["compute"], - ) - import_source.append(lop3_intrin_info["c_source"]) - - sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) - union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) - B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) - _, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, num_ty, num_tz, warp_size]) - if not (can_swizzle_b or intrin_info.smooth_b): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) - sch.bind(B_shared_tx, "threadIdx.x") - sch.bind(B_shared_ty, "threadIdx.y") - sch.bind(B_shared_tz, "threadIdx.z") - sch.vectorize(sch.get_loops(block_shared)[-1]) - sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - - sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_shared_local_local_shared).iter_vars) - _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) - - _bind_thread_based_on_config(sch, block_shared_local_local_shared, num_ty, num_tz, - warp_size) - - # cache small tensors, e.g. LUT - if b_idx: - block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) - sch.reverse_compute_at(block_shared_lut, j2) - _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - return block_shared_local - - _ = decode_fetch_to_shared(block_outer, 1) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1, preserve_unit_loops=True) - sch.compute_at(B_mat, k1, preserve_unit_loops=True) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate( - k0, - ann_key="software_pipeline_stage", - ann_val=[0, 0, stage - 1, stage - 1], - ) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - thread_idz, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def sch_shared_memory_prefetch_block_reduction_with_config( - self, - func: tir.PrimFunc, - config: Hint, - ): - """ - For A100 Like devices, the shared memory prefetch(async) is required - to achieve optimal performance. - quantized weight - | - V - shared memory prefetch (with async copy) - | - V - dequantized into shared memory - | - V - compute - """ - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - from .intrin import get_lop3_intrin_group - - import_source: List[str] = [] - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - # always enable shared memory rewrite - cache_write_required = True - - # Check Dequantize Info - # TODO(leiwang): this is a hack to get the configuration, can be improved by writing a pass to analysis the dequantize block. - dequantize_info = func.attrs["dequantize_info"] - - def check_dequantize_info(dequantize_info): - conditions = [] - # currently only support weight only dequantization - conditions.append(len(dequantize_info) == 1) - # TODO(@lei) check if the dequantize value name is weight - return all(conditions) - - assert check_dequantize_info(dequantize_info) - - (weight_decode_info,) = list(dequantize_info.values()) - - def check_weight_decode_info(weight_decode_info): - conditions = [] - # check source format in ["int", "fp", "nf"] - conditions.append("source_format" in weight_decode_info) - conditions.append(weight_decode_info["source_format"]["format"] in - ["uint", "int", "fp", "nf", "fp_e4m3"]) - # check source bits in [1, 2, 4, 8] - conditions.append(weight_decode_info["source_format"]["bits"] in [1, 2, 4, 8]) - # check target format in ["float16", "int8"] - conditions.append("target_format" in weight_decode_info) - conditions.append(weight_decode_info["target_format"] in ["float16", "int8"]) - return all(conditions) - - assert check_weight_decode_info(weight_decode_info), "Invalid B_decode_info" - - # Start Schedule - # Step 0. Get schedule config. - # tensor core intrinsic size - shared_scope = config.shared_scope - - intrin_info = config.intrin_info - intrin_group = get_mma_intrin_group( - load_scope=shared_scope, - store_scope=shared_scope if cache_write_required else "global", - a_dtype=intrin_info.in_dtype, - b_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=intrin_info.trans_a, - trans_b=intrin_info.trans_b, - smooth_a=intrin_info.smooth_a, - smooth_b=intrin_info.smooth_b, - not_use_mma_store_intrinic=False, - ) - - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - assert (config.block_reduction_depth is not None), "block_reduction_depth is required" - reduce_k = config.block_reduction_depth - chunk = config.rstep[0] // reduce_k - - micro_size_x, micro_size_y, micro_size_k = intrin_group["micro_kernel"] - - # get the axis for layout transform - def get_axis(l, r, trans): # noqa: E741 - return (r, l) if trans else (l, r) # noqa: E741 - - a_lr = get_axis(micro_size_x, micro_size_k, intrin_info.trans_a) - b_lr = get_axis(micro_size_k, micro_size_y, intrin_info.trans_b) - - def can_enable_swizzle(dtype: str, smooth: bool): - # inject_permuted_layout only support float16 currently - if dtype == "float16" or dtype == "int8": - if (chunk * reduce_k) * DataType(dtype).bits != 512: - # currently the swizzle rule only support 512 bit. - return False - # if we use smooth layout, we don't need to do swizzling - return not smooth - return False - - can_swizzle_a = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_a) - can_swizzle_b = can_enable_swizzle(intrin_info.in_dtype, intrin_info.inter_transform_b) - - # rewrite global smooth layout, for dequantize, currently only support weight only recover. - def smooth_gmem_layout_rewrite( - sch, - main_block, - enable=True, - trans=False, - matrix_name="A", - intrin_group=intrin_group, - ): - if not enable: - return - - # normalized block may have three read buffers, while the first one is the write buffer. - buffer_offset = (1 if sch.get(main_block).reads[0].buffer - == sch.get(main_block).writes[0].buffer else 0) - buffer_idx = 0 if matrix_name == "A" else 1 - source_buffer = sch.get(main_block).reads[buffer_offset + buffer_idx].buffer - - # step1: find the first producer block - # Notes: we assume the layout propagate happens in the first producer block - # otherwise, the layout transform will have no effect as it will transform both - # read and write buffer - propagate_block: tir.Block = find_last_producer_from_buffer( - sch, main_block, source_buffer) - # some trick impl may not have reindex block - (weight_dequantize_info,) = dequantize_info.values() - if (sch.get(propagate_block).name_hint == weight_dequantize_info["decode_block"]): - return - # step2: transform the layout with inverse permutation - intra_indexmap, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # step3: propagate the matmul layout to the first reindex block - - intra_indexmap = layout_propagate_chain( - sch, - start_block=main_block, - start_buffer=source_buffer, - end_block=propagate_block, - index_map=intra_indexmap, - ) - - def inverse_permutation(i, j, ii, jj): - return (i, j, *intra_indexmap.map_indices([ii, jj])) - - sch.transform_layout(propagate_block, ("read", 0), inverse_permutation) - - intra_index_map, _ = get_propagate_map( - trans=trans, dtype=intrin_info.in_dtype, matrix_name=matrix_name) - - # get target dequantize buffer's offset - def get_offset(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_dequantize_info["source_format"]["format"] == "nf": - return 1 - return 0 - - offset = get_offset() - dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) - group_size = weight_dequantize_info["group_size"] - - _, mn, mk = intrin_group["micro_kernel"] - - def get_param_indices( - indexmap, - l=mn, - r=mk, - group_size=group_size # noqa: E741 - ): # noqa: E741 - # assume the param layout is n, k - rl, rr = [x.var for x in sch.get(dequantize_block).iter_vars] - warp_i, warp_j = rl % l, rr % r - spatial_i, spatial_j = rl // l, rr // r - warp_i, warp_j = indexmap.map_indices([warp_i, warp_j]) - new_indices = ( - spatial_i * l + warp_i, - (spatial_j * r + warp_j) // group_size, - ) - return new_indices - - with_scaling = bool(weight_dequantize_info["with_scaling"]) - if with_scaling: - sch.unsafe_rewrite_buffer_region( - dequantize_block, - ("read", offset + 1), - get_param_indices(intra_index_map), - ) - with_zeros = bool(weight_dequantize_info["with_zeros"]) - if with_zeros: - sch.unsafe_rewrite_buffer_region( - dequantize_block, - ("read", offset + 2), - get_param_indices(intra_index_map), - ) - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_a, intrin_info.trans_a, matrix_name="A") - - smooth_gmem_layout_rewrite( - sch, main_block, intrin_info.smooth_b, intrin_info.trans_b, matrix_name="B") - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] - num_tz = j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - k0, kr = sch.split(k0, [None, reduce_k]) - - sch.reorder(i0, j0, i1, j1, i2, j2, kr, i3, j3, k0, k1) - - block_idy = sch.fuse(i0, j0) - block_idx = sch.fuse(i1, j1) - thread_idy = i2 - thread_idz = j2 - - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - thread_idz = j2 = thread_idy = sch.fuse(thread_idy, thread_idz) - sch.bind(thread_idy, "threadIdx.y") - sch.bind(kr, "threadIdx.z") - - def smooth_layout_recover(block, scope, l=16, r=16, enable=True): # noqa: E741 - if not enable: - return - sch.transform_layout( - block, - scope, - lambda b, i, j: ( - b, - i // l, - j // r, - i % l, - j % r, - ), - ) - - smooth_layout_recover(block_outer, ("read", 0), *a_lr, enable=intrin_info.inter_transform_a) - smooth_layout_recover( - block_outer, - ("read", 1), - *b_lr, - enable=intrin_info.inter_transform_b, - ) - smooth_layout_recover(block_outer, ("write", 0), enable=True) - - def fetch_to_shared(block, idx, vec_len, can_swizzle=False, is_smooth=False): - block_read = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_read, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_read).iter_vars) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - f_0, f_r, f_1, f_2, f_3, f_4 = sch.split( - fused, factors=[None, reduce_k, num_ty, num_tz, warp_size, vec_len]) - - sch.bind(f_3, "threadIdx.x") - f_1 = f_2 = sch.fuse(f_1, f_2) - sch.bind(f_1, "threadIdx.y") - sch.bind(f_r, "threadIdx.z") - sch.vectorize(f_4) - sch.unroll(f_0) - sch.annotate(f_0, "pragma_unroll_explicit", False) - - # Apply Swizzling - sch.annotate(block_read, ann_key="permuted_layout", ann_val=can_swizzle) - # if not, apply padding to alleviate bank conflict - if not (can_swizzle or is_smooth): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=pad_offset) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - vec_len=list(config.vectorize.values())[0], - can_swizzle=can_swizzle_a, - is_smooth=intrin_info.smooth_a, - ) - - auto_inline_producers(sch, a_g2s) - - def decode_fetch_to_shared(block, idx): - # step1. create memory hierarchy - # global -> local -> shared - block_shared = sch.cache_read(block, idx, shared_scope) - sch.compute_at(block_shared, k0, preserve_unit_loops=True) - - # TODO(lei): the factor should be analyzed more deeper. - decode_factor = get_coalesced_veclen(sch.get(block_shared)) - _, B_shared_vi, _ = sch.split( - sch.get_loops(block_shared)[-1], factors=[None, 1, decode_factor]) - block_shared_local = sch.cache_read(block_shared, 0, "local") - # global -> dequantzed_local -> shared - # step2. inline to local block, should skip qzeros - is_qzeros = ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"] and - weight_decode_info["zeros_mode"] == "quantized") - weight_dequantize_block = sch.get_block(weight_decode_info["decode_block"]) - weight_producers = ( - _collect_producers(sch, weight_dequantize_block) if is_qzeros else []) - auto_inline_producers(sch, block_shared_local, weight_producers) - - # get target dequantize buffer's idx - def get_idx(): - # for LUT dequantize, the expr is LUT(w), the idx is 1 - # maybe we can use a more general and structural based way - # to analysis the idx - if weight_decode_info["source_format"]["format"] == "nf": - return 1 - return 0 - - b_idx = get_idx() - # global -> prefetch_local -> dequantzed_local -> shared - block_shared_local_local = sch.cache_read(block_shared_local, b_idx, "local") - # global -> prefetch_shared -> vector load -> dequantzed_local -> shared - block_shared_local_local_shared = sch.cache_read(block_shared_local_local, 0, - shared_scope) - sch.compute_at(block_shared_local, B_shared_vi, preserve_unit_loops=True) - sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) - - dequantize_block_local = block_shared_local - if is_qzeros: - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) - - for producer in weight_producers: - with suppress(Exception): - auto_inline_producers(sch, producer) - sch.compute_inline(producer) - - # fast type conversion - if ("fast_decoding" in weight_decode_info and weight_decode_info["fast_decoding"]): - source_bit = weight_decode_info["source_format"]["bits"] - out_dtype = weight_decode_info["target_format"] - lop3_intrin_info = get_lop3_intrin_group( - out_dtype=out_dtype, - storage_dtype=weight_decode_info["storage_dtype"], - source_format=weight_decode_info["source_format"]["format"], - source_bit=source_bit, - with_scaling=weight_decode_info["with_scaling"], - with_zeros=weight_decode_info["with_zeros"], - zeros_mode=weight_decode_info["zeros_mode"], - ) - sch.tensorize( - sch.get_loops(dequantize_block_local)[-1], - lop3_intrin_info["compute"], - ) - import_source.append(lop3_intrin_info["c_source"]) - - sch.annotate(block_shared, ann_key="permuted_layout", ann_val=can_swizzle_b) - union_len = (2 + 4) if intrin_info.smooth_b else (2 + 2) - B_shared_fused = sch.fuse(*sch.get_loops(block_shared)[-union_len:-2]) - _, B_shared_rk, B_shared_ty, B_shared_tz, B_shared_tx = sch.split( - B_shared_fused, factors=[None, reduce_k, num_ty, num_tz, warp_size]) - if not (can_swizzle_b or intrin_info.smooth_b): - pad_offset = 8 if intrin_info.in_dtype == "float16" else 16 - sch.storage_align(block_shared, 0, axis=-2, factor=16, offset=pad_offset) - sch.bind(B_shared_tx, "threadIdx.x") - B_shared_tz = B_shared_ty = sch.fuse(B_shared_ty, B_shared_tz) - sch.bind(B_shared_ty, "threadIdx.y") - sch.bind(B_shared_rk, "threadIdx.z") - sch.vectorize(sch.get_loops(block_shared)[-1]) - sch.vectorize(sch.get_loops(block_shared_local_local)[-1]) - - sch.compute_at(block_shared_local_local_shared, k0, preserve_unit_loops=True) - ndim = len(sch.get(block_shared_local_local_shared).iter_vars) - _ = sch.fuse(*sch.get_loops(block_shared_local_local_shared)[-ndim:]) - - _bind_thread_based_with_block_reduce_on_config(sch, block_shared_local_local_shared, - num_ty, num_tz, warp_size, reduce_k) - - # cache small tensors, e.g. LUT - if b_idx: - block_shared_lut = sch.cache_read(dequantize_block_local, 0, shared_scope) - sch.reverse_compute_at(block_shared_lut, j2) - _, B_shared_tx = sch.split( - sch.get_loops(block_shared_lut)[-1], factors=[None, warp_size]) - sch.bind(B_shared_tx, "threadIdx.x") - return block_shared_local - - _ = decode_fetch_to_shared(block_outer, 1) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1, preserve_unit_loops=True) - sch.compute_at(B_mat, k1, preserve_unit_loops=True) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - if cache_write_required: - accumulator_shared_to_global = sch.cache_write(block_outer, 0, shared_scope) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, j2) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - - if cache_write_required: - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - sch.reverse_compute_at( - accumulator_shared_to_global, - sch.get_loops(store)[-5], - preserve_unit_loops=True, - ) - vec_len = get_coalesced_veclen(sch.get(accumulator_shared_to_global)) - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-5:]) - f0, f1, f2 = sch.split(fused, factors=[None, warp_size, vec_len]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - sch.unroll(f0) - sch.annotate(f0, "pragma_unroll_explicit", False) - else: - auto_inline_consumer_chain(sch, store) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - - index_map_a, index_map_b, index_map_c = intrin_group["index_map"] - - sch.transform_layout( - A_mat, - ("write", 0), - get_index_map(index_map_a, *a_lr, intrin_info.inter_transform_a), - ) - sch.transform_layout( - B_mat, - ("write", 0), - get_index_map(index_map_b, *b_lr, intrin_info.inter_transform_b), - ) - sch.transform_layout( - store, - ("read", 0), - get_index_map(index_map_c, is_5d=True), - ) - - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, a_lr[0]]) - j0, j1 = sch.split(j, factors=[None, a_lr[1]]) - sch.reorder(i0, j0, i1, j1) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val=can_swizzle_a) - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, b_lr[0]]) - j0, j1 = sch.split(j, factors=[None, b_lr[1]]) - sch.reorder(i0, j0, i1, j1) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val=can_swizzle_b) - sch.tensorize(bb, intrin_group["load_b"]) - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - if stage > 1: - sch.annotate( - k0, - ann_key="software_pipeline_stage", - ann_val=[0, 0, stage - 1, stage - 1], - ) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2, 3]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - # plan rasteration - if not isinstance(config.rasterization_plan, NoRasterization): - device_func, invoke_func = config.rasterization_plan.get_code() - import_source.append(device_func) - sch.annotate( - sch.get_loops(block_init_c)[-2], - ann_key="inject_customized_code_prepend", - ann_val=invoke_func, - ) - # plan import source - if len(import_source) > 0: - sch.annotate( - j2, - ann_key="pragma_import_c", - ann_val=("\n").join(import_source), - ) - return sch - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config: Hint, - ) -> Optional[tir.Schedule]: - - def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - if check_sm_version(config.arch.target.arch) < 80: - """MMA Template only support sm_80 and above""" - return None - - if (config.arch.target.kind.name == "cuda" and - check_sm_version(config.arch.target.arch) == 80): - return self.sch_shared_memory_prefetch_with_config(func, config) - else: - return self.sch_dequantize_in_register_with_config(func, config) diff --git a/python/bitblas/gpu/matmul_wmma.py b/python/bitblas/gpu/matmul_wmma.py deleted file mode 100644 index 60817258f..000000000 --- a/python/bitblas/gpu/matmul_wmma.py +++ /dev/null @@ -1,892 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring, invalid-name -"""A GEMM schedule rule for GPU operators.""" -from typing import Literal, Optional - -from tvm import DataType, tir -from tvm.target import Target - -from ..base.roller.rasterization import NoRasterization -from ..base import analysis -from .base import GPUScheduleRule -from .matmul_analysis import ( - auto_inline_consumer_chain, - auto_inline_producers, - get_index_map, - get_reduction_blocks, - normalize_to_matmul, -) - - -class MatmulTensorizationWMMA(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - block_m = 128 - block_n = 128 - block_k = 32 - - # tensor core intrinsic size - micro_size_m = 16 - micro_size_n = 16 - micro_size_k = 16 - - thread_z = 2 - thread_y = 2 - warp_size = 32 - - vector_size = 8 - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - - # # Step 2.1 Swizzle for l2, for better performance on inputs exceeding l2 size - # # Get input shape - batch, i, j, k = sch.get_loops(main_block) - # input_b, input_m, input_n, input_k = [sch.get(loop).extent for loop in [batch, i, j, k]] - - # # Get input/output dtype - dtype_a, dtype_b = [DataType(region.buffer.dtype) for region in sch.get(main_block).reads] - dtype_c = DataType(sch.get(main_block).writes[0].buffer.dtype) - # dtype_a_bytes, dtype_b_bytes = [math.ceil(d.bits / 8) for d in [dtype_a, dtype_b]] - - # # Get l2 size - # l2_size = target.l2_cache_size_bytes - - # # Analyse swizzle factor - # def get_swizzle_factor(l2_size, input_k, dtype_bytes, input_spatial, block_size): - # if l2_size != 0 and isinstance(input_k, (int, tir.IntImm)): - # # div by 3: suppose the two inputs and the output uses the same amount of l2 - # swizzle_factor = l2_size / 3 / int(input_k) / dtype_bytes / block_size - # # optimization: try find the best swizzle factor (aka the least additional padding) - # if isinstance(input_spatial, (int, tir.IntImm)): - # block_cnt = math.ceil(int(input_spatial) / block_size) - # swizzle_factor = math.ceil(block_cnt / math.ceil(block_cnt / swizzle_factor)) - # else: - # swizzle_factor = math.floor(swizzle_factor) - # return [None, swizzle_factor] - # else: - # return [4, None] - - # swizzle_factor_m = get_swizzle_factor(l2_size, input_k, dtype_a_bytes, input_m, block_m) - # swizzle_factor_n = get_swizzle_factor(l2_size, input_k, dtype_b_bytes, input_n, block_n) - - swizzle_factor_m = [4, None] - swizzle_factor_n = [4, None] - - # Step 2.2 Add padding - sch.pad_einsum( - main_block, - [ - 1, - (swizzle_factor_m[0] or swizzle_factor_m[1]) * block_m, - (swizzle_factor_n[0] or swizzle_factor_n[1]) * block_n, - block_k, - ], - ) - - # Step 3. Reorder loops for tiling - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_m]) - j, j_inner = sch.split(j, factors=[None, micro_size_n]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = main_block - block_outer = sch.blockize(i_inner) - - # split factors for i, j, and k - in_wrap_block_cnt_m = block_m // thread_z // micro_size_m - in_wrap_block_cnt_n = block_n // thread_y // micro_size_n - in_wrap_block_cnt_k = block_k // micro_size_k - - i_factors = swizzle_factor_m + [thread_z, in_wrap_block_cnt_m] - j_factors = swizzle_factor_n + [thread_y, in_wrap_block_cnt_n] - k_factors = [None, in_wrap_block_cnt_k] - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, factors=k_factors) - - sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) - block_axis = sch.fuse(batch, i0, j0, i1, j1) - - sch.bind(block_axis, "blockIdx.x") - sch.bind(i2, "threadIdx.z") - sch.bind(j2, "threadIdx.y") - - # Step 4. Read to/write from shared mem, and from/to wmma fragments - def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wmma_name): - block_read = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-2:]) - - f0, f1, f2, f3, f4 = sch.split(fused, - [None, thread_z, thread_y, warp_size, vector_size]) - - sch.bind(f1, "threadIdx.z") - sch.bind(f2, "threadIdx.y") - sch.bind(f3, "threadIdx.x") - sch.vectorize(f4) - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) - - auto_inline_producers(sch, block_read) - - wmma_read = sch.cache_read(block_outer, read_buffer_idx, wmma_name) - sch.compute_at(wmma_read, k1) - - micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n - v0, v1 = sch.get_loops(wmma_read)[-2:] - sch.split(v0, factors=[None, micro_size_spatial]) - - return wmma_read - - wmma_read_a = fetch_input(block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], - "wmma.matrix_a") - wmma_read_b = fetch_input(block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], - "wmma.matrix_b") - - def store_output(block_outer, write_buffer_idx, wmma_name): - block_write = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") - sch.reverse_compute_at(block_write, block_axis) - - fused = sch.fuse(*sch.get_loops(block_write)[-2:]) - - f0, f1, f2, f3, f4 = sch.split(fused, - [None, thread_z, thread_y, warp_size, vector_size]) - - sch.bind(f1, "threadIdx.z") - sch.bind(f2, "threadIdx.y") - sch.bind(f3, "threadIdx.x") - sch.vectorize(f4) - # sch.storage_align(block_write, 0, axis=-2, factor=128, offset=16) - - auto_inline_consumer_chain(sch, block_write) - - wmma_store = sch.cache_write(block_outer, write_buffer_idx, wmma_name) - v0, v1 = sch.get_loops(wmma_store)[-2:] - v00, v01, v02 = sch.split(v0, factors=[thread_z, None, micro_size_m]) - v10, v11, v12 = sch.split(v1, factors=[thread_y, None, micro_size_n]) - sch.reorder(v00, v10, v01, v11, v02, v12) - sch.bind(v00, "threadIdx.z") - sch.bind(v10, "threadIdx.y") - return wmma_store - - wmma_store = store_output(block_outer, 0, "wmma.accumulator") - - block_init = sch.decompose_reduction(block_outer, k0) - block_init_inner = sch.get_child_blocks(block_init)[0] - - # unroll k - sch.unroll(k0) - - # Step 5. Schedule tensor core computation - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) - - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype=str(dtype_a), - out_dtype=str(dtype_c), - trans_b=True, - ) - - sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(wmma_read_a)[-2], intrin_group["load_a"]) - sch.tensorize(sch.get_loops(wmma_read_b)[-2], intrin_group["load_b"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - sch.tensorize(sch.get_loops(wmma_store)[-2], intrin_group["store"]) - - return sch - - -class MatmulInt8Tensorization(GPUScheduleRule): - """ - The schedule rule for int8 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - vector_size = 4 - - i_factors, j_factors, k_factors = ( - [None, 1, 4, 2], - [1, None, 4, 2], - [None, 1], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) - sch.annotate(block_read, "double_buffer_scope", 0) - return block_read - - a_g2s = fetch_to_shared(block_outer, 0, 2) - b_g2s = fetch_to_shared(block_outer, 1, 2) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="int8", - out_dtype="int32", - trans_b=True, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except Exception: # pylint: disable=bare-except - return None - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - except Exception: # pylint: disable=bare-except - return None - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - return sch - - -class MatmulTensorizationLegacy(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - vector_size = 4 - - i_factors, j_factors, k_factors = ( - [None, 1, 4, 2], - [1, None, 4, 2], - [None, 4], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) - sch.annotate(block_read, "double_buffer_scope", 0) - return block_read - - a_g2s = fetch_to_shared(block_outer, 0, 2) - b_g2s = fetch_to_shared(block_outer, 1, 2) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="float16", - out_dtype="float32", - trans_b=True, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except Exception: # pylint: disable=bare-except - return None - - # Try to tensorize the init, store and compute block with f16 or f32 intrinsics - tensorize_success: bool = False - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - tensorize_success = True - except Exception: # pylint: disable=bare-except - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="float16", - out_dtype="float16", - trans_b=True, - ) - - if not tensorize_success: - try: - tensorize_init_store_compute() - tensorize_success = True - except Exception: # pylint: disable=bare-except - return None - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - return sch if tensorize_success else None - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - intrin_info = config.intrin_info - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - # plan rasteration - if (not isinstance(config.rasterization_plan, NoRasterization) and - sch.get(batch).extent.value == 1): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # wait for https://github.com/apache/tvm/pull/16113 to be merged - # require a solution for general block rasterization - factor = 8 # should be divisible by block_idy - if sch.get(block_idy).extent.value % factor == 0: - block_k, block_idy = sch.split(block_idy, factors=[None, factor]) - sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vec_len]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - offset: int = 0 - if dtype == "float16": - offset = 8 - elif dtype == "int8": - offset = 16 - # todo(lei): the pad value should be varied according to the data type - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=offset) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - 2, - vec_len=list(config.vectorize.values())[0], - dtype=intrin_info.in_dtype, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - 2, - vec_len=list(config.vectorize.values())[1], - dtype=intrin_info.in_dtype, - ) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_b=intrin_info.trans_b, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except Exception: # pylint: disable=bare-except - return None - - # Try to tensorize the init, store and compute block with f16 or f32 intrinsics - tensorize_success: bool = False - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - tensorize_success = True - except Exception: # pylint: disable=bare-except - return None - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split( - fused, factors=[None, warp_size, max(list(config.vectorize.values()))]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - return sch if tensorize_success else None diff --git a/python/bitblas/gpu/reduction.py b/python/bitblas/gpu/reduction.py deleted file mode 100644 index 9d6aada75..000000000 --- a/python/bitblas/gpu/reduction.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm reduction.py in dlight. -"""A rule for reduction. """ -from typing import List, Optional, Tuple, Union - -from tvm import arith, ir, tir -from tvm.target import Target - -from ..base import ( - BlockInfo, - normalize_prim_func, - try_inline_contiguous_spatial, - detect_dominant_read, - is_broadcast_epilogue, -) -from . import utils -from .base import GPUScheduleRule - - -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` - buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): - return None - if not isinstance(buffer_store.value, tir.Add): - return None - if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, - ): - return None - return buffer_store.value.b - - -class Reduction(GPUScheduleRule): - """A rule for Reduction.""" - - def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - sch = tir.Schedule(func) - block_infos = normalize_prim_func(sch) - if block_infos is None: - return None - block_infos = try_inline_contiguous_spatial(sch, block_infos) - if len(block_infos) == 1: - epilogue = None - elif len(block_infos) == 2: - epilogue = block_infos[1] - if not epilogue.is_injective(): - return None - else: - return None - - block_info = block_infos[0] - block = block_info.block_rv - block_stmt = sch.get(block) - - # Step 1. Check reduction block - if ( - (not block_info.is_reduction()) - or len(block_stmt.writes) != 1 - or _get_reduction_expr(block_stmt) is None - ): - return None - # Step 2. Normalize the block, merge spatial and reduction iters - is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize( - sch, - block_info, - arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, - ), - ) - if is_inner_reduction is None and c_factor is None: - return None - # Step 3. Do the scheduling - if is_inner_reduction: - self._sch_inner_reduction( - sch, target, block, c_factor, epilogue, loop_order, s_split_index - ) - else: - self._sch_inner_spatial( - sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index - ) - return sch - - def _normalize( # pylint: disable=too-many-branches - self, - sch: tir.Schedule, - block_info: BlockInfo, - access: arith.IterSumExpr, - ) -> Tuple[Optional[bool], Optional[int]]: - if access.base != 0: - return None, None, None, None - iter_to_info = {i.var: i for i in block_info.iters} - s_loops, r_loops, c_loops, c_factor = [], [], [], None - s_split_loop, s_split_index = None, None - for split_expr in access.args: - var = split_expr.source.source - info = iter_to_info.pop(var) - loop = info.loop_rv - is_inner_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None, None, None, None - s_split_loop = loop - s_split_index = len(s_loops) - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - c_loops.append(c_loop) - if not is_inner_reduction: - c_factor = split_expr.lower_factor - if is_inner_reduction: - r_loops.append(loop) - else: - s_loops.append(loop) - - if iter_to_info: - for var, info in iter_to_info.items(): - if info.kind == "S" and info.dom.extent == 1: - s_loops.append(info.loop_rv) - else: - return None, None, None, None - - loop_order = {} - s_block_var_loops = [] - for i in block_info.iters: - if i.loop_rv in s_loops or i.loop_rv == s_split_loop: - s_block_var_loops.append(i.loop_rv) - - for i in range(len(s_block_var_loops)): - for j in range(len(s_loops)): - if s_block_var_loops[i] == s_loops[j]: - loop_order[i] = j - break - if s_block_var_loops[i] == s_split_loop: - loop_order[i] = s_split_index - break - - assert s_loops - assert r_loops - if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]): - return None, None - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*s_loops, *r_loops, *c_loops) - sch.fuse(*s_loops) - sch.fuse(*r_loops) - return is_inner_reduction, c_factor, loop_order, s_split_index - - def _sch_inner_reduction( # pylint: disable=too-many-arguments - self, - sch: tir.Schedule, - target: Target, - block: tir.schedule.BlockRV, - unroll_spatial_factor: Optional[int], - epilogue_info: Optional[BlockInfo], - loop_order, - s_split_index, - ): - # pylint: disable=invalid-name - _, r, _ = sch.get_loops(block) - (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking - target, [sch.get(r)] - ) - - _, tx = sch.split(r, factors=[None, len_tx]) - # Schedule the RF block - rf = sch.rfactor(tx, 0) - bx, r, tx, _ = sch.get_loops(rf) - sch.reorder(bx, tx, r) - sch.bind(bx, "blockIdx.x") - sch.bind(tx, "threadIdx.x") - sch.annotate(tx, ann_key="pragma_auto_unroll_max_step", ann_val=256) - sch.annotate(tx, ann_key="pragma_unroll_explicit", ann_val=1) - sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, r) - # Schedule the write back block - sch.reverse_compute_at(block, bx, preserve_unit_loops=True) - _, tx, *s = sch.get_loops(block) - - if unroll_spatial_factor: - assert len(s) == len(loop_order) - new_order_s = [s[loop_order[i]] for i in range(len(s))] - sch.reorder(*new_order_s) - new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) - sch.reorder(*new_order_s, c) - s = sch.fuse(*new_order_s) - sch.reorder(s, tx, c) - else: - s = sch.fuse(*s) - sch.reorder(s, tx) - sch.bind(tx, "threadIdx.x") - # Schedule epilogue - if epilogue_info is not None: - epilogue = epilogue_info.block_rv - sch.reverse_compute_at(epilogue, bx) - if is_broadcast_epilogue(sch, block, epilogue): - sch.set_scope(block, 0, "shared") - _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx]) - sch.bind(tx, "threadIdx.x") - else: - sch.set_scope(block, 0, "local") - # pylint: enable=invalid-name - - def _sch_inner_spatial( - self, - sch: tir.Schedule, - _: Target, - block: tir.schedule.BlockRV, - block_info: BlockInfo, - unroll_spatial_factor: Optional[int], - epilogue_info: Optional[BlockInfo], - loop_order, - s_split_index, - ): - # pylint: disable=invalid-name - s, r, _ = sch.get_loops(block) - len_tx, len_ty = 16, 16 - s_factor = [i.dom.extent for i in block_info.iters if i.kind == "S"][-1] - # get perfect spatial factor, spatial factor should be divide the innermost spatial loop so - # that the block after r_factor and be reversed compute at the original scope - while len_tx > 1: - if s_factor % len_tx == 0: - break - len_tx -= 1 - _, _ = sch.split(s, factors=[None, len_tx]) - _, ty = sch.split(r, factors=[None, len_ty]) - # Schedule the RF block - rf = sch.rfactor(ty, 0) - bx, tx, r, ty, _ = sch.get_loops(rf) - sch.reorder(bx, tx, ty, r) - sch.bind(tx, "threadIdx.x") - sch.bind(ty, "threadIdx.y") - sch.bind(bx, "blockIdx.x") - sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, r) - # Schedule the write back block - sch.reverse_compute_at(block, bx, preserve_unit_loops=True) - _, r, *s = sch.get_loops(block) - if unroll_spatial_factor: - assert len(s) == len(loop_order) - new_order_s = [s[loop_order[i]] for i in range(len(s))] - sch.reorder(*new_order_s) - new_order_s[s_split_index], c = sch.split( - new_order_s[s_split_index], factors=[None, unroll_spatial_factor] - ) - sch.reorder(*new_order_s, c) - s = sch.fuse(*new_order_s) - sch.reorder(s, c, r) - else: - s = sch.fuse(*s) - sch.reorder(s, r) - sch.bind(s, "threadIdx.x") - sch.bind(r, "threadIdx.y") - - # Schedule epilogue - if epilogue_info is not None: - epilogue = epilogue_info.block_rv - sch.reverse_compute_at(epilogue, bx) - if is_broadcast_epilogue(sch, block, epilogue): - sch.set_scope(block, 0, "shared") - _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx, ty = sch.split(sch.fuse(*s), factors=[None, len_tx, len_ty]) - sch.bind(tx, "threadIdx.x") - sch.bind(ty, "threadIdx.y") - else: - # The epilogue is element-wise without broadcasting. - # Thus the remaining spatial part should be bind to tx. - sch.set_scope(block, 0, "local") - _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - tx, _ = sch.split(sch.fuse(*s), factors=[len_tx, None]) - sch.bind(tx, "threadIdx.x") - # pylint: enable=invalid-name diff --git a/python/bitblas/gpu/rmsnorm.py b/python/bitblas/gpu/rmsnorm.py deleted file mode 100644 index 6e6d3e247..000000000 --- a/python/bitblas/gpu/rmsnorm.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm rmsnorm.py in dlight. -# pylint: disable=missing-docstring -"""A RMS norm schedule rule for GPU operators.""" - -import tvm -from tvm import tir -from tvm.tir import Block, BufferStore -from tvm.tir.expr import Cast, BufferLoad, Call -from tvm.target import Target - -from ..base import ScheduleRule - - -def identify_cast_or_load_block(block: Block) -> bool: - if len(block.reads) != 1 or len(block.writes) != 1: - return False - - if not isinstance(block.body, BufferStore): - return False - store = block.body - - # check types - if isinstance(store.value, BufferLoad): - load = store.value - elif isinstance(store.value, Cast): - load = store.value.value - if not isinstance(load, BufferLoad): - return False - else: - return False - - # check indices - if len(load.indices) != len(store.indices): - return False - - for lhs, rhs in zip(load.indices, store.indices): - if not lhs.same_as(rhs): - return False - - return True - - -def identify_rsqrt_block(block: Block) -> bool: - if len(block.reads) != 1 or len(block.writes) != 1: - return False - - if not isinstance(block.body, BufferStore): - return False - store = block.body - - if not isinstance(store.value, Call): - return False - call = store.value - op = call.op - - return op == tvm.ir.op.Op.get("tir.rsqrt") - - -class RMSNorm(ScheduleRule): - """A rule for RMS norm.""" - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> tir.Schedule: - if target.kind.name == "cuda": - num_tx = 512 - else: - num_tx = 64 - - sch = tir.Schedule(func) - root = sch.get_block(name="root", func_name="main") - - blocks = sch.get_child_blocks(root) - - if not any([identify_rsqrt_block(sch.get(block)) for block in blocks]): - return None - - read = sch.cache_read(block=blocks[0], read_buffer_index=0, storage_scope="local") - write = sch.cache_write(block=blocks[-1], write_buffer_index=0, storage_scope="local") - - for block in blocks: - if identify_cast_or_load_block(sch.get(block)): - sch.compute_inline(block) - - blocks = sch.get_child_blocks(root) - - read, sqr, redsum, rsqrt, norm, write = blocks - - if not identify_rsqrt_block(sch.get(rsqrt)): - return None - - for name in [read, sqr, redsum, rsqrt, norm, write]: - loops = sch.get_loops(name) - sch.fuse(*loops[:-1]) - - block_loop, loops = sch.get_loops(block=read) - thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) - sch.bind(block_loop, thread_axis="blockIdx.x") - sch.bind(thread_loop, thread_axis="threadIdx.x") - sch.vectorize(sch.get_loops(block=read)[-1]) - sch.reverse_compute_at(block=sqr, loop=thread_loop) - sch.reverse_compute_at(block=redsum, loop=thread_loop) - - sch.reverse_compute_at(block=rsqrt, loop=block_loop, index=-1) - sch.reverse_compute_at(block=norm, loop=block_loop, index=-1) - block_loop, loops = sch.get_loops(block=norm) - thread_loop, _, _ = sch.split( - loop=loops, factors=[num_tx, None, 8], preserve_unit_iters=True - ) - sch.bind(thread_loop, thread_axis="threadIdx.x") - - sch.reverse_compute_at(block=write, loop=thread_loop, index=-1) - sch.vectorize(sch.get_loops(block=write)[-1]) - - sch.set_scope(block=sqr, buffer_index=0, storage_scope="local") - sch.set_scope(block=redsum, buffer_index=0, storage_scope="local") - sch.set_scope(block=rsqrt, buffer_index=0, storage_scope="shared") - sch.set_scope(block=norm, buffer_index=0, storage_scope="local") - - return sch diff --git a/python/bitblas/gpu/transpose.py b/python/bitblas/gpu/transpose.py deleted file mode 100644 index 6dc025c07..000000000 --- a/python/bitblas/gpu/transpose.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from apache/tvm transpose.py in dlight. -"""Reduction rule for operators including softmax, layer norm, RMS norm, etc""" -from typing import List, Union - -from tvm import arith, tir -from tvm.target import Target -from tvm.tir import Schedule -from tvm.tir.schedule import BlockRV - -from ..base import ( - detect_dominant_read, - normalize_prim_func, - try_inline_contiguous_spatial, -) -from .base import GPUScheduleRule - - -class Transpose(GPUScheduleRule): - """Schedule rule for transpose""" - - def is_transpose(self, sch: Schedule, block_rv: BlockRV): - block = sch.get(block_rv) - if isinstance(block.body, tir.BufferStore): - rhs = block.body.value - if isinstance(rhs, tir.BufferLoad): - lhs_indices = block.body.indices - rhs_indices = rhs.indices - if list(lhs_indices) != list(rhs_indices) and set(lhs_indices) == set(rhs_indices): - return True - return False - - def apply( # pylint: disable=too-many-locals - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Union[None, tir.Schedule, List[tir.Schedule]]: - # pylint: disable=invalid-name - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - if target.kind.name == "cuda": - len_tx = 16 - len_ty = 8 - unroll_depth = 256 - else: - len_tx = 8 - len_ty = 4 - unroll_depth = 64 - len_vec = 4 - - sch = tir.Schedule(func) - blocks = normalize_prim_func(sch) - transpose_block_idx = -1 - for idx, block in reversed(list(enumerate(blocks))): - if self.is_transpose(sch, block.block_rv): - transpose_block_idx = idx - break - if not block.is_injective(): - return None - if transpose_block_idx == -1: - return None - transpose_block = blocks[transpose_block_idx].block_rv - - prologue = None # the optional decoding block - if transpose_block_idx > 0: - spatials = try_inline_contiguous_spatial(sch, blocks[: transpose_block_idx - 1]) - assert len(spatials) == 0 - prologue = blocks[transpose_block_idx - 1].block_rv - - loops = sch.get_loops(transpose_block) - if len(loops) != 2: - # transpose with more than 2 axes is not supported - return None - - c_factor = 1 - if prologue is not None: - block_stmt = sch.get(prologue) - result = arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom.extent for i in block_stmt.iter_vars}, - ) - if len(result.args) > 0: - c_factor = int(result.args[0].lower_factor) - - i, j = loops - i, vi = sch.split(i, factors=[None, c_factor], preserve_unit_iters=True) - bi, ti = sch.split(i, factors=[None, len_ty], preserve_unit_iters=True) - bj, tj = sch.split(j, factors=[None, len_tx], preserve_unit_iters=True) - sch.reorder(bi, bj, ti, tj, vi) - sch.bind(bi, "blockIdx.y") - sch.bind(bj, "blockIdx.x") - sch.bind(ti, "threadIdx.y") - sch.bind(tj, "threadIdx.x") - len_vec = min(len_vec, c_factor) - _, vi = sch.split(vi, factors=[None, len_vec]) - if len_vec > 1: - sch.vectorize(vi) - - cache_read = sch.cache_read(transpose_block, read_buffer_index=0, storage_scope="shared") - sch.compute_at(cache_read, bj) - loops = sch.get_loops(cache_read)[2:] - fused = sch.fuse(*loops) - _, ty, tx, v = sch.split(fused, factors=[None, len_ty, len_tx, c_factor]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.unroll(v) - sch.storage_align(block=cache_read, buffer_index=0, axis=0, factor=32, offset=1) - - sch.annotate(bi, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(bi, ann_key="pragma_unroll_explicit", ann_val=1) - - if prologue is not None: - sch.compute_inline(prologue) - return sch diff --git a/python/bitblas/gpu/utils.py b/python/bitblas/gpu/utils.py deleted file mode 100644 index e3a5b6098..000000000 --- a/python/bitblas/gpu/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# pylint: disable=missing-docstring -"""Utility methods for generic GPU.""" -from typing import List, Optional - -from tvm import tir -from tvm.target import Target - - -def max_threads_per_block(target: Target) -> int: - """Get the maximum number of threads per block for a given target. - - Parameters - ---------- - target : Target - The target to get the maximum number of threads per block for. - - Returns - ------- - max_threads_per_block : int - The maximum number of threads per block for the given target. - """ - for name in ["max_threads_per_block", "max_num_threads"]: - result = target.attrs.get(name, None) - if result is not None: - return result - if target.kind.name == "cuda": - return 1024 - return 256 - - -def suggest_threads_per_block( - target: Target, - loops: List[tir.For], - max_threads_for_dynamic_loop: int = 32, -) -> List[int]: - if target.kind.name == "cuda": - threads = 1024 - elif target.kind.name == "rocm": - threads = 256 - elif target.kind.name == "metal": - threads = 256 - else: - threads = 64 - results: List[Optional[int]] = [] - dynamic: List[int] = [] - for i, loop in enumerate(loops): - loop_extent = loop.extent - if isinstance(loop_extent, tir.IntImm): - loop_extent = loop_extent.value - extent = 1 - while extent <= loop_extent and extent <= threads: - extent *= 2 - extent //= 2 - assert extent >= 1 - assert threads % extent == 0 - threads //= extent - results.append(extent) - else: - results.append(None) - dynamic.append(i) - - for i in dynamic: - extent = 1 - while extent <= max_threads_for_dynamic_loop and extent <= threads: - extent *= 2 - extent //= 2 - assert extent >= 1 - assert threads % extent == 0 - threads //= extent - results[i] = extent - - if dynamic: - results[dynamic[0]] *= threads - - return results - - -def get_sm_version(target: Target) -> int: - if target.kind.name != "cuda": - return -1 - arch = target.arch - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 diff --git a/python/bitblas/module/__init__.py b/python/bitblas/module/__init__.py deleted file mode 100644 index f353228a5..000000000 --- a/python/bitblas/module/__init__.py +++ /dev/null @@ -1,305 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import ctypes -import operator -from functools import reduce -from logging import getLogger - -import torch -import torch.nn as nn - -logger = getLogger(__name__) - -from typing import List, Union, Optional - -from bitblas.cache import global_operator_cache, get_database_path -from bitblas import Matmul, MatmulConfig -from bitblas.quantization.utils import general_compress -from bitblas import auto_detect_nvidia_target - -BITBLAS_TARGET = auto_detect_nvidia_target() -BITBLAS_DATABASE_PATH = get_database_path() - - -def unpack_qzeros(qzeros, bits): - qzeros = qzeros.view(torch.int32) - elems_per_int32 = 32 // bits - unpacked_zeros = torch.zeros( - (qzeros.shape[0], qzeros.shape[1] * elems_per_int32), - dtype=torch.int8, - device=qzeros.device, - requires_grad=False, - ) - for col in range(unpacked_zeros.shape[1]): - i = col % elems_per_int32 - unpacked_zeros[:, col] = (qzeros[:, col // elems_per_int32] >> (bits * i)) - - # Follow the instruction in AutoGPTQ qlinear_cuda_old.py line 303 - # NOTE: It appears that casting after the `unpacked_zeros + 1` is important. - return torch.bitwise_and(unpacked_zeros + 1, 2**bits - 1) - - -class Linear(nn.Module): - opt_M = [1, 16, 32, 64, 128, 256, 512] - STORAGE_DTYPE = "int8" # assume int8 storage - TORCH_STORAGE_DTYPE = getattr(torch, STORAGE_DTYPE) - BITBLAS_DTYPES = { - torch.float32: "float32", - torch.float16: "float16", - torch.half: "float16", - torch.int8: "int8", - } - - def __init__( - self, - in_features: int, - out_features: int, - bias: bool = False, - A_dtype: str = "float16", - W_dtype: str = "float16", - accum_dtype: str = "float16", - out_dtype: str = "float16", - # configs for weight only quantization - group_size: int = -1, - with_scaling: bool = None, - with_zeros: bool = False, - zeros_mode: str = None, - opt_M: Union[int, List[int]] = opt_M, - # performance related configs - enable_tuning: bool = True, - fast_decoding: Optional[bool] = None, - propagate_b: bool = False, - ): - """ - @opt_M: optimize range of the input shape for dynamic symbolic - if the input shape is a range, we will optimize the matmul with dynamic symbolic. - if the input shape is int, we will optimize the matmul with static symbolic. - """ - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.opt_M = opt_M - self.group_size = self._set_group_size(group_size, in_features) - self.torch_dtype = getattr(torch, A_dtype) - self.is_consitent = A_dtype == W_dtype - self.zeros_mode = zeros_mode - self._validate_parameters(self.group_size, in_features, out_features) - self._configure_bitblas_matmul( - A_dtype, - W_dtype, - accum_dtype, - out_dtype, - with_scaling, - with_zeros, - zeros_mode, - enable_tuning, - fast_decoding, - bias, - propagate_b, - ) - self._initialize_buffers(in_features, out_features, bias) - - def init_params(self): - # eliminate runtime overhead like exllama state - if self.is_consitent: - param_list = [self.weight] - if self.bitblas_matmul.config.with_bias: - param_list.append(self.bias) - self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] - else: - param_list = [self.qweight] - if self.bitblas_matmul.config.with_scaling: - param_list.append(self.scales) - if self.bitblas_matmul.config.with_zeros: - param_list.append(self.zeros) - if self.bitblas_matmul.config.with_bias: - param_list.append(self.bias) - self.q_params = [ctypes.c_void_p(arr.data_ptr()) for arr in param_list] - - def _validate_parameters(self, group_size, in_features, out_features): - if in_features % 16 != 0 or out_features % 16 != 0: - raise ValueError("`in_features` and `out_features` must be divisible by 16.") - if in_features % group_size != 0: - raise ValueError("`in_features` must be divisible by `group_size`.") - - def _set_group_size(self, group_size, in_features): - return in_features if (group_size == -1 or group_size is None) else group_size - - def _initialize_buffers(self, in_features, out_features, bias): - if self.consistent: - self.register_buffer( - "weight", - torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), - ) - else: - self.register_buffer( - "qweight", - torch.zeros( - self.bitblas_matmul.retrieve_weight_shape(), - dtype=self.TORCH_STORAGE_DTYPE, - ), - ) - self.register_buffer( - "scales", - torch.zeros((out_features, in_features // self.group_size), dtype=self.torch_dtype), - ) - if self.zeros_mode == "quantized": - storage_nbit = int("".join(c for c in self.STORAGE_DTYPE if c.isdigit())) - self.register_buffer( - "zeros", - torch.zeros( - ( - in_features // self.group_size, - out_features // storage_nbit * self.bits, - ), - dtype=self.TORCH_STORAGE_DTYPE, - ), - ) - else: - self.register_buffer( - "zeros", - torch.zeros( - (out_features, in_features // self.group_size), - dtype=self.torch_dtype, - ), - ) - if bias: - self.register_buffer("bias", torch.zeros((out_features), dtype=self.torch_dtype)) - else: - self.bias = None - - def _configure_bitblas_matmul( - self, - A_dtype, - W_dtype, - accum_dtype, - out_dtype, - with_scaling, - with_zeros, - zeros_mode, - enable_tuning, - fast_decoding, - bias, - propagate_b, - ): - matmul_config = MatmulConfig( - M=self.opt_M, - N=self.out_features, - K=self.in_features, - A_dtype=A_dtype, - W_dtype=W_dtype, - accum_dtype=accum_dtype, - out_dtype=out_dtype, - storage_dtype=self.STORAGE_DTYPE, - with_scaling=with_scaling, - with_zeros=with_zeros, - group_size=self.group_size, - fast_decoding=fast_decoding, - with_bias=bias, - propagate_b=propagate_b, - zeros_mode=zeros_mode, - ) - self.bitblas_matmul = self._get_or_create_bitblas_operator(matmul_config, enable_tuning) - self.bits = self.bitblas_matmul.bit - self.source_format = self.bitblas_matmul.source_format - - def _get_or_create_bitblas_operator(self, config, enable_tuning): - if global_operator_cache.size() == 0: - global_operator_cache.load_from_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) - logger.info(f"Loaded {global_operator_cache.size()} operators from database.") - - bitblas_matmul = global_operator_cache.get(config) - if bitblas_matmul is None: - # should disable tuning for the first time because we may require loading bitblas operator from database. - bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=False) - if enable_tuning: - bitblas_matmul.hardware_aware_finetune(topk=20) - global_operator_cache.add(config, bitblas_matmul) - global_operator_cache.save_into_database(BITBLAS_DATABASE_PATH, BITBLAS_TARGET) - print("BitBLAS Tuning done, appended operator to global_operator_cache.") - else: - print("BitBLAS Operator created.") - else: - print("BitBLAS Operator found in global_operator_cache.") - return bitblas_matmul - - def warmup(self, topk=20): - self.bitblas_matmul.hardware_aware_finetune(topk=topk) - - def forward(self, A, output=None): - if A.dtype != torch.float16: - A = A.half() - # can be lifted to post init. - self.init_params() - - if output is None: - output = torch.empty( - A.shape[:-1] + (self.out_features,), dtype=A.dtype, device=A.device) - m = ctypes.c_int32(reduce(operator.mul, A.shape[:-1], 1)) - A = self.bitblas_matmul.transform_input(A) - stream = torch.cuda.current_stream() - - A_void = ctypes.c_void_p(A.data_ptr()) - stream_handle = ctypes.c_void_p(stream.cuda_stream) - # m is the product of the last n - 1 dimensions of A - self.bitblas_matmul.lib.call(A_void, *self.q_params, ctypes.c_void_p(output.data_ptr()), m, - stream_handle) - - return output - - def load_and_transform_weight( - self, - weight: torch.Tensor, - scales: torch.Tensor = None, - zeros: torch.Tensor = None, - bias: torch.Tensor = None, - ): - if self.consistent: - assert scales is None, "scales should be None for consistent mode." - assert zeros is None, "zeros should be None for consistent mode." - weight = self.bitblas_matmul.transform_weight(weight) - self.weight = nn.Parameter(weight) - if bias is not None: - self.bias = bias - else: - weight = self.bitblas_matmul.transform_weight(weight) - self.qweight = weight - if scales is not None: - self.scales = scales - if zeros is not None: - self.zeros = zeros - if bias is not None: - self.bias = bias - - def repack_from_gptq(self, gptq_module): - # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. - qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) - if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(qweight.cpu()).cuda() - self.qweight = qweight - # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. - scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) - self.scales = scales - # qzeros should be dequantized to int zeros. - intzeros = unpack_qzeros(gptq_module.qzeros, self.bits).T.contiguous() - if self.bitblas_matmul.config.zeros_mode == "original": - self.zeros = intzeros.to(torch.float16).contiguous() - elif self.bitblas_matmul.config.zeros_mode == "rescale": - self.zeros[:, :] = intzeros.to(torch.float16)[:, :] * self.scales[:, :] - elif self.bitblas_matmul.config.zeros_mode == "quantized": - self.zeros = ( - torch.Tensor(general_compress(intzeros.T.contiguous().cpu().numpy(), self.bits)).to( - self.qweight.device).to(self.zeros.dtype).contiguous()) - else: - raise ValueError(f"Unsupported zeros type: {self.bitblas_matmul.config.zeros_mode}") - if self.bias is not None: - self.bias = gptq_module.bias.data.to(torch.float16).contiguous() - - @property - def consistent(self): - return self.is_consitent - - -__all__ = ["Linear"] diff --git a/python/bitblas/ops/__init__.py b/python/bitblas/ops/__init__.py deleted file mode 100644 index cdacc5bad..000000000 --- a/python/bitblas/ops/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .operator import Operator # noqa: F401 -from .matmul import Matmul, MatmulConfig # noqa: F401 -from .matmul_dequantize import MatmulWeightOnlyDequantize, MatmulWeightOnlyDequantizeConfig # noqa: F401 -from .ladder_permutate import LadderPermutate, LadderPermutateConfig # noqa: F401 -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig # noqa: F401 diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py deleted file mode 100644 index af2da3f02..000000000 --- a/python/bitblas/ops/general_matmul.py +++ /dev/null @@ -1,588 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.target import Target -import operator -from functools import reduce -from bitblas.base.roller.arch.cuda import CUDA -from typing import Any, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind, OPExecutorCPU -from .impl.matmul_dequantize_impl import ( - select_implementation as weight_dequantize_implementation,) -from .impl.matmul_impl import select_implementation as consistent_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.target_detector import auto_detect_nvidia_target -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import logging -import torch - -logger = logging.getLogger(__name__) - -WORKSPACE_SIZE = 1024 * 1024 * 256 - -# TODO(lei): This should be improved into a general -# Method to get the consistent compute patterns. -NATIVE_COMPUTE_PATTERNS = [ - # A_dtype, W_dtype - ("float64", "float64"), - ("float32", "float32"), - ("float16", "float16"), - ("int8", "int8"), - ("e4m3_float8", "e4m3_float8"), - ("e4m3_float8", "e5m2_float8"), - ("e5m2_float8", "e4m3_float8"), - ("e5m2_float8", "e5m2_float8"), -] - - -def is_native_compute(A_dtype, W_dtype) -> bool: - return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS - - -@dataclass(frozen=True) -class MatmulConfig: - M: Union[int, Tuple[int]] = None - N: int = None - K: int = None - A_dtype: str = "float16" - # is a wrapper for source_format and bit - W_dtype: str = A_dtype # W_dtype is the same as A_dtype by default - out_dtype: str = "float16" - accum_dtype: str = "float16" - layout: Literal["nn", "nt", "tn", "tt"] = "nt" - with_bias: bool = False - group_size: int = -1 - with_scaling: bool = False - with_zeros: bool = False - # documents for zeros_mode: - # original: target = (dequantize_weight - zero_point) * scale - # rescale: target = dequantize_weight * scale - zero_point - # quantized: target = (dequantize_weight - dequantize_zeros) * scale - # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. - zeros_mode: Literal["original", "rescale", "quantized"] = "original" - storage_dtype: str = "int8" - - # weight transform related flags - fast_decoding: Optional[bool] = ( - None # enable fast decoding by default, if not specified, it is enabled by a rule. - ) - propagate_a: Optional[TransformKind] = ( - None # propagate_a is a flag to control the ladder permutation. - ) - propagate_b: Optional[TransformKind] = ( - None # propagate_b is a flag to control the ladder permutation - ) - - def __legalize_dynamic_symbolic(self, M): - return tuple(self.M) if isinstance(self.M, list) else self.M - - def __legalize_propagate(self, propagate): - if isinstance(propagate, bool): - return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) - elif isinstance(propagate, int): - return TransformKind(propagate) - - return propagate - - def __initialize_propagate(self, propagate_a: Optional[TransformKind], - propagate_b: Optional[TransformKind]): - MICRO_KERNEL_SIZE = 16 - if (isinstance(self.M, int) and (self.M % MICRO_KERNEL_SIZE) == 0 and - (self.K % MICRO_KERNEL_SIZE) == 0): - object.__setattr__(self, "propagate_a", TransformKind.IntraWarpTransform) - else: - object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - - if (self.M == 1 or (self.N % MICRO_KERNEL_SIZE) != 0 or (self.K % MICRO_KERNEL_SIZE) != 0 or - isinstance(self.M, Tuple) or (self.with_zeros and self.zeros_mode == "quantized")): - object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - object.__setattr__(self, "propagate_b", TransformKind.NonTransform) - else: - object.__setattr__(self, "propagate_b", TransformKind.IntraWarpTransform) - - # set a and b value if is not None - if propagate_a is not None: - object.__setattr__(self, "propagate_a", propagate_a) - if propagate_b is not None: - object.__setattr__(self, "propagate_b", propagate_b) - - # TODO(lei): This is a limitation arose by pytorch and llvm - # Should be removed in the future. - if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: - object.__setattr__(self, "propagate_a", TransformKind.NonTransform) - object.__setattr__(self, "propagate_b", TransformKind.NonTransform) - - def __initialize_zeros_mode(self, zeros_mode: Optional[str]): - if zeros_mode is None: - object.__setattr__(self, "zeros_mode", "original") - - def __initialize_fast_decoding(self, fast_decoding: Optional[bool]): - - def is_not_fast_decoding_supported(): - conditions = [] - conditions.append("int" not in self.W_dtype) - conditions.append(self.W_dtype == self.A_dtype) - # int8,uint8 also do not implement and also do not require fast decoding - conditions.append(self.W_dtype in ["int8", "uint8"]) - return any(conditions) - - if fast_decoding is not None: - object.__setattr__(self, "fast_decoding", fast_decoding) - elif is_not_fast_decoding_supported(): - object.__setattr__(self, "fast_decoding", False) - else: - object.__setattr__(self, "fast_decoding", True) - - def __post_init__(self): - # set M to default dynamic range if it is None - if self.M is None: - object.__setattr__(self, "M", [1, 16, 32, 64, 128, 256, 512, 1024]) - if self.N is None: - raise ValueError("N should be specified currently.") - if self.K is None: - raise ValueError("K should be specified currently.") - - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", self.__legalize_dynamic_symbolic(self.M)) - - # set propagate_a and propagate_b to default value if it is None - object.__setattr__(self, "propagate_a", self.__legalize_propagate(self.propagate_a)) - object.__setattr__(self, "propagate_b", self.__legalize_propagate(self.propagate_b)) - - # This is hack to legalize propagate_a and b - # TODO(lei): should be removed in the future when tc+br template is ready. - self.__initialize_propagate(self.propagate_a, self.propagate_b) - - self.__initialize_zeros_mode(self.zeros_mode) - - self.__initialize_fast_decoding(self.fast_decoding) - - if self.with_bias is None: - object.__setattr__(self, "with_bias", False) - - if self.group_size is None: - object.__setattr__(self, "group_size", -1) - - if self.with_scaling is None: - object.__setattr__(self, "with_scaling", False) - - if self.with_zeros is None: - object.__setattr__(self, "with_zeros", False) - - if self.A_dtype == self.W_dtype and self.W_dtype in [ - "float16", - "int8", - "e4m3_float8", - "e5m2_float8", - ]: - object.__setattr__(self, "storage_dtype", self.W_dtype) - - -class Matmul(Operator): - - # TODO(lei): This should be improved into a general datatype. - BITBLAS_TRICK_DTYPE_MAP = { - "float64": ("fp", 64), - "float32": ("fp", 32), - "float16": ("fp", 16), - "int32": ("int", 32), - "uint32": ("uint", 32), - "int16": ("int", 16), - "uint16": ("uint", 16), - "int8": ("int", 8), - "uint8": ("uint", 8), - "int4": ("int", 4), - "uint4": ("uint", 4), - "int2": ("int", 2), - "uint2": ("uint", 2), - "int1": ("int", 1), - "uint1": ("uint", 1), - "nf4": ("nf", 4), - "fp4_e2m1": ("fp", 4), - "e4m3_float8": ("fp_e4m3", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" - "e5m2_float8": ("fp_e5m2", 8), - } - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Optional[Union[str, Target]] = None, - enable_tuning: bool = True, - from_database: bool = False, - ): - # if from database, we should disable default schedule - # to save compilation time - if target is None: - target = auto_detect_nvidia_target() - logger.info(f"Auto detected target: {target}") - assert (config.A_dtype - in self.BITBLAS_TRICK_DTYPE_MAP), f"Unsupported input dtype {config.A_dtype}" - source_format, bit = self.BITBLAS_TRICK_DTYPE_MAP[config.W_dtype] - - self.source_format = source_format - self.bit = bit - super().__init__(name, config, target) - - if source_format == "int" and self.with_zeros: - logger.warning( - "[BitBLAS][Warning] with_zeros is not supported for int source format as int has a constant zeropoints already." - ) - - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - self.arch = CUDA(target) - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - self.workspace = None - if self.propagate_a: - # for general purpose, we use propagate_a to control the ladder permutation. - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.A_dtype, - storage_dtype=self.A_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=target, - enable_tuning=enable_tuning, - ) - self.workspace = torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda() - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.A_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - propagate_kind="B", - transpose_matrix=self.layout == "nt", - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - if self.fast_decoding: - assert self.source_format in ["int", "uint"] - lop3_permutate_config = LOP3PermutateConfig( - M=self.N, - N=self.K, - datatype=self.A_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - ) - self.lop3_permutate = LOP3Permutate( - config=lop3_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - - weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: - weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - if source_format == "nf": - self.lut = torch.tensor( - [ - -1.0, - -0.6961928009986877, - -0.5250730514526367, - -0.39491748809814453, - -0.28444138169288635, - -0.18477343022823334, - -0.09105003625154495, - 0.0, - 0.07958029955625534, - 0.16093020141124725, - 0.24611230194568634, - 0.33791524171829224, - 0.44070982933044434, - 0.5626170039176941, - 0.7229568362236023, - 1.0, - ], - dtype=getattr(torch, self.A_dtype), - ).cuda() - else: - self.lut = None - - # output data type - self.torch_output_dtype = getattr(torch, self.out_dtype) - - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - - def _select_implementation(self): - if is_native_compute(self.A_dtype, self.W_dtype): - return consistent_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.A_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - else: - return weight_dequantize_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.A_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def transform_weight(self, weight, scale=None, zeros=None, bias=None): - """ - Transforms the given weight tensor based on the specified quantization parameters and - returns the transformed weight along with optional scale, zeros, and bias. - - Parameters: - - weight: The input weight tensor to be transformed. - - scale: Optional scaling factor for the weight tensor. - - zeros: Optional zero-point adjustment for the weight tensor. - - bias: Optional bias to be added to the weight tensor. - - Returns: - A list containing the transformed weight tensor and optionally the scale, zeros, and bias. - """ - weight = weight.contiguous() - if self.W_dtype == self.A_dtype: - if self.weight_transform is not None: - return self.weight_transform(weight.cpu()).cuda().contiguous() - return weight - - from bitblas.quantization import general_compress - import torch - import numpy as np - - source_format, bit = self.source_format, self.bit - - # Process integer source format - if source_format == "int" and bit < 8: - assert not self.with_scaling, "scale should be False for int source format" - assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2**(bit - 1) - # Clamp weight values to be within the quantizable range and adjust - weight = torch.clamp(weight, -maxq, maxq).int() + maxq - elif source_format in ["fp_e5m2", "fp_e4m3"]: - weight = weight.view(torch.int8) - weight = weight.int() - else: - # For non-integer formats, simply convert weights to integers - weight = weight.int() - - np_storage_dtype = getattr(np, self.storage_dtype) - - weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) - - weight = torch.from_numpy(weight).cuda().contiguous() - - # Apply an optional weight transformation if specified - if self.weight_transform is not None: - weight = self.weight_transform(weight.cpu()).cuda().contiguous() - - # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias - result = [weight] - if scale is not None: - result.append(scale) - if zeros is not None: - result.append(zeros) - if bias is not None: - result.append(bias) - - return next(iter(result), result) - - def transform_input(self, input_tensor): - if self.propagate_a is not TransformKind.NonTransform: - # check workspace size - if input_tensor.numel() > WORKSPACE_SIZE: - raise ValueError( - f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." - ) - self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) - return self.workspace - return input_tensor - - def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: - args = [] - args.append(self.transform_input(A)) - args.append(W) - - if self.lut is not None: - args.append(self.lut) - - if output is None: - output = torch.empty( - A.shape[:-1] + (self.N,), dtype=self.torch_output_dtype, device=A.device) - if scale is not None: - args.append(scale) - if zeros is not None: - args.append(zeros) - if bias is not None: - args.append(bias) - args.append(output) - - if self.dynamic_range is not None: - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) - - stream = torch.cuda.current_stream() - - if self.lib is None: - self._forward_from_torch_func(*args) - self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - - return output - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def A_dtype(self): - return self.config.A_dtype - - @property - def W_dtype(self): - return self.config.W_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def with_scaling(self): - return self.config.with_scaling - - @property - def with_zeros(self): - return self.config.with_zeros - - @property - def group_size(self): - return self.config.group_size - - @property - def fast_decoding(self): - return self.config.fast_decoding - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def layout(self): - return self.config.layout - - @property - def zeros_mode(self): - return self.config.zeros_mode - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None diff --git a/python/bitblas/ops/general_matmul_splitk.py b/python/bitblas/ops/general_matmul_splitk.py deleted file mode 100644 index 28e3cbbf2..000000000 --- a/python/bitblas/ops/general_matmul_splitk.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.target import Target -import operator -from functools import reduce -from typing import Any, Optional, Union -from .operator import TransformKind -from .impl.matmul_splitk_impl import select_implementation as consistent_implementation -from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation -from dataclasses import dataclass -import logging -import torch -from .general_matmul import MatmulConfig, Matmul -from .general_matmul import is_native_compute - -logger = logging.getLogger(__name__) - -WORKSPACE_SIZE = 1024 * 1024 * 256 - - -@dataclass(frozen=True) -class MatmulConfigWithSplitK(MatmulConfig): - k_split: int = 1 # split K dimension - - -class MatmulWithSplitK(Matmul): - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Optional[Union[str, Target]] = None, - enable_tuning: bool = True, - from_database: bool = False, - ): - super().__init__(config, name, target, enable_tuning, from_database) - - def _select_implementation(self): - # the major implementation - if is_native_compute(self.A_dtype, self.W_dtype): - return consistent_implementation( - SplitK=self.k_split, - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.A_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - else: - return weight_dequantize_implementation( - SplitK=self.k_split, - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.A_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def transform_weight(self, weight, scale=None, zeros=None, bias=None): - """ - Transforms the given weight tensor based on the specified quantization parameters and - returns the transformed weight along with optional scale, zeros, and bias. - - Parameters: - - weight: The input weight tensor to be transformed. - - scale: Optional scaling factor for the weight tensor. - - zeros: Optional zero-point adjustment for the weight tensor. - - bias: Optional bias to be added to the weight tensor. - - Returns: - A list containing the transformed weight tensor and optionally the scale, zeros, and bias. - """ - weight = weight.contiguous() - if self.W_dtype == self.A_dtype: - if self.weight_transform is not None: - return self.weight_transform(weight.cpu()).cuda().contiguous() - return weight - - from bitblas.quantization import general_compress - import torch - import numpy as np - - source_format, bit = self.source_format, self.bit - - # Process integer source format - if source_format == "int" and bit < 8: - assert not self.with_scaling, "scale should be False for int source format" - assert not self.with_zeros, "zeros should be False for int source format" - maxq = 2**(bit - 1) - # Clamp weight values to be within the quantizable range and adjust - weight = torch.clamp(weight, -maxq, maxq).int() + maxq - elif source_format in ["fp_e5m2", "fp_e4m3"]: - weight = weight.view(torch.int8) - weight = weight.int() - else: - # For non-integer formats, simply convert weights to integers - weight = weight.int() - - np_storage_dtype = getattr(np, self.storage_dtype) - - weight = general_compress( - weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype) - - weight = torch.from_numpy(weight).cuda().contiguous() - - # Apply an optional weight transformation if specified - if self.weight_transform is not None: - weight = self.weight_transform(weight.cpu()).cuda().contiguous() - - # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias - result = [weight] - if scale is not None: - result.append(scale) - if zeros is not None: - result.append(zeros) - if bias is not None: - result.append(bias) - - return next(iter(result), result) - - def transform_input(self, input_tensor): - if self.propagate_a is not TransformKind.NonTransform: - # check workspace size - if input_tensor.numel() > WORKSPACE_SIZE: - raise ValueError( - f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size." - ) - self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace) - return self.workspace - return input_tensor - - def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: - args = [] - args.append(self.transform_input(A)) - args.append(W) - - if self.lut is not None: - args.append(self.lut) - - if output is None: - output = torch.empty( - A.shape[:-1] + (self.N,), - dtype=self.torch_output_dtype, - device=A.device) - if scale is not None: - args.append(scale) - if zeros is not None: - args.append(zeros) - if bias is not None: - args.append(bias) - - sk_output = torch.empty((self.k_split,) + - A.shape[:-1] + (self.N,), - dtype=self.torch_output_dtype, - device=A.device) - args.append(sk_output) - - if self.dynamic_range is not None: - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) - - stream = torch.cuda.current_stream() - - if self.lib is None: - self._forward_from_torch_func(*args) - self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - torch.sum(sk_output, dim=0, out=output) - return output - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def k_split(self): - return self.config.k_split - - -__all__ = ["MatmulConfigWithSplitK", "MatmulWithSplitK"] diff --git a/python/bitblas/ops/impl/__init__.py b/python/bitblas/ops/impl/__init__.py deleted file mode 100644 index a254dc7fb..000000000 --- a/python/bitblas/ops/impl/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .lop3_permutate_impl import tir_interleave_weight diff --git a/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py b/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py deleted file mode 100644 index a3ab5ebef..000000000 --- a/python/bitblas/ops/impl/batch_matmul_dequantize_impl.py +++ /dev/null @@ -1,392 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te, DataType -from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind -from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16) - - -def matmul_nt_dequantize_b( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K // storage_nbit * bit), name="B", dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def decode_func(b, n, k): - if source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B[b, n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "int": - if bit == 1: - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B[b, n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, B[b, n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B[b, n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[b, n, k // group_size] - - return w - - B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_scaling": with_scaling, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "group_size": group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) - - -def matmul_nt_dequantize_b_propagate_b( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - transform_kind: TransformKind = TransformKind.IntraWarpTransform, -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - - _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") - target_dtype = DataType(in_dtype) - scaling_factor = 1 - if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // - target_dtype.bits) - initial_indices = inverse_indexmap.initial_indices - scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + - [initial_indices[-1] * scaling_factor]) - scaling_final_indices = scaling_final_indices[:-1] + [ - scaling_final_indices[-1] // scaling_factor - ] - inverse_indexmap = IndexMap( - initial_indices, - scaling_final_indices, - None, - ) - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - qr = r * bit // storage_nbit - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N // l, (K // scaling_factor) // qr, l, qr), - name="B", - dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((Batch, N, K // group_size), name="Scale", dtype=in_dtype) - Zeros = te.placeholder((Batch, N, K // group_size), name="Zeros", dtype=in_dtype) - Bias = te.placeholder(( - Batch, - N, - ), name="Bias", dtype=in_dtype) - - def fcompute(b, i, j): - warp_i, warp_j = i % l, j % qr - spatial_args = i // l, j // qr - if transform_kind >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) - new_index = (b, *spatial_args, warp_i, warp_j) - return B[new_index] - - B_reindex = te.compute( - (Batch, N, K // storage_nbit * bit), - fcompute, - name="B_reindex", - ) - - def decode_func(b, n, k): - if source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[b, n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "int": - if bit == 1: - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, - B_reindex[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[b, n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, - B_reindex[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, - B_reindex[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[b, n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[b, n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[b, n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[b, n, k // group_size]) * Scale[b, n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[b, n, k // group_size] - Zeros[b, n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - B_decode = te.compute((Batch, N, K), decode_func, name="B_decode") - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B_decode[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_zeros: - args.append(Zeros) - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "with_scaling": with_scaling, - "group_size": group_size, - } - }, - ) - func = func.with_attr("weight_transform_kind", transform_kind.value) - return tvm.IRModule.from_expr(func) - - -def select_implementation( - Batch=1, - M=None, - N=1024, - K=1024, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - layout="nt", - zeros_mode="original", - propagate_a=False, - propagate_b=False, -): - if layout == "nn": - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" - ) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - return matmul_nt_dequantize_b_propagate_b( - Batch, - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - transform_kind=propagate_b, - ) - else: - return matmul_nt_dequantize_b( - Batch, - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - ) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/batch_matmul_impl.py b/python/bitblas/ops/impl/batch_matmul_impl.py deleted file mode 100644 index 1828ed15d..000000000 --- a/python/bitblas/ops/impl/batch_matmul_impl.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te -from bitblas.ops.operator import TransformKind - - -def matmul_nt( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((Batch, M, K), name="A", dtype=in_dtype) - B = te.placeholder((Batch, N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (Batch, M, N), - lambda b, i, j: te.sum( - A[b, i, k].astype(accum_dtype) * B[b, j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((Batch, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((Batch, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def matmul( - Batch, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) - - -def select_implementation( - Batch=1, - M=None, - N=16384, - K=16384, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, -): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(Batch, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/convolution2d_impl.py b/python/bitblas/ops/impl/convolution2d_impl.py deleted file mode 100644 index d77d8f573..000000000 --- a/python/bitblas/ops/impl/convolution2d_impl.py +++ /dev/null @@ -1,190 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te, tir - - -def conv2d_nhwc_ohwi( - n, - f, - h, - w, - c, - kh, - kw, - s, - d, - p, - in_dtype="float16", - accum_dtype="float16", - out_dtype="float16", -): - - A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) - B = te.placeholder((f, kh, kw, c), name="weight", dtype=in_dtype) - - pad_shape = (n, h + 2 * p, w + 2 * p, c) - pad_value = tir.const(0.0, A.dtype) - pad = te.compute( - pad_shape, - lambda n, h, w, c: te.if_then_else( - tir.all( - h >= p, - w >= p, - h < pad_shape[1] - p, - w < pad_shape[2] - p, - ), - A[n, h - p, w - p, c], - pad_value, - ), - name="pad", - ) - kernel_h, kernel_w = kh, kw - stride_h, stride_w = s, s - dilation_h, dilation_w = d, d - out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 - out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 - out_shape = (n, out_h, out_w, f) - kh = te.reduce_axis((0, kernel_h), name="kh") - kw = te.reduce_axis((0, kernel_w), name="kw") - c = te.reduce_axis((0, c), name="c") - C = te.compute( - out_shape, - lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[f, kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c].astype(accum_dtype), - axis=[kh, kw, c], - ), - name="C", - ) - args = [A, B] - last_output = C - if accum_dtype != out_dtype: - D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") - last_output = D - args.append(last_output) - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def conv2d_nhwc_hwio( - n, - f, - h, - w, - c, - kh, - kw, - s, - d, - p, - in_dtype="float16", - accum_dtype="float16", - out_dtype="float16", -): - - A = te.placeholder((n, h, w, c), name="input", dtype=in_dtype) - B = te.placeholder((kh, kw, c, f), name="weight", dtype=in_dtype) - - pad_shape = (n, h + 2 * p, w + 2 * p, c) - pad_value = tir.const(0.0, A.dtype) - pad = te.compute( - pad_shape, - lambda n, h, w, c: te.if_then_else( - tir.all( - h >= p, - w >= p, - h < pad_shape[1] - p, - w < pad_shape[2] - p, - ), - A[n, h - p, w - p, c], - pad_value, - ), - name="pad", - ) - kernel_h, kernel_w = kh, kw - stride_h, stride_w = s, s - dilation_h, dilation_w = d, d - out_h = (h + 2 * p - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1 - out_w = (w + 2 * p - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1 - out_shape = (n, out_h, out_w, f) - kh = te.reduce_axis((0, kernel_h), name="kh") - kw = te.reduce_axis((0, kernel_w), name="kw") - c = te.reduce_axis((0, c), name="c") - C = te.compute( - out_shape, - lambda n, h, w, f: te.sum( - pad[n, h * stride_h + kh * tir.any(dilation_h), w * stride_w + kw * tir.any(dilation_w), - c,].astype(accum_dtype) * B[kh - 1 - tir.any(dilation_h), kw - 1 - tir.any( - dilation_w), c, f].astype(accum_dtype), - axis=[kh, kw, c], - ), - name="C", - ) - args = [A, B] - last_output = C - if accum_dtype != out_dtype: - D = te.compute(out_shape, lambda n, h, w, c: C[n, h, w, c].astype(out_dtype), name="D") - last_output = D - args.append(last_output) - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def select_implementation( - n, - f, - h, - w, - c, - kh, - kw, - s, - d, - p, - in_dtype="float16", - accum_dtype="float16", - out_dtype="float16", - input_layout="nhwc", - weight_layout="ohwi", -): - assert input_layout in ["nhwc", "nchw"] - if input_layout == "nhwc" and weight_layout == "ohwi": - return conv2d_nhwc_ohwi( - n, - f, - h, - w, - c, - kh, - kw, - s, - d, - p, - in_dtype, - accum_dtype, - out_dtype, - ) - elif input_layout == "nhwc" and weight_layout == "hwio": - return conv2d_nhwc_hwio( - n, - f, - h, - w, - c, - kh, - kw, - s, - d, - p, - in_dtype, - accum_dtype, - out_dtype, - ) - else: - raise ValueError("Unsupported input_layout: {} and weight_layout: {}".format( - input_layout, weight_layout)) diff --git a/python/bitblas/ops/impl/ladder_permutate_impl.py b/python/bitblas/ops/impl/ladder_permutate_impl.py deleted file mode 100644 index 8086bf584..000000000 --- a/python/bitblas/ops/impl/ladder_permutate_impl.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas.gpu.matmul_analysis import get_propagate_map -from typing import Literal -from tvm import te, IRModule, DataType -from tvm.tir import IndexMap - - -def select_implementation( - M: int, - N: int, - datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", - dequantize_bits: int = -1, - storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", - propagate_kind: Literal["A", "B"] = "B", - transpose_matrix: bool = False, - transform_kind: int = 0, - target_instruction: Literal["nvidia-mma"] = "nvidia-mma", -): - if target_instruction != "nvidia-mma": - raise ValueError("Currently only support nvidia-mma instruction") - - # This is trick to get the basic tile size for the current datatype - # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 - l = r = 16 # noqa: E741 - if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - intra_index_map, _ = get_propagate_map( - transpose_matrix, dtype=datatype, matrix_name=propagate_kind) - - target_dtype = DataType(datatype) - scaling_factor = 1 - if dequantize_bits > 0 and dequantize_bits < target_dtype.bits: - scaling_factor = ((target_dtype.bits // dequantize_bits) * DataType(storage_dtype).bits // - target_dtype.bits) - r = r // scaling_factor - initial_indices = intra_index_map.initial_indices - scaling_final_indices = intra_index_map.map_indices(initial_indices[:-1] + - [initial_indices[-1] * scaling_factor]) - scaling_final_indices = scaling_final_indices[:-1] + [ - scaling_final_indices[-1] // scaling_factor - ] - intra_index_map = IndexMap( - initial_indices, - scaling_final_indices, - None, - ) - - inp = te.placeholder((M, N // scaling_factor), name="inp", dtype=storage_dtype) - args = [inp] - - if transform_kind >= 1: - arg = args[-1] - - inter_warp = te.compute( - (M // l, (N // scaling_factor) // r, l, r), - lambda i, j, ii, jj: arg[i * l + ii, j * r + jj], - name="inter_warp_permutate", - ) - args.append(inter_warp) - if transform_kind >= 2: - arg = args[-1] - - def fcompute(*args): - warp_i, warp_j = args[-2:] - spatial_args = args[:-2] - permutate_i, permutate_j = intra_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, permutate_i, permutate_j) - return arg[new_index] - - intra_warp = te.compute( - (M // l, (N // scaling_factor) // r, l, r), - fcompute, - name="intra_warp_permutate", - ) - args.append(intra_warp) - args = [args[0], args[-1]] - - func = te.create_prim_func(args) - - return IRModule.from_expr(func) diff --git a/python/bitblas/ops/impl/lop3_permutate_impl.py b/python/bitblas/ops/impl/lop3_permutate_impl.py deleted file mode 100644 index 07d8f4f0c..000000000 --- a/python/bitblas/ops/impl/lop3_permutate_impl.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from typing import Literal -from tvm import DataType -from tvm import IRModule -from tvm.ir import GlobalVar -from tvm.script import tir as T - - -# fmt: off -# TIR interleave weight impl-> 2D implementation -def tir_interleave_weight( - N: int = 2, - K: int = 16, - bits: int = 4, - QK: int = -1, - target_dtype: str = "float16", - storage_dtype: str = "int32", -): - if QK == -1: - QK = K * bits // 32 - bits_stride = DataType(target_dtype).bits - mask = (1 << bits) - 1 # for 4bit the val is 0x0000000f - num_groups = 32 // bits_stride - elems_per_group = bits_stride // bits - - @T.prim_func - def interleave_weight(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), storage_dtype)): - for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): - with T.block("B"): - v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits - B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) - - @T.prim_func - def interleave_weight_f16_2b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), - storage_dtype)): - B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): - with T.block("B_tmp"): - v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits - B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) - - for ax0, ax1 in T.grid(N, QK): - with T.block("B"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xFF0000FF) - B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x00FF0000)) << 8) >> 16 - B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000FF00)) << 16) >> 8 - B[v0, v1] = B_tmp_1[v0, v1] | B_tmp_2[v0, v1] | B_tmp_3[v0, v1] - - @T.prim_func - def interleave_weight_f16_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), - storage_dtype)): - B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_6 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_7 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): - with T.block("B_tmp"): - v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits - B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) - - for ax0, ax1 in T.grid(N, QK): - with T.block("B"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF000000F) - B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 8 - B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x00000F00)) >> 8) << 16 - B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 - B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 8 - B_tmp_6[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 20) << 12 - B_tmp_7[v0, v1] = ((B[v0, v1] & T.uint32(0x00F00000)) >> 24) << 20 - B[v0, v1] = ( - B_tmp_1[v0, v1] - | B_tmp_2[v0, v1] - | B_tmp_3[v0, v1] - | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1] - | B_tmp_6[v0, v1] - | B_tmp_7[v0, v1]) - - @T.prim_func - def interleave_weight_int8_1b(A: T.Buffer((N, QK), storage_dtype), B: T.Buffer((N, QK), - storage_dtype)): - B_tmp_1 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_2 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_3 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_4 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - B_tmp_5 = T.alloc_buffer((N, QK), storage_dtype, scope="local") - for ax0, ax1, ax2, ax3 in T.grid(N, QK, num_groups, elems_per_group): - with T.block("B_tmp"): - v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) - offset = v2 * elems_per_group + v3 - shift = (offset % num_groups) * bits_stride + (offset // num_groups) * bits - B[v0, v1] = B[v0, v1] | (((A[v0, v1] >> (bits * offset)) & mask) << shift) - - for ax0, ax1 in T.grid(N, QK): - with T.block("B"): - v0, v1 = T.axis.remap("SS", [ax0, ax1]) - B_tmp_1[v0, v1] = B[v0, v1] & T.uint32(0xF0F00F0F) - B_tmp_2[v0, v1] = ((B[v0, v1] & T.uint32(0x000000F0)) >> 4) << 16 - B_tmp_3[v0, v1] = ((B[v0, v1] & T.uint32(0x0000F000)) >> 12) << 24 - B_tmp_4[v0, v1] = ((B[v0, v1] & T.uint32(0x000F0000)) >> 16) << 4 - B_tmp_5[v0, v1] = ((B[v0, v1] & T.uint32(0x0F000000)) >> 24) << 12 - B[v0, v1] = ( - B_tmp_1[v0, v1] - | B_tmp_2[v0, v1] - | B_tmp_3[v0, v1] - | B_tmp_4[v0, v1] - | B_tmp_5[v0, v1]) - - if target_dtype == "float16" and bits == 2: - return interleave_weight_f16_2b - elif target_dtype == "float16" and bits == 1: - return interleave_weight_f16_1b - elif target_dtype == "int8" and bits == 1: - return interleave_weight_int8_1b - - return interleave_weight - - -# fmt: on - - -def select_implementation( - M: int, - N: int, - datatype: Literal["float16", "int8"] = "float16", - storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32", - dequantize_bits: int = 4, -): - func = tir_interleave_weight( - N=M, - K=N, - bits=dequantize_bits, - target_dtype=datatype, - storage_dtype=storage_dtype, - ) - mod = IRModule() - mod.update_func(GlobalVar("main"), func) - return mod diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py deleted file mode 100644 index d4aa02c84..000000000 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ /dev/null @@ -1,644 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te, DataType -from tvm.tir import IndexMap -from bitblas.ops.operator import TransformKind -from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.quantization import ( - _tir_packed_int_to_int_convert, - _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, - _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16, - _tir_packed_to_unsigned_convert_with_zeros, -) - - -def matmul_nt_dequantize_b( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) - Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) - QZeros = te.placeholder(((K // group_size), N // storage_nbit * bit), - name="QZeros", - dtype=storage_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def qzeros_dequantize(k, n): - return _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=storage_dtype, - ) - - Dequantize_qzeros = None - if with_zeros and zeros_mode == "quantized": - Dequantize_qzeros = te.compute( - (K // group_size, N), - qzeros_dequantize, - name="Dequantize_zeros", - ) - - def decode_func(n, k): - if with_zeros and zeros_mode == "quantized": - assert Dequantize_qzeros is not None, "Dequantize_zeros is None" - w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( - bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - Dequantize_qzeros[k // group_size, n], - dtype=in_dtype, - ) - elif source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B[n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "int": - if bit == 1: - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B[n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - B_decode = te.compute((N, K), decode_func, name="B_decode") - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), - name="C", - ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_zeros: - if zeros_mode == "quantized": - args.append(QZeros) - else: - args.append(Zeros) - if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_scaling": with_scaling, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "group_size": group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) - - -def matmul_nt_dequantize_b_propagate_b( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - transform_kind: TransformKind = TransformKind.IntraWarpTransform, -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - - _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") - target_dtype = DataType(in_dtype) - scaling_factor = 1 - if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // - target_dtype.bits) - initial_indices = inverse_indexmap.initial_indices - scaling_final_indices = inverse_indexmap.map_indices(initial_indices[:-1] + - [initial_indices[-1] * scaling_factor]) - scaling_final_indices = scaling_final_indices[:-1] + [ - scaling_final_indices[-1] // scaling_factor - ] - inverse_indexmap = IndexMap( - initial_indices, - scaling_final_indices, - None, - ) - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - qr = r * bit // storage_nbit - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) - Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def fcompute(i, j): - warp_i, warp_j = i % l, j % qr - spatial_args = i // l, j // qr - if transform_kind >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inverse_indexmap.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return B[new_index] - - B_reindex = te.compute( - (N, K // storage_nbit * bit), - fcompute, - name="B_reindex", - ) - - def decode_func(n, k): - if source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "int": - if bit == 1: - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - B_decode = te.compute((N, K), decode_func, name="B_decode") - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), - name="C", - ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_zeros: - args.append(Zeros) - if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "with_scaling": with_scaling, - "group_size": group_size, - } - }, - ) - func = func.with_attr("weight_transform_kind", transform_kind.value) - return tvm.IRModule.from_expr(func) - - -def matmul_nt_dequantize_b_propagate_a_propagate_b( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, - transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") - A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) - - def fcompute(i, j): - warp_i, warp_j = i % l, j % r - spatial_args = i // l, j // r - if transform_kind_input >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return A[new_index] - - A_reindex = te.compute( - (M, K), - fcompute, - name="A_reindex", - ) - - _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") - target_dtype = DataType(in_dtype) - scaling_factor = 1 - if bit > 0 and bit < target_dtype.bits: - scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // - target_dtype.bits) - initial_indices = inversed_index_map.initial_indices - scaling_final_indices = inversed_index_map.map_indices( - initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) - scaling_final_indices = scaling_final_indices[:-1] + [ - scaling_final_indices[-1] // scaling_factor - ] - inversed_index_map = IndexMap( - initial_indices, - scaling_final_indices, - None, - ) - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - qr = r * bit // storage_nbit - B = te.placeholder((N // l, (K // scaling_factor) // qr, l, qr), name="B", dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) - Zeros = te.placeholder((N, K // group_size), name="Zeros", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def fcompute(i, j): - warp_i, warp_j = i % l, j % qr - spatial_args = i // l, j // qr - if transform_kind_weight >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return B[new_index] - - B_reindex = te.compute( - (N, K // storage_nbit * bit), - fcompute, - name="B_reindex", - ) - - def decode_func(n, k): - if source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "int": - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - if bit == 1: - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B_reindex[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B_reindex[n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype=in_dtype, - ) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B_reindex[n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B_reindex[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - B_decode = te.compute((N, K), decode_func, name="B_decode") - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A_reindex[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), - axis=k, - ), - name="C", - ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_zeros: - args.append(Zeros) - if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "with_scaling": with_scaling, - "group_size": group_size, - } - }, - ) - func = func.with_attr("input_transform_kind", transform_kind_input.value) - func = func.with_attr("weight_transform_kind", transform_kind_weight.value) - return tvm.IRModule.from_expr(func) - - -def select_implementation( - M=None, - N=1024, - K=1024, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - layout="nt", - zeros_mode="original", - propagate_a=False, - propagate_b=False, -): - if layout == "nn": - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" - ) - elif layout == "nt": - if propagate_a and propagate_b: - return matmul_nt_dequantize_b_propagate_a_propagate_b( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - transform_kind_input=propagate_a, - transform_kind_weight=propagate_b, - ) - elif propagate_a: - raise NotImplementedError - elif propagate_b: - return matmul_nt_dequantize_b_propagate_b( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - transform_kind=propagate_b, - ) - else: - return matmul_nt_dequantize_b( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - ) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py b/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py deleted file mode 100644 index afe241b65..000000000 --- a/python/bitblas/ops/impl/matmul_dequantize_splitk_impl.py +++ /dev/null @@ -1,184 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te -from bitblas.quantization import (_tir_packed_int_to_int_convert, _tir_packed_to_signed_convert, - _tir_packed_to_unsigned_convert, _tir_u32_to_f4_to_f16, - _tir_u8_to_f8_e4m3_to_f16) - - -def matmul_nt_dequantize_b( - SplitK, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", -): - assert bit in [1, 2, 4, 8], "Unsupported bit: {}".format(bit) - if not isinstance(M, int): - M = tvm.te.var("m") - - storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) - storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) - n_float_per_elem = storage_nbit // bit - if group_size == -1: - group_size = K - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N, K // storage_nbit * bit), name="B", dtype=storage_dtype) - LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) - Scale = te.placeholder((N, K // group_size), name="Scale", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def decode_func(n, k): - if source_format == "uint": - if bit == 8: - # 8 bit does not need to be compressed - w = B[n, k].astype(in_dtype) - else: - w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "int": - if bit == 1: - # Dequantize int1 to -1 and 1. Without this step, the values would be 0 and 1, identical to uint1. - w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif bit == 8: - # 8 bit does not need to be compressed - w = B[n, k].astype(in_dtype) - else: - w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp": - w = _tir_u32_to_f4_to_f16( - bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) - elif source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) - elif source_format == "nf": - w = LUT[_tir_packed_to_unsigned_convert(storage_type, storage_nbit)( - bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", # assume the index data type is int32 - )] - else: - raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: - return w - - if not with_zeros: - return w * Scale[n, k // group_size] - - return w - - B_decode = te.compute((N, K), decode_func, name="B_decode") - # Describe the matrix multiplication in TE - RK = K // SplitK - k = te.reduce_axis((0, RK), name="k") - C = te.compute( - (SplitK, M, N), - lambda sk, i, j: te.sum( - A[i, sk * RK + k].astype(accum_dtype) * B_decode[j, sk * RK + k].astype(accum_dtype), - axis=k), - name="C", - ) - D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - args = [A, B] - last_output = D - if source_format == "nf": - args.append(LUT) - if with_scaling: - args.append(Scale) - if with_bias: - E = te.compute((SplitK, M, N), lambda b, i, j: D[b, i, j] + Bias[j], name="E") - last_output = E - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": fast_decoding, - "source_format": { - "bits": bit, - "format": source_format, - }, - "storage_dtype": storage_dtype, - "target_format": in_dtype, - "with_scaling": with_scaling, - "with_zeros": with_zeros, - "zeros_mode": zeros_mode, - "group_size": group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) - - -def select_implementation( - SplitK=1, - M=None, - N=1024, - K=1024, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - layout="nt", - zeros_mode="original", - propagate_a=False, - propagate_b=False, -): - if layout == "nn": - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn in Dequantize Implementation" - ) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul_nt_dequantize_b( - SplitK, - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - bit, - storage_dtype, - source_format, - with_scaling, - with_zeros, - group_size, - fast_decoding, - with_bias, - zeros_mode, - ) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_impl.py b/python/bitblas/ops/impl/matmul_impl.py deleted file mode 100644 index 69b426354..000000000 --- a/python/bitblas/ops/impl/matmul_impl.py +++ /dev/null @@ -1,356 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te -from bitblas.gpu.matmul_analysis import get_propagate_map -from bitblas.ops.operator import TransformKind - - -def matmul_nn( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((K, N), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[k, j].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def matmul_nt( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum(A[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def matmul( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - return matmul_nn(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) - return matmul_nt(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) - - -def matmul_nt_propagate_a( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - transform_kind: TransformKind = TransformKind.IntraWarpTransform, -): - if not isinstance(M, int): - M = tvm.te.var("m") - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - - _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") - - A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) - B = te.placeholder((N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def fcompute(i, j): - warp_i, warp_j = i % l, j % r - spatial_args = i // l, j // r - if transform_kind >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return A[new_index] - - A_reindex = te.compute( - (M, K), - fcompute, - name="A_reindex", - ) - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A_reindex[i, k].astype(accum_dtype) * B[j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - func = func.with_attr("input_transform_kind", transform_kind.value) - - return tvm.IRModule.from_expr(func) - - -def matmul_nt_propagate_b( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - transform_kind: TransformKind = TransformKind.IntraWarpTransform, -): - if not isinstance(M, int): - M = tvm.te.var("m") - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - - _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") - - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - def fcompute(i, j): - warp_i, warp_j = i % l, j % r - spatial_args = i // l, j // r - if transform_kind >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return B[new_index] - - B_reindex = te.compute( - (N, K), - fcompute, - name="B_reindex", - ) - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - func = func.with_attr("weight_transform_kind", transform_kind.value) - - return tvm.IRModule.from_expr(func) - - -def matmul_nt_propagate_a_propagate_b( - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - transform_kind_input: TransformKind = TransformKind.IntraWarpTransform, - transform_kind_weight: TransformKind = TransformKind.IntraWarpTransform, -): - if not isinstance(M, int): - M = tvm.te.var("m") - l = r = 16 # noqa: E741 - if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - - A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) - B = te.placeholder((N // l, K // r, l, r), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") - - def fcompute(i, j): - warp_i, warp_j = i % l, j % r - spatial_args = i // l, j // r - if transform_kind_input >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return A[new_index] - - A_reindex = te.compute( - (M, K), - fcompute, - name="A_reindex", - ) - - _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") - - def fcompute(i, j): - warp_i, warp_j = i % l, j % r - spatial_args = i // l, j // r - if transform_kind_weight >= TransformKind.IntraWarpTransform: - warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) - new_index = (*spatial_args, warp_i, warp_j) - return B[new_index] - - B_reindex = te.compute( - (N, K), - fcompute, - name="B_reindex", - ) - # Describe the matrix multiplication in TE - k = te.reduce_axis((0, K), name="k") - C = te.compute( - (M, N), - lambda i, j: te.sum( - A_reindex[i, k].astype(accum_dtype) * B_reindex[j, k].astype(accum_dtype), - axis=k, - ), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - func = func.with_attr("input_transform_kind", transform_kind_input.value) - func = func.with_attr("weight_transform_kind", transform_kind_weight.value) - - return tvm.IRModule.from_expr(func) - - -def select_implementation( - M=None, - N=16384, - K=16384, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, -): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - return matmul_nt_propagate_a_propagate_b( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - transform_kind_input=propagate_a, - transform_kind_weight=propagate_b, - ) - elif propagate_a: - return matmul_nt_propagate_a( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - transform_kind=propagate_a, - ) - elif propagate_b: - return matmul_nt_propagate_b( - M, - N, - K, - in_dtype, - out_dtype, - accum_dtype, - with_bias, - transform_kind=propagate_b, - ) - else: - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/matmul_splitk_impl.py b/python/bitblas/ops/impl/matmul_splitk_impl.py deleted file mode 100644 index c437f64cb..000000000 --- a/python/bitblas/ops/impl/matmul_splitk_impl.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -# pre-transformed tir expression of matmul -import tvm -from tvm import te -from bitblas.ops.operator import TransformKind - - -def matmul_nt( - SplitK, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, -): - if not isinstance(M, int): - M = tvm.te.var("m") - A = te.placeholder((M, K), name="A", dtype=in_dtype) - B = te.placeholder((N, K), name="B", dtype=in_dtype) - Bias = te.placeholder((N,), name="Bias", dtype=in_dtype) - - # Describe the matrix multiplication in TE - RK = K // SplitK - k = te.reduce_axis((0, RK), name="k") - C = te.compute( - (SplitK, M, N), - lambda sk, i, j: te.sum( - A[i, sk * RK + k].astype(accum_dtype) * B[j, sk * RK + k].astype(accum_dtype), axis=k), - name="C", - ) - last_output = C - if accum_dtype != out_dtype: - D = te.compute((SplitK, M, N), lambda b, i, j: C[b, i, j].astype(out_dtype), name="D") - last_output = D - - if with_bias: - E = te.compute((SplitK, M, N), lambda b, i, j: last_output[b, i, j] + Bias[j], name="E") - last_output = E - - args = [A, B, Bias, last_output] if with_bias else [A, B, last_output] - - func = te.create_prim_func(args) - - return tvm.IRModule.from_expr(func) - - -def matmul( - SplitK, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", -): - if layout == "nn": - raise ValueError("Currently only support layout=nt") - return matmul_nt(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias) - - -def select_implementation( - SplitK=1, - M=None, - N=16384, - K=16384, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - with_bias=False, - layout="nt", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, -): - if layout == "nn": - if propagate_a or propagate_b: - raise ValueError( - "Currently only support propagate_a=False and propagate_b=False for layout=nn") - return matmul(M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - elif layout == "nt": - if propagate_a and propagate_b: - raise ValueError("Currently only support propagate_a or propagate_b for layout=nt") - elif propagate_a: - raise ValueError("Currently only support propagate_a=False for layout=nt") - elif propagate_b: - raise ValueError("Currently only support propagate_b=False for layout=nt") - else: - return matmul(SplitK, M, N, K, in_dtype, out_dtype, accum_dtype, with_bias, layout) - else: - raise ValueError(f"Unsupported layout: {layout}") diff --git a/python/bitblas/ops/impl/param_permutate_impl.py b/python/bitblas/ops/impl/param_permutate_impl.py deleted file mode 100644 index 4ecb17709..000000000 --- a/python/bitblas/ops/impl/param_permutate_impl.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from bitblas.gpu.matmul_analysis import get_propagate_map -from ..operator import TransformKind -from typing import Literal -from tvm import te, IRModule - - -def select_implementation( - M: int, - N: int, - datatype: Literal["float16"] = "float16", - transpose_matrix: bool = True, - group_size: int = -1, - propagate_kind: TransformKind = TransformKind.NonTransform, - target_instruction: Literal["nvidia-mma"] = "nvidia-mma", -): - if target_instruction != "nvidia-mma": - raise ValueError("Currently only support nvidia-mma instruction") - if propagate_kind < TransformKind.IntraWarpTransform: - raise ValueError("Currently only support propagate_kind >= IntraWarpTransform") - if transpose_matrix is not True: - raise ValueError("Currently only support transpose_matrix == True") - # This is trick to get the basic tile size for the current datatype - # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 - l = r = 16 # noqa: E741 - if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: - l, r = 16, 32 # noqa: E741 - if group_size == -1: - group_size = N - - intra_index_map, inverse_indexmap = get_propagate_map( - transpose_matrix, dtype=datatype, matrix_name=propagate_kind) - - inp = te.placeholder((M, N // group_size), name="inp", dtype=datatype) - - def fcompute(n, k): - rl, rr = n, k - warp_i, warp_j = rl % l, rr % r - spatial_i, spatial_j = rl // l, rr // r - if propagate_kind >= TransformKind.IntraWarpTransform: - warp_i, warp_j = intra_index_map.map_indices([warp_i, warp_j]) - new_index = (spatial_i * l + warp_i, (spatial_j * r + warp_j) // group_size) - return inp[new_index] - - inp_prmt = te.compute( - (M, N // group_size), - fcompute, - name="intra_warp_permutate", - ) - - args = [inp, inp_prmt] - - func = te.create_prim_func(args) - - return IRModule.from_expr(func) diff --git a/python/bitblas/ops/ladder_permutate.py b/python/bitblas/ops/ladder_permutate.py deleted file mode 100644 index 70999b09d..000000000 --- a/python/bitblas/ops/ladder_permutate.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.target import Target -from typing import Literal, Union -from .operator import Operator -from .impl.ladder_permutate_impl import select_implementation -from dataclasses import dataclass - - -@dataclass(frozen=True) -class LadderPermutateConfig: - M: int - N: int - datatype: Literal["int8", "e4m3_float8", "e5m2_float8"] = "float16" - dequantize_bits: int = -1 - storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16" - propagate_kind: Literal["A", "B"] = "B" # "A" or "B" - transpose_matrix: bool = False - transform_kind: int = 2 # 0: none, 1: inter_warp 2: intra_warp - target_instruction: Literal["nvidia-mma"] = ( - "nvidia-mma" # maybe extend to "cdna-mfma" in future. - ) - - -class LadderPermutate(Operator): - - def __init__( - self, - config: LadderPermutateConfig, - name: str = "permutate", - target: Union[str, Target] = "llvm", # assume to do permutation on cpu. - enable_tuning: bool = False, - from_database: bool = False, - ): - # consider to warp the arguments to MatmulConfig - super().__init__(name, config, target) - - target = self.target - if target.kind.name == "cuda": - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - if enable_tuning: - self.hardware_aware_finetune() - if not from_database: - self._build_runtime_module(target) - - # select implementation based on the Operator config - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - datatype=self.datatype, - dequantize_bits=self.dequantize_bits, - storage_dtype=self.storage_dtype, - propagate_kind=self.propagate_kind, - transpose_matrix=self.transpose_matrix, - transform_kind=self.transform_kind, - target_instruction=self.target_instruction, - ) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def datatype(self): - return self.config.datatype - - @property - def dequantize_bits(self): - return self.config.dequantize_bits - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def propagate_kind(self): - return self.config.propagate_kind - - @property - def transpose_matrix(self): - return self.config.transpose_matrix - - @property - def transform_kind(self): - return self.config.transform_kind - - @property - def target_instruction(self): - return self.config.target_instruction - - -__all__ = ["LadderPermutate", "LadderPermutateConfig"] diff --git a/python/bitblas/ops/lop3_permutate.py b/python/bitblas/ops/lop3_permutate.py deleted file mode 100644 index 867432a5e..000000000 --- a/python/bitblas/ops/lop3_permutate.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.target import Target -from typing import Literal, Union -from .operator import Operator -from .impl.lop3_permutate_impl import select_implementation -from dataclasses import dataclass -import torch - - -@dataclass(frozen=True) -class LOP3PermutateConfig: - M: int - N: int - datatype: Literal["float16", "int8"] = "float16" - storage_dtype: Literal["int8", "uint8", "int32", "uint32"] = "int32" - dequantize_bits: int = 4 - - -class LOP3Permutate(Operator): - - def __init__( - self, - config: LOP3PermutateConfig, - name: str = "permutate", - target: Union[str, Target] = "llvm", # assume to do permutation on cpu. - ): - # consider to warp the arguments to MatmulConfig - super().__init__(name, config, target) - - if target.kind.name != "llvm": - raise ValueError("Currently only support llvm target for Permutation") - - self.target = target - self._build_runtime_module(target) - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - datatype=self.datatype, - dequantize_bits=self.dequantize_bits, - ) - - def forward(self, weight, res): - # reinterpret the input tensor to int32 format - args = [arg.view(torch.int32) for arg in [weight, res]] - self.torch_func(*args) - return args[-1].view(weight.dtype) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def datatype(self): - return self.config.datatype - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def dequantize_bits(self): - return self.config.dequantize_bits - - -__all__ = ["LOP3Permutate", "LOP3PermutateConfig"] diff --git a/python/bitblas/ops/matmul.py b/python/bitblas/ops/matmul.py deleted file mode 100644 index 7783c4972..000000000 --- a/python/bitblas/ops/matmul.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -import numpy as np -from tvm.target import Target -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from typing import List, Union, Optional, Any, Tuple -from .operator import Operator, TransformKind -from .impl.matmul_impl import select_implementation -from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class TransformExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - with_bias: bool = False - # layout of matrix A and B - # "nn": C[i, j] = A[i, k] * B[k, j] - # "nt": C[i, j] = A[i, k] * B[j, k] - layout: str = "nt" - # weight transformation kind of matrix A - propagate_a: TransformKind = TransformKind.NonTransform - # weight transformation kind of matrix B - propagate_b: TransformKind = TransformKind.NonTransform - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class Matmul(Operator): - - def __init__( - self, - config: MatmulConfig, - name: str = "matmul", - target: Union[str, Target] = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.update_func(self.prim_func.with_attrs({"opt_shapes": self.dynamic_range})) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - assert (self.propagate_a is - TransformKind.NonTransform), "Currently only support NonTransform for input" - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="B", - transpose_matrix=(self.layout == "nt"), - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - input_executors = TransformExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_b) - - self.input_executors = input_executors - - weight_executors = TransformExecutorCPU() - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - with_bias=self.with_bias, - layout=self.layout, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def _profile_latency_with_dynamic_range(self) -> List: - func = self.prim_func_mod["main"] - device = self.arch.device - - def var_warpper(v, m): - if isinstance(v, tvm.tir.Var): - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return m - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - benchmark_latencies = [] - for m in self.dynamic_range["m"]: - profile_tensors = [] - for param in func.params: - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i, m) for i in arg.shape]).astype(arg.dtype), - device=device, - )) - self.profile_tensors = profile_tensors - latency = self.time_evaluator(*profile_tensors).mean * 1e3 - benchmark_latencies.append({"m": m, "latency": latency}) - # ms - return benchmark_latencies - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def layout(self): - return self.config.layout - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None - - -__all__ = ["Matmul", "MatmulConfig"] diff --git a/python/bitblas/ops/matmul_dequantize.py b/python/bitblas/ops/matmul_dequantize.py deleted file mode 100644 index 25c68b121..000000000 --- a/python/bitblas/ops/matmul_dequantize.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from tvm.target import Target -from bitblas.base.roller.arch.cuda import CUDA -from typing import Any, List, Literal, Optional, Tuple, Union -from .operator import Operator, TransformKind -from .impl.matmul_dequantize_impl import select_implementation -from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from dataclasses import dataclass -from .ladder_permutate import LadderPermutate, LadderPermutateConfig -from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig -import logging - -logger = logging.getLogger(__name__) - - -class OPExecutorCPU: - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) - - -@dataclass(frozen=True) -class MatmulWeightOnlyDequantizeConfig: - M: Union[int, Tuple[int]] - N: int - K: int - in_dtype: str = "float16" - out_dtype: str = "float16" - accum_dtype: str = "float16" - bit: int = 4 - storage_dtype: str = "int8" - # documents for source_format: - # the format of the source data, which can be "int", "uint", "fp", "nf" - # "int": dequantize_weight = (target)((int)(quantize_weight - fixed_zero_point)) * scale - # where the fixed_zero_point is 2^(bit - 1) - 1 - # "uint": dequantize_weight = (target)((uint)(quantize_weight - zero_point)) * scale - # where the zero_point is manually set by zeros tensor - # "fp": dequantize_weight = (quantize_weight - zero_point) * scale - # "nf": dequantize_weight = (lut[quantize_weight] - zero_point) * scale - source_format: Literal["int", "uint", "fp", "nf"] = "int" - with_scaling: bool = False - with_zeros: bool = False - group_size: int = -1 - fast_decoding: bool = False - with_bias: bool = False - propagate_a: TransformKind = TransformKind.NonTransform - propagate_b: TransformKind = TransformKind.NonTransform - layout: str = "nt" - # documents for zeros_mode: - # original: target = (dequantize_weight - zero_point) * scale - # rescale: target = dequantize_weight * scale - zero_point - # quantized: target = (dequantize_weight - dequantize_zeros) * scale - # The auto-gptq framework prefer "quantized" and "original" for alignment with cuda. - zeros_mode: Literal["original", "rescale", "quantized"] = "original" - - def __post_init__(self): - # set M to tuple if it is list - # otherwise, M is not hashable - object.__setattr__(self, "M", tuple(self.M) if isinstance(self.M, list) else self.M) - if isinstance(self.propagate_a, bool): - object.__setattr__( - self, - "propagate_a", - (TransformKind.IntraWarpTransform - if self.propagate_a else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_a, int): - object.__setattr__(self, "propagate_a", TransformKind(self.propagate_a)) - - if isinstance(self.propagate_b, bool): - object.__setattr__( - self, - "propagate_b", - (TransformKind.IntraWarpTransform - if self.propagate_b else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_b, int): - object.__setattr__(self, "propagate_b", TransformKind(self.propagate_b)) - - -class MatmulWeightOnlyDequantize(Operator): - - def __init__( - self, - config: MatmulWeightOnlyDequantizeConfig, - name: str = "matmul_weight_only_dequantize", - target: Target = "cuda", - enable_tuning: bool = False, - from_database: bool = False, - ): - super().__init__(name, config, target) - - target = self.target - if target.kind.name != "cuda": - raise ValueError("Currently only support cuda target") - - self.arch = CUDA(target) - - if isinstance(self.M, Tuple): - self.dynamic_range = {"m": self.M} - self.prim_func_mod["main"] = self.prim_func_mod["main"].with_attrs( - {"opt_shapes": self.dynamic_range}) - else: - self.dynamic_range = None - - if not from_database: - self._build_default_module(target) - - if self.propagate_a: - ladder_permutate_config = LadderPermutateConfig( - M=self.M, - N=self.K, - datatype=self.in_dtype, - storage_dtype=self.in_dtype, - propagate_kind="A", - transpose_matrix=False, - transform_kind=self.propagate_a, - ) - self.ladder_permutate_a = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_a = None - - if self.propagate_b: - ladder_permutate_config = LadderPermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - propagate_kind="B", - transpose_matrix=self.layout == "nt", - transform_kind=self.propagate_b, - ) - self.ladder_permutate_b = LadderPermutate( - config=ladder_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.ladder_permutate_b = None - - if self.fast_decoding: - lop3_permutate_config = LOP3PermutateConfig( - M=self.N, - N=self.K, - datatype=self.in_dtype, - dequantize_bits=self.bit, - storage_dtype=self.storage_dtype, - ) - self.lop3_permutate = LOP3Permutate( - config=lop3_permutate_config, - target=tvm.target.Target("llvm"), - ) - else: - self.lop3_permutate = None - - input_executors = OPExecutorCPU() - if self.ladder_permutate_a is not None: - input_executors.append(self.ladder_permutate_a) - self.input_executors = input_executors - - weight_executors = OPExecutorCPU() - if self.lop3_permutate is not None: - weight_executors.append(self.lop3_permutate) - - if self.ladder_permutate_b is not None: - weight_executors.append(self.ladder_permutate_b) - - self.weight_executors = weight_executors - - if enable_tuning: - self.hardware_aware_finetune() - - def _build_default_module(self, target: Target): - try: - self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target) - except Exception: - self.optimized_func = None - logger.warning( - "[BitBLAS][Warning] Apply default schedule failed, should do hardware-aware optimization manually." - ) - - self._build_runtime_module(target) - - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - K=self.K, - in_dtype=self.in_dtype, - out_dtype=self.out_dtype, - accum_dtype=self.accum_dtype, - bit=self.bit, - storage_dtype=self.storage_dtype, - source_format=self.source_format, - with_scaling=self.with_scaling, - with_zeros=self.with_zeros, - group_size=self.group_size, - fast_decoding=self.fast_decoding, - with_bias=self.with_bias, - layout=self.layout, - zeros_mode=self.zeros_mode, - propagate_a=self.propagate_a, - propagate_b=self.propagate_b, - ) - - def post_process(self, code: str) -> str: - code = tensor_replace_dp4a(code) - code = tensor_remove_make_int4(code) - code = tensor_remove_make_int2(code) - return code - - def retrieve_weight_shape(self): - return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape] - - def forward(self, *args) -> Any: - if self.lib is None: - self._forward_from_torch_func(*args) - dynamic_symbolic = [] - if self.dynamic_range is not None: - # assume we only have one dynamic range - m = args[0].shape[0] - dynamic_symbolic.append(m) - self._forward_from_prebuild_lib(*args, *dynamic_symbolic) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def K(self): - return self.config.K - - @property - def in_dtype(self): - return self.config.in_dtype - - @property - def out_dtype(self): - return self.config.out_dtype - - @property - def accum_dtype(self): - return self.config.accum_dtype - - @property - def bit(self): - return self.config.bit - - @property - def storage_dtype(self): - return self.config.storage_dtype - - @property - def source_format(self): - return self.config.source_format - - @property - def with_scaling(self): - return self.config.with_scaling - - @property - def with_zeros(self): - return self.config.with_zeros - - @property - def group_size(self): - return self.config.group_size - - @property - def fast_decoding(self): - return self.config.fast_decoding - - @property - def with_bias(self): - return self.config.with_bias - - @property - def propagate_a(self): - return self.config.propagate_a - - @property - def propagate_b(self): - return self.config.propagate_b - - @property - def layout(self): - return self.config.layout - - @property - def zeros_mode(self): - return self.config.zeros_mode - - @property - def input_transform(self): - return self.input_executors if self.input_executors.size else None - - @property - def weight_transform(self): - return self.weight_executors if self.weight_executors.size else None diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py deleted file mode 100644 index 90930d6d3..000000000 --- a/python/bitblas/ops/operator.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from abc import ABC, abstractmethod -import tvm -from tvm import IRModule -from tvm.target import Target -from tvm.tir import PrimFunc -from tvm.contrib.dlpack import to_pytorch_func -from tvm._ffi.base import _LIB, raise_last_ffi_error -from tvm._ffi._ctypes.types import TVMValue, ArgTypeCode -import bitblas -import ctypes -from typing import List, Dict, Any, Optional -import numpy as np -from ..base import fast_tune, fast_tune_with_dynamic_range -from copy import deepcopy -from bitblas.base.roller.arch import get_arch -from bitblas.utils.tensor_adapter import tvm_tensor_to_torch -from bitblas.wrapper import CUDASourceWrapper, CUDASourceWrapperWithDynamic -from dataclasses import dataclass -from enum import IntEnum -import logging - -logger = logging.getLogger(__name__) - - -class TransformKind(IntEnum): - NonTransform = 0 - InterWarpTransform = 1 - IntraWarpTransform = 2 - - -@dataclass -class OperatorConfig: - """Base class for operator configurations. Used for typing.""" - - pass - - -class Operator(ABC): - - def __init__(self, name, config: OperatorConfig, target: Target = None): - if isinstance(target, str): - target = Target(target) - self.name = name - self.config = config - self.target = target - self.prim_func_mod = self._select_implementation() - self.optimized_func = None - self.rt_mod = None - self.time_evaluator = None - self.profile_tensors = None - self.arch = get_arch(target) if target else None - self.dynamic_range = None - self.pass_context: Dict = {} - self.num_args = len(self.prim_func.params) - self.function_handle = None - self.num_output_args: int = ( - 1 # todo(lei): should be analyzed from the prim_func. - ) - self.wrapper = None - self.src_name = None - self.lib_name = None - self.lib = None - - def get_source(self, target: Target = None) -> str: - if target is None: - target = self.target - if self.rt_mod is None: - self._build_runtime_module(target) - return self.rt_mod.imported_modules[0].get_source() if self.rt_mod else None - - def _build_runtime_module(self, target: Target): - """ - Builds the runtime module based on the architecture platform. - - This function attempts to build a runtime module (rt_mod) for the specified target. - If the platform is CUDA and an optimized function is available, it tries to build - using the optimized function with a specific pass context. Otherwise, it falls back - to building with the primary function. After successful build, it initializes a - time evaluator for performance measurement. - - Args: - target (Target): The compilation target specification. - - Returns: - The compiled runtime module or None if the build was unsuccessful. - """ - - # Initialize rt_mod as None to handle cases where build fails or is skipped - rt_mod = None - - # Check if the platform is CUDA and we have an optimized function - if self.arch.platform == "CUDA": - if self.optimized_func is None: - return None - - @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) - def tvm_callback_cuda_postproc(code, _): - return self.post_process(code) - - try: - # Use a specific TVM pass context for CUDA platforms - with tvm.transform.PassContext(config={ - "tir.use_async_copy": True, - **self.pass_context - }): - rt_mod = tvm.build(self.optimized_func, target=target, name=self.name) - except Exception as e: - rt_build_error = e # noqa - logger.debug( - "Failed to build optimized function for CUDA target with default schedule, Please consider enable hardware aware tuning!" - ) - else: - # For non-CUDA platforms or when no optimized function is available, build with the primary function - rt_mod = tvm.build(self.prim_func, target=target, name=self.name) - - # If the runtime module was successfully built, set up for evaluation - if rt_mod: - self.rt_mod = rt_mod - # Initialize a time evaluator with the built module, specifying the device and the number of runs - self.time_evaluator = rt_mod.time_evaluator( - rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle - self.torch_func = to_pytorch_func(rt_mod) - if self.arch.platform == "CUDA": - try: - if (self.dynamic_range is not None and len(self.optimized_func.functions) > 1): - wrapper = CUDASourceWrapperWithDynamic(self.optimized_func, - self.get_source(target), self.arch) - else: - wrapper = CUDASourceWrapper(self.optimized_func, self.get_source(target), - self.arch) - wrapper.compile_lib() - self.wrapper = wrapper - self.src_name = self.wrapper.src_name - self.lib_name = self.wrapper.lib_name - self.lib = self.wrapper.load_lib() - self.lib.init() - except Exception as e: - build_runtime_library_error = e - logger.debug( - "Failed to build runtime library {}".format(build_runtime_library_error)) - - return rt_mod - - def apply_default_schedule(self, func_mod: IRModule, target: Target) -> IRModule: - mod_for_opt = deepcopy(func_mod) - with target: - optimized_mod = ( - bitblas.ApplyDefaultSchedule( # pylint: disable=not-callable - bitblas.gpu.Matmul(), - bitblas.gpu.GEMV(), - bitblas.gpu.Reduction(), - bitblas.gpu.GeneralReduction(), - bitblas.gpu.Fallback(), - )(mod_for_opt)) - - if optimized_mod is not None: - return optimized_mod - return None - - def post_process(self, code: str) -> str: - return code - - def apply_fast_tuning(self, - func: PrimFunc, - target: Target, - topk: int = 20, - parallel_build=True) -> IRModule: - _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) - if best is not None: - return best.sch.mod - self.pass_context = best.config.pass_context - return None - - def apply_fast_tuning_with_dynamic_range( - self, - func: PrimFunc, - target: Target, - topk: int = 20, - dynamic_range: Dict[str, List[int]] = None, - ): - optimized_mod = fast_tune_with_dynamic_range( - func, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range) - if optimized_mod is not None: - return optimized_mod - return None - - def hardware_aware_finetune(self, - topk: int = 20, - target: tvm.target.Target = None, - parallel_build=True): - if target is None: - target = self.target - dynamic_range = self.dynamic_range - func = self.prim_func - if dynamic_range is not None: - self.optimized_func = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range) - else: - self.optimized_func = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build) - self._build_runtime_module(self.target) - - def get_profile_tensors(self, dynamic_symbolic_constrains: Optional[Dict] = None): - if dynamic_symbolic_constrains is None: - dynamic_symbolic_constrains = {} - func = self.prim_func - device = self.arch.device - - def var_warpper(v): - if isinstance(v, tvm.tir.Var): - if v.name in dynamic_symbolic_constrains: - return dynamic_symbolic_constrains[v.name] - assert "opt_shapes" in func.attrs - assert v.name in func.attrs["opt_shapes"] - return func.attrs["opt_shapes"][v.name].value - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - def map_numpy_type(intype): - typemap = { - 'e4m3_float8': 'float8_e4m3fn', - 'e5m2_float8': 'float8_e5m2', - } - if intype in typemap: - return typemap[intype] - else: - return intype - - profile_tensors = [] - for param in func.params: - if param not in func.buffer_map: - # in case of dynamic symbolic may in params - continue - arg = func.buffer_map[param] - numpy_dtype = map_numpy_type(arg.dtype) - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, - [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), - device=device, - )) - self.profile_tensors = profile_tensors - return profile_tensors - - def profile_latency(self, dynamic_symbolic_constrains: Optional[Dict] = None) -> str: - if dynamic_symbolic_constrains is None: - dynamic_symbolic_constrains = {} - profile_tensors = self.get_profile_tensors(dynamic_symbolic_constrains) - latency = self.time_evaluator(*profile_tensors).mean * 1e3 - return latency - - def _tensor_adapter(self, tensor, device): - import torch - from torch.utils.dlpack import to_dlpack - - if isinstance(tensor, tvm.te.Tensor): - return tensor - elif isinstance(tensor, torch.Tensor): - return tvm.runtime.ndarray.from_dlpack(to_dlpack(tensor)) - elif isinstance(tensor, np.ndarray): - return tvm.nd.array(tensor, device=device) - else: - raise RuntimeError("Not supported type: ", type(tensor)) - - def _forward_from_tvm_args(self, *args): - _tvm_args = [self._tensor_adapter(arg, self.arch.device) for arg in args] - self.rt_mod(*_tvm_args) - - def _forward_from_tvm_nd_array(self, *args): - self.rt_mod(*args) - - def _forward_from_torch_func(self, *args): - # torch func is not reliable as some datatypes they don't support - # like float8. - self.torch_func(*args) - return args[-1] - - def forward(self, *args): - return self._forward_from_torch_func(*args) - - def _forward_from_prebuild_lib(self, *args, stream=0): - ctypes_args = [ - ctypes.c_void_p(arr.data_ptr()) if not isinstance(arr, int) else arr for arr in args - ] - ctypes_args.append(ctypes.c_void_p(stream)) - self.lib.call(*ctypes_args) - - def call_lib(self, *args, stream=0): - self.lib.call(*args, ctypes.c_void_p(stream)) - - def _forward_from_tvm_lib_func(self, values): - tcodes = (ctypes.c_int * self.num_args)() - ret_val = TVMValue() - ret_tcode = ctypes.c_int() - for i in range(self.num_args): - tcodes[i] = ArgTypeCode.NDARRAY_HANDLE - if (_LIB.TVMFuncCall( - self.function_handle, - values, - tcodes, - ctypes.c_int(self.num_args), - ctypes.byref(ret_val), - ctypes.byref(ret_tcode), - ) != 0): - raise_last_ffi_error() - - def __call__(self, *args: Any) -> Any: - return self.forward(*args) - - def update_func(self, func: PrimFunc): - self.prim_func_mod["main"] = func - - def update_runtime_module(self, rt_mod, src_name=None, lib_name=None): - self.rt_mod = rt_mod - self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) - self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle - self.torch_func = to_pytorch_func(rt_mod) - if src_name is not None: - self.src_name = src_name - if lib_name is not None: - self.lib_name = lib_name - self.lib = ctypes.CDLL(lib_name) - self.lib.init() - - @abstractmethod - def _select_implementation(self) -> IRModule: - pass - - @property - def prim_func(self): - return self.prim_func_mod["main"] - - -class OPExecutorCPU: - """ - A class to execute a sequence of operators on the CPU. - """ - - def __init__(self, operators: Optional[List[Operator]] = None): - if operators is None: - operators = [] - self.operators = operators - - def append(self, op): - self.operators.append(op) - - def is_none(self): - return len(self.operators) == 0 - - def forward(self, weight): - inputs = [weight] - for op in self.operators: - inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu()) - inputs = [op.forward(*inputs)] - return inputs[-1] - - def __call__(self, *args: Any, **kwds: Any) -> Any: - return self.forward(*args, **kwds) - - @property - def size(self): - return len(self.operators) diff --git a/python/bitblas/ops/param_permutate.py b/python/bitblas/ops/param_permutate.py deleted file mode 100644 index ca28c86eb..000000000 --- a/python/bitblas/ops/param_permutate.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.target import Target -from typing import Literal, Union -from .operator import Operator, TransformKind -from .impl.param_permutate_impl import select_implementation -from dataclasses import dataclass - - -@dataclass(frozen=True) -class ParamPermutateConfig: - M: int - N: int - datatype: Literal["float16"] = "float16" - transpose_matrix: bool = True - group_size: int = -1 - propagate_kind: TransformKind = TransformKind.NonTransform - target_instruction: Literal["nvidia-mma"] = ( - "nvidia-mma" # maybe extend to "cdna-mfma" in future. - ) - - def __post_init__(self): - if isinstance(self.propagate_kind, bool): - object.__setattr__( - self, - "propagate_kind", - (TransformKind.IntraWarpTransform - if self.propagate_kind else TransformKind.NonTransform), - ) - elif isinstance(self.propagate_kind, int): - object.__setattr__(self, "propagate_kind", TransformKind(self.propagate_kind)) - - -class ParamPermutate(Operator): - - def __init__( - self, - config: ParamPermutateConfig, - name: str = "permutate", - target: Union[str, Target] = "llvm", # assume to do permutation on cpu. - ): - super().__init__(name, config, target) - - if target.kind.name != "llvm": - raise ValueError("Currently only support llvm target for Permutation") - - self.target = target - self._build_runtime_module(target) - - # select implementation based on the Operator config - def _select_implementation(self): - return select_implementation( - M=self.M, - N=self.N, - datatype=self.datatype, - transpose_matrix=self.transpose_matrix, - group_size=self.group_size, - propagate_kind=self.propagate_kind, - target_instruction=self.target_instruction, - ) - - @property - def M(self): - return self.config.M - - @property - def N(self): - return self.config.N - - @property - def datatype(self): - return self.config.datatype - - @property - def propagate_kind(self): - return self.config.propagate_kind - - @property - def transpose_matrix(self): - return self.config.transpose_matrix - - @property - def group_size(self): - return self.config.group_size - - @property - def target_instruction(self): - return self.config.target_instruction - - -__all__ = ["ParamPermutate", "ParamPermutateConfig"] diff --git a/python/bitblas/quantization/__init__.py b/python/bitblas/quantization/__init__.py deleted file mode 100644 index d29cb679a..000000000 --- a/python/bitblas/quantization/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .quantization import ( - _tir_packed_int_to_int_convert, # noqa: F401 - _tir_packed_to_signed_convert, # noqa: F401 - _tir_packed_to_unsigned_convert, # noqa: F401 - _tir_u32_to_f4_to_f16, # noqa: F401 - _tir_u8_to_f8_e4m3_to_f16, # noqa: F401 - _tir_packed_to_unsigned_convert_with_zeros, # noqa: F401 -) - -from .utils import gen_quant4, general_compress # noqa: F401 diff --git a/python/bitblas/quantization/quantization.py b/python/bitblas/quantization/quantization.py deleted file mode 100644 index 71ef224d7..000000000 --- a/python/bitblas/quantization/quantization.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright 2018 The apache/tvm Authors. All Rights Reserved. -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -# Modifications Copyright (c) Microsoft. -# The code below is mostly copied from mlc.ai quantization.py in mlc-llm. -# pylint: disable=invalid-name,missing-function-docstring,unused-variable -"""TIR computation utilities for quantization.""" - -import tvm -from tvm import tir - - -# fmt: off -def _tir_f32x2_to_bf16x2_to_u32(v0: tir.PrimExpr, v1: tir.PrimExpr, round_to_even: bool = True): - mask = tir.const((1 << 16) - 1, "uint32") - res = [] - for data in [v0, v1]: - u32_val = tir.reinterpret("uint32", data) - if round_to_even: - rounding_bias = ((u32_val >> tir.const(16, "uint32")) - & tir.const(1, "uint32")) + tir.const(0x7FFF, "uint32") - u32_val += rounding_bias - res.append((u32_val >> tir.const(16, "uint32")) & mask) - return res[0] | (res[1] << tir.const(16, "uint32")) - - -def _tir_u32_to_bf16x2_to_f32x2(x: tir.PrimExpr): - mask = tir.const((1 << 16) - 1, "uint32") - x0 = x & mask - x1 = (x >> 16) & mask - return (tir.reinterpret("float32", x << tir.const(16, "uint32")) for x in [x0, x1]) - - -def _tir_u32_to_int_to_float(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == "uint32" - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - return tir.Cast(dtype, (val >> (pos * nbit).astype("uint32")) & mask) - - -def _tir_packed_uint_to_uint_to_float(storage_nbit: int): - storage_dtype = "uint" + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - max_int_value = (1 << (nbit - 1)) - 1 - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( - (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) - - return f_convert - - -def _tir_packed_int_to_int_to_float(storage_nbit: int): - storage_dtype = "int" + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask - return tir.Cast( - dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) - - return f_convert - - -def _tir_f32_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float32" - val_u32 = tir.reinterpret("uint32", val) - # e_f32 > 120 -> e_f4 = min(e_f32 - 120 + M_h, 7) - # e_f32 == 120 -> e_f4 = 1 - # e_f32 < 120 -> e_f4 = 0 - m_h = (val_u32 >> tir.const(22, "uint32")) & tir.const(1, "uint32") - e_f32 = (val_u32 >> tir.const(23, "uint32")) & tir.const(255, "uint32") - s = (val_u32 >> tir.const(31, "uint32")) - e_f4 = tir.Select( - e_f32 > tir.const(120, "uint32"), - tir.Min(e_f32 - tir.const(120, "uint32") + m_h, tir.const(7, "uint32")), - tir.Select(e_f32 == tir.const(120, "uint32"), tir.const(1, "uint32"), - tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 - - -def _tir_f16_to_uint_to_f4(val: tir.PrimExpr): - assert val.dtype == "float16" - val_u32 = tir.Cast("uint32", tir.reinterpret("uint16", val)) - m_h = (val_u32 >> tir.const(9, "uint32")) & tir.const(1, "uint32") - e_f16 = (val_u32 >> tir.const(10, "uint32")) & tir.const(31, "uint32") - s = (val_u32 >> tir.const(15, "uint32")) - e_f4 = tir.Select( - e_f16 > tir.const(8, "uint32"), - tir.Min(e_f16 - tir.const(8, "uint32") + m_h, tir.const(7, "uint32")), - tir.Select(e_f16 == tir.const(8, "uint32"), tir.const(1, "uint32"), tir.const(0, "uint32"))) - return (s << tir.const(3, "uint32")) | e_f4 - - -def _tir_u32_to_f4_to_f32(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert nbit == 4 - assert dtype == "float32" - assert val.dtype == "uint32" - # e_f4 == 0 -> e_f32 = 0 - # e_f4 != 0 -> e_f32 = e_f4 + 120 = e_f4 | (1111000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f32 = e_f4 | tir.const(120, "uint32") - val_f32 = tir.reinterpret("float32", - (e_f32 | (s << tir.const(8, "uint32"))) << tir.const(23, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float32"), val_f32) - - -def _tir_u32_to_f4_to_f16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert nbit == 4 - assert dtype == "float16" - assert val.dtype == "uint32" - # e_f4 == 0 -> e_f16 = 0 - # e_f4 != 0 -> e_f16 = e_f4 + 8 = e_f4 | (1000)_2 - mask = tvm.tir.const((1 << nbit) - 1, "uint32") - f4 = (val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & mask - s = f4 >> tir.const(3, "uint32") - e_f4 = f4 & tir.const(7, "uint32") - e_f16 = e_f4 | tir.const(8, "uint32") - val_f16 = tir.reinterpret("float16", - (e_f16 | (s << tir.const(5, "uint32"))) << tir.const(10, "uint32")) - return tir.Select(e_f4 == tir.const(0, "uint32"), tir.const(0, "float16"), val_f16) - - -def _tir_u8_to_f8_e4m3_to_f16_naive(nbit: int, val: tir.PrimExpr, dtype: str): - assert nbit == 8 - assert dtype == "float16" - s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") - e4 = val & tir.const(0x40, "uint16") - prefix = tir.Select(e4 == tir.const(0, "uint16"), tir.const(0x2000, "uint16"), - tir.const(0x4000, "uint16")) - e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | prefix - return tir.reinterpret("float16", s_f16 | e_f16) - - -def _tir_u8_to_f8_e4m3_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): - assert nbit == 8 - assert dtype == "float16" - s_f16 = (val >> tir.const(7, "uint16")) << tir.const(15, "uint16") - e4 = val & tir.const(0x40, "uint16") - e_f16 = (((val & tir.const(63, "uint16")) << tir.const(7, "uint16"))) | (e4 << tir.const(8, "uint16")) | (e4 << tir.const(7, "uint16")) - e_f16 = e_f16 ^ tir.const(0x2000, "uint16") - return tir.reinterpret("float16", s_f16 | e_f16) - - -def _tir_u8_to_f8_e5m2_to_f16(nbit: int, val: tir.PrimExpr, dtype: str): - assert nbit == 8 - assert dtype == "float16" - return tir.reinterpret("e5m2_float8", val).astype("float16") - - -def _tir_packed_to_signed_convert(storage_type="uint", storage_nbit=8): - storage_dtype = storage_type + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - max_int_value = (1 << (nbit - 1)) - return ((val >> (pos.astype("uint32") * tir.const(nbit, "uint32"))) & tir.const( - (1 << nbit) - 1, "uint32")).astype(dtype) - tir.const(max_int_value, dtype) - - return f_convert - - -def _tir_packed_to_unsigned_convert(storage_type="uint", storage_nbit=8): - storage_dtype = storage_type + str(storage_nbit) - - def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) - return ((val >> (pos * nbit).astype(storage_dtype)) & mask).astype(dtype) - - return f_convert - - -def _tir_packed_to_unsigned_convert_with_zeros(storage_type="uint", storage_nbit=8): - storage_dtype = storage_type + str(storage_nbit) - - def f_convert(nbit: int, val: tvm.tir.PrimExpr, pos: tvm.tir.PrimExpr, zero: tvm.tir.PrimExpr, - dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tvm.tir.const((1 << nbit) - 1, storage_dtype) - return (((val >> (pos * nbit).astype(storage_dtype)) & mask) - zero).astype(dtype) - - return f_convert - - -def _tir_packed_int_to_int_convert(storage_type="uint", storage_nbit=8): - storage_dtype = storage_type + str(storage_nbit) - - def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): - assert val.dtype == storage_dtype, f"{val.dtype} != {storage_dtype}" - mask = tir.const((1 << nbit) - 1, "int32") - unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask - return tir.Cast( - dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) - - return f_convert - - -# fmt: on diff --git a/python/bitblas/quantization/utils.py b/python/bitblas/quantization/utils.py deleted file mode 100644 index 45890c3d8..000000000 --- a/python/bitblas/quantization/utils.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import torch -import torch.nn as nn - - -def gen_quant4(k, n, groupsize=-1): - maxq = 2**4 - w = torch.randn((k, n), dtype=torch.half, device="cpu") - - original_w = w.clone() - - if groupsize == -1: - groupsize = k - - if groupsize != -1: - w = w.reshape((-1, groupsize, n)) - w = w.permute(1, 0, 2) - w = w.reshape((groupsize, -1)) - - s = torch.max(torch.abs(w), 0, keepdim=True)[0] - s *= 2 / maxq - - # Quantize. - w = torch.round(w / s).int() - - # Unsigned storage. - w += (maxq) // 2 - - w = torch.clamp(w, 0, maxq) - - # Dequantize. - ref = (w - (maxq) // 2).half() * s - - if groupsize != -1: - - def reshape(w): - w = w.reshape((groupsize, -1, n)) - w = w.permute(1, 0, 2) - w = w.reshape((k, n)).contiguous() - return w - - ref = reshape(ref) - w = reshape(w) - - s = s.reshape((-1, n)).contiguous() - linear = nn.Linear(k, n, bias=False) - linear.weight.data = ref.t() - - return original_w, linear, s, (w - (maxq) // 2) - - -def general_compress(lowprecision_weight, source_bits=4, storage_dtype=np.int8): - elems_per_byte = 8 // source_bits - if lowprecision_weight.dtype == np.float16: - lowprecision_weight = lowprecision_weight.astype(dtype=np.int8) - int8_weight = np.zeros( - ( - *lowprecision_weight.shape[:-1], - lowprecision_weight.shape[-1] // elems_per_byte, - ), - dtype=np.int8, - ) - for j in range(lowprecision_weight.shape[-1] // elems_per_byte): - for k in range(elems_per_byte): - int8_weight[:, j] |= lowprecision_weight[:, j * elems_per_byte + k] << (source_bits * k) - - return int8_weight.view(storage_dtype) - - -# interleave weight numpy implementation -def interleave_weight(qweight, nbits=4, target_dtype="float16"): - assert target_dtype in ["float16", "int8"] - # reinterpret the data type of qweight to int32 - qweight = qweight.view(np.int32) - new_qweight = np.zeros_like(qweight) - bits_stride = 8 if target_dtype == "int8" else 16 - mask = (1 << nbits) - 1 # for 4bit the val is 0x0000000f - num_groups = 32 // bits_stride - elems_per_group = bits_stride // nbits - for i in range(num_groups): - for j in range(elems_per_group): - offset = i * elems_per_group + j - shift = (offset % num_groups) * bits_stride + (offset // num_groups) * nbits - new_qweight |= ((qweight >> (nbits * offset)) & mask) << shift - - if nbits == 1 and target_dtype == "int8": - # special handling for 1b interleave - n16_weight = new_qweight & np.int32(0xF0F00F0F) - n16_weight |= ((new_qweight & np.int32(0x000000F0)) >> 4) << 16 - n16_weight |= ((new_qweight & np.int32(0x0000F000)) >> 12) << 24 - n16_weight |= ((new_qweight & np.int32(0x000F0000)) >> 16) << 4 - n16_weight |= ((new_qweight & np.int32(0x0F000000)) >> 24) << 12 - return n16_weight.view(np.int8) - elif nbits == 2 and target_dtype == "float16": - n8_weight = new_qweight & np.int32(0xFF0000FF) - n8_weight |= ((new_qweight & np.int32(0x0000FF00)) >> 8) << 16 - n8_weight |= ((new_qweight & np.int32(0x00FF0000)) >> 16) << 8 - return n8_weight.view(np.int8) - elif nbits == 1 and target_dtype == "float16": - n8_weight = new_qweight & 0xF000000F - n8_weight |= ((new_qweight & 0x000000F0) >> 4) << 8 - n8_weight |= ((new_qweight & 0x00000F00) >> 8) << 16 - n8_weight |= ((new_qweight & 0x0000F000) >> 12) << 24 - n8_weight |= ((new_qweight & 0x000F0000) >> 16) << 4 - n8_weight |= ((new_qweight & 0x00F00000) >> 20) << 12 - n8_weight |= ((new_qweight & 0x0F000000) >> 24) << 20 - - return new_qweight.view(np.int8) diff --git a/python/bitblas/relax/op/interleave_weight.py b/python/bitblas/relax/op/interleave_weight.py deleted file mode 100644 index 98b1f5cd4..000000000 --- a/python/bitblas/relax/op/interleave_weight.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from tvm.relax.block_builder import BlockBuilder -from tvm.relax.expr import Call, Expr -from tvm.relax.transform.legalize_ops.common import register_legalize - -from bitblas.ops.impl import tir_interleave_weight - - -@register_legalize("bitblas.interleave_weight") -def _interleave_weight(bb: BlockBuilder, call: Call) -> Expr: - nbits = call.attrs.nbits - target_dtype = call.attrs.target_dtype - out_dtype = call.attrs.out_dtype - - return bb.call_te( - tir_interleave_weight(nbits, target_dtype, out_dtype), - call.args[0], - primfunc_name_hint="interleave_weight", - ) - - -__all__ = ["_interleave_weight"] diff --git a/python/bitblas/relax/transform/__init__.py b/python/bitblas/relax/transform/__init__.py deleted file mode 100644 index b92f2c0b4..000000000 --- a/python/bitblas/relax/transform/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .annotate_decode_block import AnnotateDecodeInformation -from .weight_only_propagate import WeightOnlyLayoutPropagation diff --git a/python/bitblas/relax/transform/annotate_decode_block.py b/python/bitblas/relax/transform/annotate_decode_block.py deleted file mode 100644 index 601647839..000000000 --- a/python/bitblas/relax/transform/annotate_decode_block.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from typing import Dict, Tuple -from tvm.ir import IRModule -from tvm.ir.transform import PassContext, module_pass -from tvm import tir -from tvm.tir.schedule import BlockRV -from mlc_llm.quantization import quantization_schemes, GroupQuantizationSpec -from bitblas.gpu.gemv import is_gemv -from bitblas.gpu.matmul_analysis import ( - get_reduction_blocks, - get_index_map, - get_root_block, - get_dequantize_block, -) -from bitblas.base import ( - normalize_prim_func, - try_inline_contiguous_spatial, -) - - -# Define a module pass to annotate dequantization information -@module_pass(opt_level=0, name="AnnotateDecodeInformation") -class AnnotateDecodeInformation: - - def __init__(self, spec: str = "q4f16_0"): - # Validate and store the specified quantization scheme - if spec not in quantization_schemes: - raise ValueError(f"Quantization scheme {spec} not found") - self.quantize_scheme = quantization_schemes[spec] - - def detect_matmul(self, func: tir.PrimFunc) -> bool: - """Detect if the given function represents a matrix multiplication.""" - sch = tir.Schedule(func) - root_block = get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - # Identify reduction blocks to infer matmul operations - reduction_blocks = get_reduction_blocks(sch, blocks) - if not reduction_blocks: - return False - - # Check for index map patterns typical of matmul operations - main_block = reduction_blocks[0] - main_block_stmt = sch.get(main_block) - index_maps = get_index_map(main_block_stmt) - _is_matmul = index_maps is not None - - block_infos = normalize_prim_func(sch) - block_infos = try_inline_contiguous_spatial(sch, block_infos) - block_info = block_infos[0] - _is_gemv = True - if len(block_info.iters) not in [2, 3]: - # either [B, S, R] = [B, S, R] * [B, R] - # or [S, R] = [S, R] * [R] - _is_gemv = False - if _is_gemv: - _is_gemv = is_gemv(sch, block_info) - return _is_matmul or _is_gemv - - def transform_module(self, mod: IRModule, _: PassContext) -> IRModule: - """Annotate dequantize information for all applicable functions in the module.""" - for g_var, func in mod.functions.items(): - if not isinstance(func, tir.PrimFunc) or g_var.name_hint == "main": - continue - - if not self.detect_matmul(func): - continue # Process only if matmul is detected - - sch = tir.Schedule(func) - root_block = get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - dequantize_block = get_dequantize_block(sch, blocks) - if dequantize_block is None: - continue # Skip if no dequantize block is found - - # Prepare dequantize info annotation - dequantize_info = self.prepare_dequantize_info(sch, dequantize_block) - - # Annotate function with dequantize information - mod[g_var] = func.with_attr("dequantize_info", dequantize_info) - return mod - - def prepare_dequantize_info(self, sch: tir.Schedule, dequantize_block: BlockRV) -> Dict: - """Generate dequantize information for a given block.""" - block_stmt = sch.get(dequantize_block) - block_name = block_stmt.name_hint - dequantize_info = {block_name: {"decode_block": block_name, "fast_decoding": False}} - - quantize_spec = self.quantize_scheme.linear_weight - if isinstance(quantize_spec, GroupQuantizationSpec): - dequantize_info[block_name].update({ - "with_scaling": True, - "group_size": quantize_spec.group_size, - }) - - # Determine source format based on quantization mode - quantize_mod = quantize_spec.mode - bits, source_format = self.parse_quantize_mode(quantize_mod) - dequantize_info[block_name]["source_format"] = { - "bits": bits, - "format": source_format, - } - - # Set storage and target data types - storage_dtype = self.get_storage_dtype(block_stmt, source_format) - dequantize_info[block_name]["storage_dtype"] = storage_dtype - dequantize_info[block_name]["target_format"] = quantize_spec.dtype - - return dequantize_info - - def parse_quantize_mode(self, quantize_mod: str) -> Tuple[int, str]: - """Extract bits and format from quantization mode.""" - if quantize_mod.startswith("int"): - return int(quantize_mod[3:]), "int" - elif quantize_mod.startswith("uint"): - return int(quantize_mod[4:]), "uint" - raise ValueError(f"Unsupported mode {quantize_mod}") - - def get_storage_dtype(self, block_stmt: BlockRV, source_format: str) -> str: - """Determine storage data type based on source format.""" - return (block_stmt.reads[0].buffer.dtype - if "nf" not in source_format else block_stmt.reads[1].buffer.dtype) diff --git a/python/bitblas/relax/transform/weight_only_propagate.py b/python/bitblas/relax/transform/weight_only_propagate.py deleted file mode 100644 index 709e02085..000000000 --- a/python/bitblas/relax/transform/weight_only_propagate.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from typing import Optional, Tuple, Union, List, Dict -from tvm.ir import IRModule -from tvm.ir.transform import PassContext, module_pass -from tvm import relax -from tvm import tir -from enum import Enum -from tvm.ir import GlobalVar -from tvm.tir import IndexMap -from tvm.target import Target -from tvm.tir import IterVar -from tvm.tir.schedule.schedule import BlockRV -from tvm.relax import PyExprMutator -from tvm.relax.expr import Call -from bitblas.gpu.matmul_analysis import ( - get_tensorized_func_and_tags, - get_propagate_map, - find_last_producer_from_buffer, - find_arg_idx_from_buffer_chain, - layout_propagate_chain, -) -from tvm.dlight.base import ( - analysis,) -from dataclasses import dataclass - - -def get_reduction_blocks(sch, blocks) -> bool: - # Get the main computation block - def is_reduction(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.CommReduce, IterVar.DataPar} - - def is_spatial(block: BlockRV) -> bool: - block_stmt = sch.get(block) - iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} - return iter_types == {IterVar.DataPar} - - # NOTE: We assume there is only one reduction block in the function - # all blocks are required to be spatial or reduction - if not all([is_reduction(block) or is_spatial(block) for block in blocks]): - return None - - # There is only one reduction block - reduction_blocks = [block for block in blocks if is_reduction(block)] - if len(reduction_blocks) != 1: - return None - - return reduction_blocks - - -class TransformKind(Enum): - NonTransform = 0 - InterWarpTransform = 1 - IntraWarpTransform = 2 - - -def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - -def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: - """ - Detect In/Out data types for the given block based on the analysis if read/write buffers. - """ - assert len(block.reads) > 0 and len(block.writes) > 0 - in_dtype = block.reads[0].buffer.dtype - out_dtype = block.writes[0].buffer.dtype - return (in_dtype, out_dtype) - - -@dataclass -class LayoutTransformHint: - """ - A dataclass to store the layout transformation hint. - """ - - transform_level: TransformKind - inter_warp_layout: IndexMap - intra_warp_layout: IndexMap - apply_arg_idx: int - - -@module_pass(opt_level=0, name="InsertLayoutTransform") -class WeightOnlyLayoutPropagation: - - def __init__( - self, - transform_level: Union[int, TransformKind] = TransformKind.InterWarpTransform, - target: Optional[Target] = None, - faster_conversion: bool = False, - ) -> None: - if isinstance(transform_level, int): - transform_level = TransformKind(transform_level) - assert transform_level in [ - TransformKind.NonTransform, - TransformKind.InterWarpTransform, - TransformKind.IntraWarpTransform, - ] - # transform_level 1: only transform the inter-warp memory layout - # transform_level 2: transform the inter-warp memory layout and the intra-warp memory layout - self.transform_level = transform_level - self.target = Target.current() if target is None else target - # fast type conversion on nvidia gpu also requires weight permutation - self.faster_conversion = faster_conversion - # layout transform info to sync the layout in both graph and tir - self.layout_transform_hints: Dict[str, List[LayoutTransformHint]] = {} - - def detect_propagate_matmul(self, func: tir.PrimFunc, target: Target): - _, tags = get_tensorized_func_and_tags(func, target, skip_normalize=True, allow_gemv=True) - if tags is None: - return False, None - return True, tags["intrin_info"] - - def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group,) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None or len(reduction_blocks) != 1: - return False - (main_block,) = reduction_blocks - - intrin_group = get_mma_intrin_group( - load_scope="shared", - store_scope="shared", - a_dtype=intrin_info["in_dtype"], - b_dtype=intrin_info["in_dtype"], - out_dtype=intrin_info["out_dtype"], - trans_a=False, - trans_b=intrin_info["trans_b"], - ) - - _, inter_j, inter_k = intrin_group["micro_kernel"] - - # weight only propagation - target_scope = ("read", 1) - weight_buffer = sch.get(main_block).reads[1].buffer - - # checkout whether the weight buffer has dynamic symbol - def check_dynamic_symbol(buffer): - return any([isinstance(axis, tir.Var) for axis in buffer.shape]) - - if check_dynamic_symbol(weight_buffer): - print("[BitBLAS] Weight buffer has dynamic symbol, skip weight propagation.") - return False - - transformed_block = find_last_producer_from_buffer(sch, main_block, weight_buffer) - if transformed_block is None: - return False - if transformed_block != main_block: - target_scope = ("read", 0) - - reindex_block = sch.cache_read(transformed_block, target_scope[1], "global") - - # create inter-warp memory layout index map - inter_warp_layout = IndexMap.from_func( - lambda i, j: (i // inter_j, j // inter_k, i % inter_j, j % inter_k)) - - inter_warp_layout = layout_propagate_chain( - sch, - main_block, - sch.get(main_block).reads[1].buffer, - reindex_block, - inter_warp_layout, - ) - - sch.transform_layout( - reindex_block, - ("read", 0), - lambda i, j: inter_warp_layout.map_indices([i, j]), - ) - arg_idx = find_arg_idx_from_buffer_chain(sch, reindex_block, - sch.get(reindex_block).reads[0].buffer) - - intra_warp_layout = None - if self.transform_level.value >= TransformKind.IntraWarpTransform.value: - intra_warp_layout, _ = get_propagate_map(intrin_info["trans_b"]) - intra_warp_layout = layout_propagate_chain( - sch, - main_block, - sch.get(main_block).reads[1].buffer, - reindex_block, - intra_warp_layout, - ) - sch.transform_layout( - reindex_block, - ("read", 0), - lambda i, j, ii, jj: ( - i, - j, - *intra_warp_layout.map_indices([ii, jj]), - ), - ) - - self.layout_transform_hints[g_var] = [ - LayoutTransformHint( - transform_level=self.transform_level, - inter_warp_layout=inter_warp_layout, - intra_warp_layout=intra_warp_layout, - apply_arg_idx=arg_idx, - ) - ] - - return sch.mod["main"] - - def transform_module( # pylint: disable=missing-function-docstring - self, - mod: IRModule, - _: PassContext, - ) -> IRModule: - if self.target.kind.name != "cuda": - # currently weight propagation only support nvidia gpus - return mod - - propagate_candidates = {} - propagated_funcs = {} # some funcs may not be able to transform - candidates_intrin_info = {} - decoded_funcs = {} - for g_var, func in mod.functions_items(): - if not isinstance(func, tir.PrimFunc): - continue - if g_var.name_hint != "main": - # Note: this can be applied to any function which can be transformed to matmul (e.g., conv2d) - # for mlc we only consider matmul - # detect the pattern - is_matmul, intrin_info = self.detect_propagate_matmul(func, self.target) - - if (func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys()): - # currently we only support tensorize propagation - continue - - if is_matmul: - if "dequantize_info" in func.attrs: - decoded_funcs[g_var] = func - if self.transform_level != TransformKind.NonTransform: - # lift tags to the function as it has intrinsic information that can be reused. - propagate_candidates[g_var] = func - candidates_intrin_info[g_var] = intrin_info - - for g_var, func in propagate_candidates.items(): - updated_func = self.transform_matmul(g_var, func, candidates_intrin_info[g_var]) - if updated_func: - updated_func = updated_func.with_attrs({ - "transform_kind": self.transform_level.value, - "weight_transform_kind": True, - }) - propagated_funcs[g_var] = updated_func - mod[g_var] = updated_func - - @relax.expr_functor.mutator - class TensorCoreLayoutMutator(PyExprMutator): - """Mutator that performs transformation.""" - - def __init__( - self, - transform_level: TransformKind = TransformKind.NonTransform, - layout_transform_hints: Optional[Dict[str, List[LayoutTransformHint]]] = None, - ): - if layout_transform_hints is None: - layout_transform_hints = {} - super().__init__() - self.transform_level = transform_level - self.layout_transform_hints = layout_transform_hints - - def tc_layout_transform(self, call_node: Call) -> Call: - if self.transform_level == TransformKind.NonTransform: - return super().visit_call_(call_node) - g_var = call_node.args[0] - if g_var not in propagated_funcs: - return super().visit_call_(call_node) - args = list(call_node.args[1]) - # assume we only have weight propagation currently - (weight_layout_hint,) = self.layout_transform_hints[g_var] - weight = args[weight_layout_hint.apply_arg_idx] - weight = self.builder_.emit( - relax.op.layout_transform( - weight, - index_map=lambda i, j: weight_layout_hint.inter_warp_layout.map_indices( - [i, j]), - )) - if self.transform_level.value >= TransformKind.IntraWarpTransform.value: - weight = self.builder_.emit( - relax.op.layout_transform( - weight, - index_map=lambda i, j, ii, jj: ( - i, - j, - *weight_layout_hint.intra_warp_layout.map_indices([ii, jj]), - ), - )) - - call_node = self.builder_.emit( - relax.call_tir( - g_var, - args[:weight_layout_hint.apply_arg_idx] + [weight] + - args[weight_layout_hint.apply_arg_idx + 1:], - out_sinfo=call_node.struct_info, - )) - return call_node - - def visit_call_(self, call_node: Call): - return self.tc_layout_transform(call_node) - - def transform( - self, - mod: IRModule, - ): - for gv, func in mod.functions_items(): - if isinstance(func, relax.Function): - updated_func = self.visit_expr(func) - self.builder_.update_func(gv, updated_func) - new_mod = self.builder_.get() - new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod - for gv, func in new_mod.functions_items(): - mod.update_func(gv, func) - return mod - - mod = TensorCoreLayoutMutator( - transform_level=self.transform_level, - layout_transform_hints=self.layout_transform_hints, - ).transform(mod) - - @relax.expr_functor.mutator - class FastTypeConversionLayoutMutator(PyExprMutator): - """Mutator that performs transformation.""" - - def __init__(self, faster_conversion: bool = False): - super().__init__() - self.faster_conversion = faster_conversion - - def lop3_layout_transform(self, call_node: Call) -> Call: - if not self.faster_conversion: - return super().visit_call_(call_node) - - from bitblas.ops.impl import tir_interleave_weight - - g_var = call_node.args[0] - if g_var not in decoded_funcs: - return super().visit_call_(call_node) - - args = list(call_node.args[1]) - func = decoded_funcs[g_var] - if "dequantize_info" not in func.attrs: - return super().visit_call_(call_node) - dequantize_info = dict(func.attrs["dequantize_info"]) - assert len(dequantize_info) == 1 - (weight_dequantize_info,) = dequantize_info.values() - - sch = tir.Schedule(func) - dequantize_block = sch.get_block(weight_dequantize_info["decode_block"]) - - # weight is the first read buffer if format in ["int", "uint"], otherwise the second read buffer, nf .etc - source_format = weight_dequantize_info["source_format"]["format"] - source_bits = weight_dequantize_info["source_format"]["bits"] - target_dtype = weight_dequantize_info["target_format"] - - if source_format in ["int", "uint"]: - weight_buffer = sch.get(dequantize_block).reads[0].buffer - elif source_format in ["nf"]: - weight_buffer = sch.get(dequantize_block).reads[1].buffer - else: - raise ValueError(f"Unsupported source format {source_format}") - - # update func with dequantize_info - dequantize_info["fast_decoding"] = True - self.builder_.update_func(g_var, - func.with_attrs({"dequantize_info": dequantize_info})) - - weight_idx = find_arg_idx_from_buffer_chain(sch, dequantize_block, weight_buffer) - weight = args[weight_idx] - - weight_shape = weight_buffer.shape - # reshape the weight shape to 2d - reshape_weight = self.builder_.emit( - relax.op.reshape(weight, (-1, weight_shape[-1]))) - # register g_var to the func - lop3_interleave_func = tir_interleave_weight( - N=reshape_weight.struct_info.shape[0], - QK=reshape_weight.struct_info.shape[1], - bits=source_bits, - target_dtype=target_dtype, - storage_dtype=reshape_weight.struct_info.dtype, - ) - interleave_gvar = self.builder_.add_func( - lop3_interleave_func.without_attr("global_symbol"), - "tir_interleave_weight", - ) - lop3_interleave_weight = self.builder_.emit( - relax.call_tir( - interleave_gvar, - [reshape_weight], - out_sinfo=reshape_weight.struct_info, - ),) - reshape_weight = self.builder_.emit( - relax.op.reshape(lop3_interleave_weight, weight_shape)) - call_node = self.builder_.emit( - relax.call_tir( - g_var, - args[:weight_idx] + [reshape_weight] + args[weight_idx + 1:], - out_sinfo=call_node.struct_info, - ),) - - return call_node - - def visit_call_(self, call_node: Call): - return self.lop3_layout_transform(call_node) - - def transform( - self, - mod: IRModule, - ): - for gv, func in mod.functions_items(): - if isinstance(func, relax.Function): - updated_func = self.visit_expr(func) - self.builder_.update_func(gv, updated_func) - new_mod = self.builder_.get() - new_mod = new_mod.with_attrs(mod.attrs) if mod.attrs else new_mod - for gv, func in new_mod.functions_items(): - mod.update_func(gv, func) - return mod - - mod = FastTypeConversionLayoutMutator( - faster_conversion=self.faster_conversion).transform(mod) - mod = relax.transform.LegalizeOps()(mod) - return mod diff --git a/python/bitblas/testing/__init__.py b/python/bitblas/testing/__init__.py deleted file mode 100644 index 24f896bd8..000000000 --- a/python/bitblas/testing/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import sys -import inspect -import pytest -from bitblas.base import DefaultPolicy, TensorCorePolicy -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags - - -# pytest.main() wrapper to allow running single test file -def main(): - test_file = inspect.getsourcefile(sys._getframe(1)) - sys.exit(pytest.main([test_file] + sys.argv[1:])) - - -def debug_with_schedule(func, arch, sch_rule): - policy = DefaultPolicy(func=func, arch=arch) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except Exception: - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - configs = policy.emit_config(1) - return sch_rule.apply_config(func, configs[0]) diff --git a/python/bitblas/utils/__init__.py b/python/bitblas/utils/__init__.py deleted file mode 100644 index 00bddc2a5..000000000 --- a/python/bitblas/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from .post_process import match_global_kernel, tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 # noqa: F401 -from .tensor_adapter import tvm_tensor_to_torch, lazy_tvm_tensor_to_torch, lazy_torch_to_tvm_tensor # noqa: F401 -from .target_detector import get_all_nvidia_targets, auto_detect_nvidia_target # noqa: F401 diff --git a/python/bitblas/utils/post_process.py b/python/bitblas/utils/post_process.py deleted file mode 100644 index cabee6be1..000000000 --- a/python/bitblas/utils/post_process.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import re - - -def match_global_kernel(source: str) -> int: - pattern = r"__global__\s+void\s+[__launch_bounds__\(\d+\)\s+]\w+" - matched = re.findall(pattern, source) - assert len(matched) > 1 # may have statement before kernel - return source.index(matched[0]) - - -def tensor_replace_dp4a(source: str) -> str: - # as under block reduction in tir dsl, the dp4a tensorize will fail, so we should do dp4a in post processor. - # TODO(lei): this is a stuff that should be fixed in the tvm in the future - pattern = r"""for\s*\(int\s*(?P\w+)\s*=\s*0;\s*\1\s*<\s*4;\s*\+\+\1\)\s*\{\s*(?P\w+)\[0\]\s*=\s*\(\2\[0\]\s*\+\s*\(\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\s*\*\s*\(\(int\)(?P\w+)\[\(\((?P\w+)\s*\*\s*4\)\s*\+\s*\1\)\]\)\)\);\s*\}""" - replacement = (r"""\2[0] = __dp4a(*(int *)&\3[((\4 * 4))],*(int *)&\5[((\6 * 4))], \2[0]);""") - source = re.sub(pattern, replacement, source) - return source - - -def tensor_remove_make_int4(source: str) -> str: - # remove make_int4 with 16 signed char arguments - # TODO(lei): this is a stuff that should be fixed in the tvm in the future - source = source.replace( - "make_int4((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", - "make_int4(0, 0, 0, 0)", - ) - return source - -def tensor_remove_make_int2(source: str) -> str: - # remove make_int4 with 16 signed char arguments - # TODO(lei): this is a stuff that should be fixed in the tvm in the future - source = source.replace( - "make_int2((signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0, (signed char)0)", - "make_int2(0, 0)", - ) - return source diff --git a/python/bitblas/utils/target_detector.py b/python/bitblas/utils/target_detector.py deleted file mode 100644 index 71d6dcc1f..000000000 --- a/python/bitblas/utils/target_detector.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import os -import subprocess -from typing import List -from thefuzz import process -from tvm.target import Target -from tvm.target.tag import list_tags - -import logging - -logger = logging.getLogger(__name__) - -TARGET_MISSING_ERROR = ( - "TVM target not found. Please set the TVM target environment variable using `export TVM_TARGET=`, " - "where is one of the available targets can be found in the output of `tools/get_available_targets.py`." -) - -# Nvidia produces non-public oem gpu models that are part of drivers but not mapped to correct tvm target -# Remap list to match the oem model name to the closest public model name -NVIDIA_GPU_REMAP = { - "NVIDIA PG506-230": "NVIDIA A100", - "NVIDIA PG506-232": "NVIDIA A100", -} - -def get_gpu_model_from_nvidia_smi(gpu_id: int = 0): - """ - Executes the 'nvidia-smi' command to fetch the name of the first available NVIDIA GPU. - - Returns: - str: The name of the GPU, or None if 'nvidia-smi' command fails. - """ - try: - # Execute nvidia-smi command to get the GPU name - output = subprocess.check_output( - ["nvidia-smi", "--query-gpu=gpu_name", "--format=csv,noheader"], - encoding="utf-8", - ).strip() - except subprocess.CalledProcessError as e: - logger.info("nvidia-smi failed with error: %s", e) - return None - - gpus = output.split("\n") - - # for multiple gpus, CUDA_DEVICE_ORDER=PCI_BUS_ID must be set to match nvidia-smi or else wrong - # gpu is returned for gpu_id - if len(gpus) > 1 and os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID": - raise EnvironmentError("Multi-gpu environment must set `CUDA_DEVICE_ORDER=PCI_BUS_ID`.") - - if gpu_id >= len(gpus) or gpu_id < 0: - raise ValueError(f"Passed gpu_id:{gpu_id} but there are {len(gpus)} detected Nvidia gpus.") - - return gpus[gpu_id] - -def find_best_match(tags, query): - """ - Finds the best match for a query within a list of tags using fuzzy string matching. - """ - MATCH_THRESHOLD = 25 - best_match, score = process.extractOne(query, tags) - - def check_target(best, default): - return best if Target(best).arch == Target(default).arch else default - - if check_target(best_match, "cuda") == best_match: - return best_match if score >= MATCH_THRESHOLD else "cuda" - else: - logger.warning(TARGET_MISSING_ERROR) - return "cuda" - - -def get_all_nvidia_targets() -> List[str]: - """ - Returns all available NVIDIA targets. - """ - all_tags = list_tags() - return [tag for tag in all_tags if "nvidia" in tag] - - -def auto_detect_nvidia_target(gpu_id: int = 0) -> str: - """ - Automatically detects the NVIDIA GPU architecture to set the appropriate TVM target. - - Returns: - str: The detected TVM target architecture. - """ - # Return a predefined target if specified in the environment variable - # if "TVM_TARGET" in os.environ: - # return os.environ["TVM_TARGET"] - - # Fetch all available tags and filter for NVIDIA tags - all_tags = list_tags() - nvidia_tags = [tag for tag in all_tags if "nvidia" in tag] - - # Get the current GPU model and find the best matching target - gpu_model = get_gpu_model_from_nvidia_smi(gpu_id=gpu_id) - - # Compat: remap oem devices to their correct non-oem model names for tvm target - if gpu_model in NVIDIA_GPU_REMAP: - gpu_model = NVIDIA_GPU_REMAP[gpu_model] - - target = find_best_match(nvidia_tags, gpu_model) if gpu_model else "cuda" - return target diff --git a/python/bitblas/utils/tensor_adapter.py b/python/bitblas/utils/tensor_adapter.py deleted file mode 100644 index 55b80d138..000000000 --- a/python/bitblas/utils/tensor_adapter.py +++ /dev/null @@ -1,130 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from typing import Union -from enum import IntEnum -import numpy as np -import torch -from torch.utils.dlpack import from_dlpack, to_dlpack -from math import prod - -from tvm.relay import TensorType -from tvm._ffi.base import _LIB, c_str -from tvm._ffi._ctypes.types import TVMValue, check_call -from tvm._ffi.runtime_ctypes import ( - TVMArrayHandle,) -import ctypes - -TVMPyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p) -_c_str_dltensor = c_str("dltensor") -_c_str_used_dltensor = c_str("used_dltensor") - - -def get_values_from_torch_tensors(tensors, num_args): - values = (TVMValue * num_args)() - dlpack_tensors = [to_dlpack(torch_tensor) for torch_tensor in tensors] - for i, dltensor in enumerate(dlpack_tensors): - dltensor = ctypes.py_object(dltensor) - if ctypes.pythonapi.PyCapsule_IsValid(dltensor, _c_str_dltensor): - ptr = ctypes.pythonapi.PyCapsule_GetPointer(dltensor, _c_str_dltensor) - # enforce type to make sure it works for all ctypes - ptr = ctypes.cast(ptr, ctypes.c_void_p) - handle = TVMArrayHandle() - check_call(_LIB.TVMArrayFromDLPack(ptr, ctypes.byref(handle))) - # ndarray = tvm.runtime.ndarray._make_array(handle, False, False) - ctypes.pythonapi.PyCapsule_SetName(dltensor, _c_str_used_dltensor) - ctypes.pythonapi.PyCapsule_SetDestructor(dltensor, TVMPyCapsuleDestructor(0)) - values[i].v_handle = ctypes.cast(handle, ctypes.c_void_p) - else: - raise ValueError("Invalid DLTensor") - return values - - -class TensorSupplyType(IntEnum): - Integer = 1 - Uniform = 2 - Normal = 3 - Randn = 4 - Zero = 5 - One = 6 - - -def get_tensor_supply(supply_type: TensorSupplyType, opt_shapes: dict = None): - - def var_wrapper(v, opt_shapes): - if isinstance(v, tvm.tir.Var): - assert opt_shapes - assert v.name in opt_shapes - return opt_shapes[v.name] - elif isinstance(v, tvm.tir.IntImm): - return v.value - else: - raise RuntimeError("Not supported type: ", type(v)) - - def get_tensor(tensor: TensorType) -> torch.Tensor: - dtype = torch.__getattribute__(str(tensor.dtype)) - device = torch.cuda.current_device() - shape = [var_wrapper(i, opt_shapes) for i in tensor.shape] - if supply_type == TensorSupplyType.Integer: - return torch.randint(low=-2, high=3, size=shape, device=device, dtype=dtype) - elif supply_type == TensorSupplyType.Uniform: - return torch.empty(*shape, device=device, dtype=dtype).uniform_(-1.0, 1.0) - elif supply_type == TensorSupplyType.Normal: - return torch.empty(*shape, device=device, dtype=dtype).normal_(-1.0, 1.0) - elif supply_type == TensorSupplyType.Randn: - return torch.randn(*shape, device=device).to(dtype) - elif supply_type == TensorSupplyType.Zero: - return torch.zeros(*shape, device=device, dtype=dtype) - elif supply_type == TensorSupplyType.One: - return torch.ones(*shape, device=device, dtype=dtype) - else: - raise NotImplementedError(supply_type) - - return get_tensor - - -def tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): - if isinstance(tensor, tvm.te.Tensor): - return torch.from_numpy(tensor.numpy()) - elif isinstance(tensor, tvm.nd.NDArray): - return from_dlpack(tensor) - else: - raise RuntimeError("Not supported type: ", type(tensor)) - -def lazy_tvm_tensor_to_torch(tensor: Union[tvm.te.Tensor, tvm.nd.NDArray]): - # It additionally needs the ctypes type as torch type - def as_tensor(address, shape, elems_inbytes, torch_type): - arr = (ctypes.c_int8 * elems_inbytes).from_address( - address) - return torch.frombuffer(arr, dtype=torch_type).view(*shape) - - if isinstance(tensor, tvm.nd.NDArray): - np_array = tensor.asnumpy() - shape = np_array.shape - dtype = np_array.dtype - torch_dtype = getattr(torch, str(dtype)) - num_elems_inbytes = prod(shape) * np_array.itemsize - data_ptr = np_array.ctypes.data - tensor = as_tensor(data_ptr, shape, num_elems_inbytes, torch_dtype) - return tensor - else: - raise RuntimeError("Not supported type: ", type(tensor)) - -def lazy_torch_to_tvm_tensor(tensor): - # It additionally needs the ctypes type as torch type - def as_tensor(address, shape, elems_inbytes, numpy_type): - arr = (ctypes.c_int8 * elems_inbytes).from_address( - address) - return np.frombuffer(arr, dtype=numpy_type).reshape(shape) - - if isinstance(tensor, torch.Tensor): - data_ptr = tensor.data_ptr() - shape = tensor.shape - torch_dtype = tensor.dtype - numpy_dtype = str(torch_dtype).replace("torch.", "") - num_elems_inbytes = prod(shape) * tensor.itemsize - np_tensor = as_tensor(data_ptr, shape, num_elems_inbytes, numpy_dtype) - tvm_tensor = tvm.nd.array(np_tensor) - return tvm_tensor - else: - raise RuntimeError("Not supported type: ", type(tensor)) diff --git a/python/bitblas/wrapper/__init__.py b/python/bitblas/wrapper/__init__.py deleted file mode 100644 index 1d87f8020..000000000 --- a/python/bitblas/wrapper/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from .general import CUDASourceWrapper, CUDASourceWrapperWithDynamic # noqa: F401 diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py deleted file mode 100644 index 58aa8d226..000000000 --- a/python/bitblas/wrapper/general.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import tvm -from typing import Optional, List, Dict, Union -from tvm import IRModule -from bitblas import TileDevice -from tvm.runtime import ndarray -from bitblas.utils import match_global_kernel -import re -import ctypes -import os -import tempfile -import subprocess -import logging -from tvm.driver import lower -from tvm.target import Target - -logger = logging.getLogger(__name__) - -_TYPE_MAP = { - "float32": "float", - "float16": "half", - "bfloat16": "__nv_bfloat162", - "e4m3_float8": "__nv_fp8_e4m3", - "e5m2_float8": "__nv_fp8_e5m2", - "float64": "double", - "int64": "int64_t", - "int32": "int", - "uint32": "unsigned int", - "bool": "int8_t", - "int8": "int8_t", - "uint8": "uint8_t", - "int16": "int16_t", - "uchar": "uint8_t", -} - - -def get_annotated_device_mod(mod: IRModule, target: Target): - """ - Lower the given IRModule and create a device module for the specified target. - - Parameters: - - mod: The input IRModule. - - target: The compilation target. - - Returns: - - A device module ready for execution. - """ - input_mod = lower(mod) - target_input_mod = {target: input_mod} - annotated_mods = {} - runtime = None - target_host = None - for tgt, mod in target_input_mod.items(): - if not isinstance(tgt, (str, Target)): - raise ValueError("The key of inputs must be str or " - "Target when inputs is dict.") - if not isinstance(mod, tvm.IRModule): - raise ValueError("inputs must be Schedule, IRModule, " - "or dict of str to IRModule.") - annotated_mods[tgt] = mod.with_attr("runtime", runtime) - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - if not target_host: - for tar, _ in annotated_mods.items(): - device_type = ndarray.device(tar.kind.name, 0).device_type - if device_type == ndarray.cpu(0).device_type: - target_host = tar - break - if not target_host: - target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm" - annotated_mods, target_host = Target.canon_target_map_and_host(annotated_mods, target_host) - for target, mod in annotated_mods.items(): - mixed_mod_passes = tvm.get_global_func("driver.mixed_mod_passes") - device_mod_passes = tvm.get_global_func("driver.device_mod_passes") - mod = mixed_mod_passes(mod, target)(mod) - device_mod = device_mod_passes(mod, target)(mod) - return device_mod - - -def get_thread_block_information(mod: IRModule): - """ - Extracts the thread block and grid dimensions for the reduction block within a given IRModule. - - Parameters: - - mod: The input IRModule from which to extract thread block and grid information. - - Returns: - A tuple containing two lists: - - The first list contains the dimensions of the thread block (threadIdx.x, threadIdx.y, threadIdx.z). - - The second list contains the dimensions of the grid (blockIdx.x, blockIdx.y, blockIdx.z). - """ - - # Initialize the schedule from the IRModule - sch = tvm.tir.Schedule(mod) - - # Get the root block and its child blocks - root_block = sch.get_block("root") - child_blocks = sch.get_child_blocks(root_block) - - # Initialize default block and grid dimensions (1, 1, 1) - block_dims, grid_dims = [1, 1, 1], [1, 1, 1] - - for block in child_blocks: - # Get the loops surrounding the main block - loops = sch.get_loops(block) - - # Iterate over each loop to extract thread and block bindings - for loop in loops: - stmt = sch.get(loop) - thread_binding = stmt.thread_binding - extent = int(stmt.extent) - - # Skip loops without thread binding - if thread_binding: - if "threadIdx" in thread_binding.thread_tag: - block_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - elif "blockIdx" in thread_binding.thread_tag: - grid_dims["xyz".index(thread_binding.thread_tag[-1])] = extent - - return block_dims, grid_dims - - -class CUDASourceWrapper(object): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - self.mod = optimized_mod - self.arch = arch - self.source = source - self.function_name: Optional[str] = None - self.dynamic_smem_buf: Optional[int] = None - self.block_info: Union[List[int], Dict] = [1, 1, 1] - self.grid_info: Union[List[int], Dict] = [1, 1, 1] - self.parse_source_information() - self.src_name: Optional[str] = None - self.lib_name: Optional[str] = None - self.lib_code: Optional[str] = self.update_lib_code(source) - - def load_lib(self): - return ctypes.CDLL(self.lib_name) - - def remove_lib(self): - if self.lib_name: - os.remove(self.lib_name) - self.lib_name = None - - def compile_lib(self, timeout: float = None): - arch = self.arch - src = tempfile.NamedTemporaryFile(mode="w", suffix=".cu", delete=False) - compute_version = arch.compute_capability - lib_name = src.name.replace(".cu", ".so") - - command = [ - "nvcc", - "-std=c++17", - "-Xcudafe", - "--diag_suppress=177", - "--compiler-options", - "'-fPIC'", - "-lineinfo", - "--shared", - src.name, - "-lcuda", - f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", - "-o", - lib_name, - ] - src.write(self.lib_code) - src.flush() - try: - ret = subprocess.run(command, timeout=timeout) - except subprocess.TimeoutExpired: - logger.warning(f"Compilation Timeout! {command}") - return None - if ret.returncode != 0: - logger.warning(f"Compilation Failed! {command}") - return None - self.src_name = src.name - self.lib_name = lib_name - - def parse_source_information(self): - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - assert (len(device_mod.functions) == 1 - ), "Only support one function in the module for static shape kernel." - for g_var, func in device_mod.functions.items(): - self.function_name = g_var.name_hint - attrs = func.attrs - if "dyn_shared_memory_buf" in attrs: - self.dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - self.block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - self.grid_info["xyz".index(tag[-1])] = extent - - def get_dynamic_symbolic_set(self, prim_func): - # Determine the set of dynamic symbols used in the function - dynamic_symbolic_set = set() - for param in prim_func.params: - buffer = prim_func.buffer_map[param] - for dim in buffer.shape: - if isinstance(dim, tvm.tir.Var): - dynamic_symbolic_set.add(dim.name) - return dynamic_symbolic_set - - def get_cuda_init_func(self): - # Initialize an empty string for the CUDA function call - call_str = """""" - # If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call - if self.dynamic_smem_buf is not None: - call_str = """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(self.function_name, self.dynamic_smem_buf) - # Format the initialization function using the call_str - init_funcs = """ - extern "C" void init() {{ - {} - }} - """.format(call_str) - return init_funcs - - def update_lib_code(self, code: str): - # Update the library code with the given code string - self.lib_code = code - # Find the index of the global kernel function in the code - index = match_global_kernel(code) - # Extract the declaration of the function starting from the found index - declaration = code[index:].split(";")[0] - - function_name = self.function_name - # Get the CUDA initialization function - init_func = self.get_cuda_init_func() - - # Locate the opening brace of the function to insert arguments - index = code.index("{", index) - function_args = [] - # Populate the function arguments from the primary function's parameters and buffers - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - # Add dynamic symbolic parameters as integers to the function arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - # Format the function arguments for declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s, function_args): - # Extract the function call arguments matching the function definition - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(declaration, function_args)) - block_info, grid_info = self.block_info, self.grid_info - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - # Prepare the block and grid dimensions for the CUDA kernel launch - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2])) - # Determine the shared memory size, defaulting to 0 if not specified - smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf - # Format the CUDA kernel launch string - if len(dynamic_symbolic_set) != 0: - call_str = "if ({} == 0) return; \n\t\t".format(list(dynamic_symbolic_set)[0]) - else: - call_str = "" - call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, - smem_str, call_args) - # Create the host function wrapper for the CUDA kernel - host_func = """ - extern "C" void call({}) {{ - {} - }} - """.format(def_args, call_str) - # Combine the source, initialization function, and host function to form the complete library code - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] - - -class CUDASourceWrapperWithDynamic(CUDASourceWrapper): - - def __init__(self, optimized_mod: IRModule, source: str, arch: TileDevice): - super().__init__(optimized_mod, source, arch) - - def get_cuda_init_func(self): - # Initialize an empty string to accumulate CUDA function calls for setting dynamic shared memory - call_str = """""" - # Iterate over functions and their dynamic shared memory requirements - for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items(): - if dynamic_smem_buf is not None: - # Format the cudaFuncSetAttribute call for dynamic shared memory - call_str += """ - cudaFuncSetAttribute({}, - cudaFuncAttributeMaxDynamicSharedMemorySize, {}); - """.format(function_name, dynamic_smem_buf) - # Define the init function that will set the attributes for each kernel - init_funcs = """ -extern "C" void init() {{ - {} -}} - """.format(call_str) - return init_funcs - - def create_dispatch_func(self, code, function_informations): - # Extract the set of dynamic symbolic names used in the primary function - dynamic_symbolic_set = self.get_dynamic_symbolic_set(self.prim_func) - - # Find the location of the global kernel function in the code - index = match_global_kernel(code) - - # Analyze the function declaration to prepare for argument extraction - dummy_declaration = code[index:].split(";")[0] - - function_name = self.function_name - - # Identify the start of the function body to insert arguments - index = code.index("{", index) - function_args = [] - # Collect function arguments based on primary function's parameters and buffer mappings - for param in self.prim_func.params: - buffer = self.prim_func.buffer_map[param] - function_args.append({ - "name": buffer.name, - "type": _TYPE_MAP[buffer.dtype] + "* __restrict__", - }) - # Add dynamic symbols as integer arguments - for dyn_sym in dynamic_symbolic_set: - function_args.append({"name": dyn_sym, "type": "int"}) - - function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) - - # Format the argument definitions for function declaration - def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) - - def func_call_args(s: str, function_args): - # Extract and clean the function call arguments to match the declaration - pattern = r"[,\s]*(?:\w+\s*\*+\s*__restrict__\s+)?(\w+)" - matches = re.findall(pattern, s) - call_args = [] - for match in matches: - match = re.sub(r"\d+", "", match) # Remove numbers - match = re.sub(r"_", "", match) # Remove underscores - for arg in function_args: - if arg["name"] == match: - call_args.append(match) - return call_args - - call_args = ", ".join(func_call_args(dummy_declaration, function_args)) - - def legalize_c(p): - # Convert TIR expressions to legal C expressions - # Directly convert to string since the special case handling - # does not alter the string representation for `tvm.tir.Var` and `IntImm`. - # Replace Python's floor division operator with C's division operator - if isinstance(p, tvm.tir.IntImm): - p = int(p) - return str(p).replace("//", "/") - - last_range = 0 - num_items = len(function_informations) - _call_str = """""" - for function_name, info in function_informations.items(): - # Prepare block and grid configurations for kernel launches - block_info, grid_info = info["block_info"], info["grid_info"] - block_str = "dim3({}, {}, {})".format( - legalize_c(block_info[0]), - legalize_c(block_info[1]), - legalize_c(block_info[2]), - ) - grid_str = "dim3({}, {}, {})".format( - legalize_c(grid_info[0]), - legalize_c(grid_info[1]), - legalize_c(grid_info[2]), - ) - # Handle dynamic shared memory specification - smem_str = (0 if info["dynamic_smem_buf"] is None else info["dynamic_smem_buf"]) - opt_shapes = info["opt_shapes"] - # Generate conditional kernel launch code based on dynamic symbolic ranges - (symbolic,) = list(dynamic_symbolic_set) - range_str = opt_shapes[symbolic] - if last_range == 0: - call_str = "if ({} == 0) return; \n".format(symbolic,) - call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - symbolic, - range_str, - function_name, - grid_str, - block_str, - smem_str, - call_args, - ) - if last_range == num_items - 1: - call_str += ( - "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) - last_range += 1 - _call_str += call_str - - # Wrap the kernel dispatch logic in an external C function - host_func = """ -extern "C" void call({}) {{ - {} -}} - """.format(def_args, _call_str) - return host_func - - def parse_source_information(self): - # Parse device module to extract execution configurations for each function - device_mod = get_annotated_device_mod(self.mod, self.arch.target) - block_info_map = {} - grid_info_map = {} - dynamic_smem_buf_map = {} - for g_var, func in device_mod.functions.items(): - # Default block and grid configurations - block_info = [1, 1, 1] - grid_info = [1, 1, 1] - function_name = g_var.name_hint - attrs = func.attrs - dynamic_smem_buf = None - if "dyn_shared_memory_buf" in attrs: - dynamic_smem_buf = int(attrs["dyn_shared_memory_buf"]) - if "thread_extent" in attrs: - # Extract block and grid sizes from thread extents - thread_extent = attrs["thread_extent"] - for tag, extent in thread_extent.items(): - if "threadIdx" in tag: - block_info["xyz".index(tag[-1])] = extent - elif "blockIdx" in tag: - grid_info["xyz".index(tag[-1])] = extent - # Map the extracted configurations to each function - block_info_map[function_name] = block_info - grid_info_map[function_name] = grid_info - dynamic_smem_buf_map[function_name] = dynamic_smem_buf - # Store the mappings for use in code generation - self.block_info = block_info_map - self.grid_info = grid_info_map - self.dynamic_smem_buf = dynamic_smem_buf_map - - def update_lib_code(self, code: str): - # Organize function information for code generation - function_informations = {} - for g_var, func in self.mod.functions.items(): - if g_var.name_hint == "main": - continue - function_name = g_var.name_hint - attrs = func.attrs - assert "opt_shapes" in attrs - opt_shapes = attrs["opt_shapes"] - function_informations[function_name] = { - "function_name": function_name, - "opt_shapes": opt_shapes, - "block_info": self.block_info[function_name], - "grid_info": self.grid_info[function_name], - "dynamic_smem_buf": self.dynamic_smem_buf[function_name], - } - - def compare_map_objects(map_obj): - comparable_representation = list(map_obj.values()) - return comparable_representation - - function_informations = dict( - sorted( - function_informations.items(), - key=lambda item: compare_map_objects(item[1]["opt_shapes"]))) - - self.lib_code = code - - # Generate the initialization and dispatch functions - init_func = self.get_cuda_init_func() - host_func = self.create_dispatch_func(code, function_informations) - # Concatenate source code with generated code segments - lib_code = self.source + init_func + host_func - return lib_code - - @property - def prim_func(self): - return self.mod["main"] diff --git a/python/bitblas_cli.py b/python/bitblas_cli.py deleted file mode 100644 index 59e481eb9..000000000 --- a/python/bitblas_cli.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. diff --git a/testing/python/dsl/test_auto_normalized_tensorcore.py b/testing/python/dsl/test_auto_normalized_tensorcore.py deleted file mode 100644 index eb6e0baef..000000000 --- a/testing/python/dsl/test_auto_normalized_tensorcore.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -import numpy as np -import tvm -from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy -from bitblas.base.roller.arch import CUDA -from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags -from bitblas.gpu import Matmul -from bitblas.ops.impl.convolution2d_impl import conv2d_nhwc_hwio, conv2d_nhwc_ohwi -from bitblas.base.utils import apply_and_build -import time - -benchmark_sets = [ - # (prim_func, input_args, default_bitblas_schedule), - (conv2d_nhwc_hwio, (128, 64, 224, 224, 3, 7, 7, 2, 1, 3, "float16", "float16"), Matmul), - (conv2d_nhwc_ohwi, (128, 64, 56, 56, 64, 3, 3, 1, 1, 1, "float16", "float16"), Matmul), - (conv2d_nhwc_hwio, (128, 64, 56, 56, 64, 1, 1, 1, 1, 1, "float16", "float16"), Matmul), - (conv2d_nhwc_ohwi, (128, 64, 56, 56, 64, 1, 1, 1, 1, 1, "float16", "float16"), Matmul), - (conv2d_nhwc_ohwi, (128, 128, 28, 28, 128, 3, 3, 1, 1, 1, "float16", "float16"), Matmul), - (conv2d_nhwc_hwio, (128, 256, 14, 14, 128, 3, 3, 2, 1, 1, "float16", "float16"), Matmul), - (conv2d_nhwc_ohwi, (128, 256, 14, 14, 128, 1, 1, 2, 1, 1, "float16", "float16"), Matmul), -] -benchmark_results = {} -for get_prim_func, input_args, d_schedule in benchmark_sets: - ir_module = get_prim_func(*input_args) - func = ir_module["main"] - target = tvm.target.Target("nvidia/nvidia-a100") - arch = CUDA(target) - policy = DefaultPolicy(func=func, arch=arch) - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - try: - tensorized_func, tags = get_tensorized_func_and_tags(func, arch.target) - except Exception as e: - print(f"Failed to get tensorized function and tags: {e}") - tags = None - if tags: - policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - - configs = policy.emit_config(20) - - tune_start = time.time() - cpresults, best = apply_and_build(func, configs, arch, parallel_build=True) - fast_tune_time = time.time() - tune_start - print("[BitBLAS] The best latency of top 1 is {:.3f} ms".format(cpresults[0].latency * 1e3)) - print("[BitBLAS] The best latency of top 20 is {:.3f} ms".format(best.latency * 1e3)) - - # evaluate the performance of the default schedule - - rule = d_schedule() - default_tune_start = time.time() - sch_default = rule.apply(func, target, False) - with tvm.transform.PassContext(config={"tir.use_async_copy": True}): - mod_default = tvm.build(sch_default.mod["main"], target="cuda") - default_tune_time = time.time() - default_tune_start - - args = func.buffer_map.values() - - profile_tensors = [] - for arg in args: - profile_tensors.append( - tvm.nd.array( - np.random.uniform(0, 1, [int(i) for i in arg.shape]).astype(arg.dtype), - device=arch.device, - )) - - timer_cuda_mod = mod_default.time_evaluator(mod_default.entry_name, arch.device, number=5) - t = timer_cuda_mod(*profile_tensors).mean - - print("Time cost of BitBLAS default schedule: {:.3f} ms".format(t * 1e3)) - - profile_config = { - f"{get_prim_func.__name__}-{'-'.join([str(i) for i in input_args])}": { - "fast_bitblas_top20_tune_time": fast_tune_time, - "fast_bitblas_top1_latency": cpresults[0].latency * 1e3, - "fast_bitblas_top20_latency": best.latency * 1e3, - "default_bitblas_tune_time": default_tune_time, - "default_bitblas_latency": t * 1e3, - } - } - benchmark_results.update(profile_config) - -headers = [ - "PrimFunc", - "Input Arguments", - "FastDLight Top20 Tune Time", - "FastDLight Top1 Latency", - "FastDLight Top20 Latency", - "DefaultDLight Tune Time", - "DefaultDLight Latency", -] - -col_width = (max(len(word) for row in [headers] + list(profile_config.values()) for word in row) + 2 - ) # padding - -print("".join(word.ljust(col_width) for word in headers)) - -print("-" * col_width * len(headers)) - -for config, values in benchmark_results.items(): - args = config.split("-") - func_name = args[0] - input_args = "-".join(args[1:]) - row = [ - func_name, - input_args, - f" {str(values['fast_bitblas_top20_tune_time'])} s", - f"{values['fast_bitblas_top1_latency']:.3f} ms", - f"{values['fast_bitblas_top20_latency']:.3f} ms", - str(values["default_bitblas_tune_time"]), - f"{values['default_bitblas_latency']:.3f} ms", - ] - print("".join(word.ljust(col_width) for word in row))