Skip to content

Commit

Permalink
[Dev] Potentially improve performance through block reduction (#63)
Browse files Browse the repository at this point in the history
* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

* feat: Add rasterization options for roller module

* Refactor tensorcore_legalization method to optimize tensor core usage

* feat: Add function to collect variables from expression, improve for splitk

* chore: Update typing import in __init__.py

* chore: Refactor CPU execution of operators

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* chore: Update version to 0.0.1.dev8

* chore: Enable debug output in bitblas.set_debug_level()

* Refactor Linear module matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bumpt version to v0.0.1.dev9

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bump version to v0.0.1.dev10

* Refactor CUDA kernel launch string for dynamic symbolic set

* Bump version to v0.0.1.dev12 and add MatmulConfigWithSplitK and MatmulWithSplitK

* fix the typo

* Refactor CUDA kernel launch string for dynamic symbolic set

* Refactor CUDA kernel launch string for dynamic symbolic set

* Refactor CUDA kernel launch string for dynamic symbolic set

* Refactor CUDA kernel launch string for dynamic symbolic set

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored Jun 30, 2024
1 parent 2e52552 commit e7ed676
Show file tree
Hide file tree
Showing 15 changed files with 1,102 additions and 49 deletions.
6 changes: 3 additions & 3 deletions docs/QuickStart.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import torch

# enabling debug output

bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")
matmul_config = bitblas.MatmulConfig(
M=1, # M dimension
N=1024, # N dimension
Expand Down Expand Up @@ -129,7 +129,7 @@ import bitblas
import torch

# enabling debug output
bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")

model = bitblas.Linear(
in_features=1024,
Expand Down Expand Up @@ -185,7 +185,7 @@ from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import (
)

# enabling debug output
bitblas.set_debug_level("Debug")
bitblas.set_log_level("Debug")

in_features = 1024
out_features = 1024
Expand Down
2 changes: 2 additions & 0 deletions python/bitblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import logging
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
""" Custom logging handler that directs log output to tqdm progress bar to avoid interference. """

Expand All @@ -61,6 +62,7 @@ def set_log_level(level):
Args:
level (str or int): Can be the string name of the level (e.g., 'INFO') or the actual level (e.g., logging.INFO).
OPTIONS: 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'
"""
if isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
Expand Down
6 changes: 4 additions & 2 deletions python/bitblas/base/roller/arch/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tvm
from tvm.target import Target
from .arch_base import TileDevice
from typing import List, Dict
from typing import List, Dict, Union


def check_sm_version(arch: str) -> int:
Expand All @@ -28,7 +28,9 @@ def __init__(

class CUDA(TileDevice):

def __init__(self, target: Target):
def __init__(self, target: Union[Target, str]):
if isinstance(target, str):
target = tvm.target.Target(target)
self.target = target
self.sm_version = check_sm_version(self.target.arch)
device = tvm.runtime.cuda(0)
Expand Down
12 changes: 9 additions & 3 deletions python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,20 @@ def __init__(self) -> None:
self.arch = None
self.use_tc = None # todo(lei): this should be renamed.

# special axes tiling info
# Special axes tiling info
self.block = []
self.thread = []
# special axes for tensorCore
# Special axes for MMA
self.warp = []
# reduce axes tiling info
# Reduce axes tiling info
self.rstep = []
self.reduce_thread = []
self.rasterization_plan = NoRasterization()
self.cached_tensors = []
self.output_strides = {}
self.schedule_stages = None
# Config for block reduction
self.block_reduction_depth = None # type: int

# Experimental
self._raxis_order = []
Expand Down Expand Up @@ -203,6 +205,10 @@ def to_dict(self) -> Dict:
dic["raxis_order"] = self._raxis_order
if self.vectorize != {}:
dic["vectorize"] = self.vectorize
if self.pipeline_stage != 1:
dic["pipeline_stage"] = self.pipeline_stage
if self.block_reduction_depth is not None:
dic["block_reduction_depth"] = self.block_reduction_depth
return dic

def from_dict(self, dic: Dict) -> "Hint":
Expand Down
15 changes: 15 additions & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(self,
self.wmma_k = 16
self.pipeline_stage: int = 1
self.use_async_copy: bool = False
self.block_reduction_depth: Optional[int] = None
self._legalize_info()

def _legalize_info(self):
Expand All @@ -44,6 +45,11 @@ def _legalize_info(self):
self.use_async_copy = True
else:
self.use_async_copy = False
# TODO: block reduction depth is not used for now.
# As there still exists some performance issues for block reduction.
# block_reduction_depth = self.prim_func_node.get_tag("block_reduction_depth")
# if block_reduction_depth:
# self.block_reduction_depth = block_reduction_depth

def _compute_tc_strides(
self,
Expand Down Expand Up @@ -114,6 +120,7 @@ def _check_small_tile(td: TileDict):

smem_limit = min(self.arch.max_smem_usage // td.block_per_SM, self.arch.smem_cap)
rstep_map = td.rstep_map.copy()
is_block_reduction = self.block_reduction_depth is not None

def _optimize(node, rstep):
all_steps = self.get_node_reduce_step_candidates(node)
Expand Down Expand Up @@ -177,6 +184,13 @@ def _enlarge(rstep_id):
if len(node.raxis) > 0:
rstep = _optimize(node, rstep_map)
rstep_map = rstep

if is_block_reduction:
# If block reduction, we should constrain the max value is 64
# Otherwise it will introduce an issue of cuda invalid args.
MAX_REDUCE_K = 64
for k in rstep_map:
rstep_map[k] = min(rstep_map[k], MAX_REDUCE_K)
td.rstep_map = rstep_map
td.smem_cost, td.cached_tensors_map = self._compute_shared_memory_usage(td)
return
Expand Down Expand Up @@ -289,6 +303,7 @@ def _score(node, thread): # small is better
codegen_dict.warp = warp_tile
codegen_dict.use_tc = True
codegen_dict.pipeline_stage = self.pipeline_stage
codegen_dict.block_reduction_depth = self.block_reduction_depth
codegen_dict.use_async = self.use_async_copy
codegen_dict.rstep = [int(rsteps[ax.var.name]) for ax in node.raxis]
codegen_dict.cached_tensors = td.cached_tensors_map[node]
Expand Down
5 changes: 3 additions & 2 deletions python/bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def apply_and_build_parallel(func,
arch,
num_repeats=3,
max_workers=10,
timeout=30,
data_distribution="uniform") -> CompileResult:
cpresults = []

Expand All @@ -187,10 +188,10 @@ def _apply_schedule(f, c):

with ThreadPoolExecutor(max_workers=4) as scheduler:
futures = {scheduler.submit(_apply_schedule, func, config) for config in configs}
for future in as_completed(futures):
for future in as_completed(futures, timeout=timeout):
_sched.append(future.result())

builder = PopenPoolExecutor(max_workers=max_workers)
builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout)

# build in process parallel
def _build(context) -> str:
Expand Down
2 changes: 1 addition & 1 deletion python/bitblas/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Modifications Copyright (c) Microsoft.
# The code below is mostly copied from apache/tvm gemv.py in dlight.
"""A rule for GEMV and DecodeGEMV."""
import re

from functools import reduce
from typing import List, Optional, Union, Dict

Expand Down
1 change: 0 additions & 1 deletion python/bitblas/gpu/gemv_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,6 @@ def get_vectorize_factor(target_format):
assert len(config.thread) == 2, "SplitK only support 2D thread config"
num_warps = int(num_warps // config.thread[0])


# get target dequantize buffer's idx
def get_idx(weight_decode_info: Dict):
# for LUT dequantize, the expr is LUT(w), the idx is 1
Expand Down
10 changes: 10 additions & 0 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,16 @@ def check_last_trait(region: List[Range]):
if func.attrs is not None and "weight_transform_kind" in func.attrs:
intrin_info["weight_transform_kind"] = func.attrs["weight_transform_kind"]
tags["intrin_info"] = intrin_info
# Analysis Block Reduction Optimization
# Currently, we only support block reduction depth 2 for small M
# When the func is a dequantize like ops, we should consider the M
if hasattr(func.attrs, "dequantize_info"):
for arg in func.params:
inp_shape = func.buffer_map[arg].shape
M = inp_shape[0]
if isinstance(M, tir.IntImm) and M <= 128:
tags["block_reduction_depth"] = 2
break

return tags

Expand Down
Loading

0 comments on commit e7ed676

Please sign in to comment.