Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Enhancing MatmulOps with Splitk Support #48

Merged
merged 13 commits into from
Jun 5, 2024
2 changes: 1 addition & 1 deletion 3rdparty/tvm
1 change: 1 addition & 0 deletions python/bitblas/base/roller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401
from .arch import TileDevice, CUDA # noqa: F401
10 changes: 9 additions & 1 deletion python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from .rasterization import *


class TensorCoreExtraConfig:
"""
This class is used to store extra information for tensorcore
Expand Down Expand Up @@ -210,6 +211,12 @@ def from_dict(self, dic: Dict) -> "Hint":
setattr(self, k, v)
return self

def tensorcore_legalization(self):
# only keep the last 2 axes for tensorcore
self.warp = self.warp[-2:]
self.block = self.block[-2:]
return self

@property
def raxis_order(self) -> List[int]:
if self._raxis_order != []:
Expand All @@ -228,7 +235,8 @@ def __repr__(self) -> str:
def complete_config(self, node: PrimFuncNode):
# analysis pass context, for int8 mma, we should merge static shared memory
merge_static_smem = False
if self.use_tc and self.intrin_info.in_dtype == "int8":
# int32 and float32 accum may take too much shared memory
if self.use_tc and self.intrin_info.out_dtype in ["float32", "int32"]:
merge_static_smem = True
self.pass_context = {"tir.merge_static_smem": merge_static_smem}
return self
3 changes: 3 additions & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def _score(node, thread): # small is better
intrin_info = node.get_tag("intrin_info")
if intrin_info:
codegen_dict.intrin_info = IntrinInfo(**intrin_info)
if intrin_info["out_dtype"] in ["float32"]:
codegen_dict.shared_scope = "shared.dyn"
# smem capacity
if td.smem_cost > self.arch.smem_cap:
codegen_dict.shared_scope = "shared.dyn"
Expand All @@ -305,6 +307,7 @@ def _score(node, thread): # small is better
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.tensorcore_legalization()
return codegen_dict

def plan_rasterization(self, td: TileDict):
Expand Down
39 changes: 36 additions & 3 deletions python/bitblas/gpu/intrin/lop3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,11 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT2_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=2, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=2,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)

LOP3_FAST_DECODE_UINT1_TO_INT8_TO_INT8_L16_INTRIN = ("lop3_fast_decode_u1_to_int8_to_i8_l16_")
Expand All @@ -1497,10 +1501,13 @@ def fast_decode_impl(
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_INT8_L16_INTRIN,
*get_fast_decode_intrin(
source_bit=1, source_format="int", storage_dtype="int8", target_dtype="int8", loops_extent=16),
source_bit=1,
source_format="int",
storage_dtype="int8",
target_dtype="int8",
loops_extent=16),
)


LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i4_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT4_TO_INT8_TO_FP16_L8_INTRIN,
Expand Down Expand Up @@ -1553,6 +1560,32 @@ def fast_decode_impl(
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN = ("lop3_fast_decode_i1_to_int8_to_f16_l8_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
),
)

LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = (
"lop3_fast_decode_i1_to_int8_to_f16_l8_scale_")
TensorIntrin.register(
LOP3_FAST_DECODE_INT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN,
*get_fast_decode_intrin(
source_bit=1,
storage_dtype="int8",
source_format="int",
target_dtype="float16",
loops_extent=8,
with_scale=True,
),
)


def get_lop3_intrin_group(
out_dtype: Literal["float16", "int8"],
Expand Down
29 changes: 25 additions & 4 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from ..base.analysis import (
Expand All @@ -17,12 +17,25 @@
get_reduction_blocks,
)
from tvm.target.target import Target
from tvm.tir import IndexMap
from tvm.tir.stmt_functor import pre_order_visit
import logging

logger = logging.getLogger(__name__)


def collect_vars_from_expr(prim_expr):
vars = []

def callback(node):
if isinstance(node, Var):
vars.append(node)
return True

pre_order_visit(prim_expr, callback)

return vars


def _is_one(x: PrimExpr) -> bool:
return isinstance(x, tir.IntImm) and x.value == 1

Expand Down Expand Up @@ -337,9 +350,13 @@ def is_common_reduce(var: Var) -> bool:
return True
return False

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return is_common_reduce(axes[-1])
return has_common_reduce(axes[-1])

def infer_layout(layout: str, region: List[Range], kind: str = "A"):
"""
Expand Down Expand Up @@ -583,9 +600,13 @@ def is_common_reduce(var: Var) -> bool:
return True
return False

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return is_common_reduce(axes[-1])
return has_common_reduce(axes[-1])

intrin_info: dict = {}
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

logger = getLogger(__name__)

from typing import List, Union
from typing import List, Union, Optional

from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
opt_M: Union[int, List[int]] = opt_M,
# performance related configs
enable_tuning: bool = True,
fast_decoding: bool = True,
fast_decoding: Optional[bool] = None,
propagate_b: bool = False,
):
"""
Expand Down
50 changes: 16 additions & 34 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import operator
from functools import reduce
from bitblas.base.roller.arch.cuda import CUDA
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from typing import Any, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind, OPExecutorCPU
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas.utils.tensor_adapter import tvm_tensor_to_torch
from dataclasses import dataclass
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig
Expand Down Expand Up @@ -42,34 +41,6 @@ def is_native_compute(A_dtype, W_dtype) -> bool:
return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS


class OPExecutorCPU:

def __init__(self, operators: Optional[List[Operator]] = None):
if operators is None:
operators = []
self.operators = operators

def append(self, op):
self.operators.append(op)

def is_none(self):
return len(self.operators) == 0

def forward(self, weight):
inputs = [weight]
for op in self.operators:
inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu())
inputs = [op.forward(*inputs)]
return inputs[-1]

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)

@property
def size(self):
return len(self.operators)


@dataclass(frozen=True)
class MatmulConfig:
M: Union[int, Tuple[int]] = None
Expand Down Expand Up @@ -148,9 +119,18 @@ def __initialize_zeros_mode(self, zeros_mode: Optional[str]):
object.__setattr__(self, "zeros_mode", "original")

def __initialize_fast_decoding(self, fast_decoding: Optional[bool]):

def is_not_fast_decoding_supported():
conditions = []
conditions.append("int" not in self.W_dtype)
conditions.append(self.W_dtype == self.A_dtype)
# int8,uint8 also do not implement and also do not require fast decoding
conditions.append(self.W_dtype in ["int8", "uint8"])
return any(conditions)

if fast_decoding is not None:
object.__setattr__(self, "fast_decoding", fast_decoding)
elif ("int" not in self.W_dtype or self.W_dtype == self.A_dtype):
elif is_not_fast_decoding_supported():
object.__setattr__(self, "fast_decoding", False)
else:
object.__setattr__(self, "fast_decoding", True)
Expand Down Expand Up @@ -450,7 +430,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None):
source_format, bit = self.source_format, self.bit

# Process integer source format
if source_format == "int":
if source_format == "int" and bit < 8:
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
Expand Down Expand Up @@ -519,9 +499,11 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
self._forward_from_torch_func(*args)
self._forward_from_prebuild_lib(*args)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)

return output

Expand Down
Loading
Loading