Skip to content

Commit

Permalink
[Feature] Enhancing MatmulOps with Splitk Support (#48)
Browse files Browse the repository at this point in the history
* improve e4m3 decoding.

* append fp16xint1

* Update submodule commit reference

* chore: Update shared memory scope for float32 output dtype

* BUGFIX: UINT8/INT8 Decoding

* feat: Add rasterization options for roller module

* Refactor tensorcore_legalization method to optimize tensor core usage

* feat: Add function to collect variables from expression, improve for splitk

* chore: Update typing import in __init__.py

* chore: Refactor CPU execution of operators

* Refactor matmul implementation for splitk layout

* Refactor matmul implementation for splitk layout

---------

Co-authored-by: LeiWang199 <leiwang199>
  • Loading branch information
LeiWang1999 authored Jun 5, 2024
1 parent c4400b3 commit 99a744e
Show file tree
Hide file tree
Showing 13 changed files with 1,152 additions and 39 deletions.
1 change: 1 addition & 0 deletions python/bitblas/base/roller/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .node import PrimFuncNode # noqa: F401
from .rasterization import NoRasterization, Rasterization2DRow, Rasterization2DColumn # noqa: F401
from .hint import Hint # noqa: F401
from .policy import DefaultPolicy, TensorCorePolicy # noqa: F401
from .arch import TileDevice, CUDA # noqa: F401
6 changes: 6 additions & 0 deletions python/bitblas/base/roller/hint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def from_dict(self, dic: Dict) -> "Hint":
setattr(self, k, v)
return self

def tensorcore_legalization(self):
# only keep the last 2 axes for tensorcore
self.warp = self.warp[-2:]
self.block = self.block[-2:]
return self

@property
def raxis_order(self) -> List[int]:
if self._raxis_order != []:
Expand Down
1 change: 1 addition & 0 deletions python/bitblas/base/roller/policy/tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def _score(node, thread): # small is better
codegen_dict.vectorize = self._plan_vectorize(self.prim_func_node, td, block_size)
codegen_dict.arch = self.arch
codegen_dict.opt_shapes = self.prim_func_node.get_tag("opt_shapes")
codegen_dict.tensorcore_legalization()
return codegen_dict

def plan_rasterization(self, td: TileDict):
Expand Down
29 changes: 25 additions & 4 deletions python/bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import List, Optional, Set, Union, Tuple, Dict
from tvm import tir
from tvm.ir import Range
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion
from tvm.tir import IterVar, PrimExpr, Var, BufferRegion, IndexMap
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
from ..base.analysis import (
Expand All @@ -17,12 +17,25 @@
get_reduction_blocks,
)
from tvm.target.target import Target
from tvm.tir import IndexMap
from tvm.tir.stmt_functor import pre_order_visit
import logging

logger = logging.getLogger(__name__)


def collect_vars_from_expr(prim_expr):
vars = []

def callback(node):
if isinstance(node, Var):
vars.append(node)
return True

pre_order_visit(prim_expr, callback)

return vars


def _is_one(x: PrimExpr) -> bool:
return isinstance(x, tir.IntImm) and x.value == 1

Expand Down Expand Up @@ -337,9 +350,13 @@ def is_common_reduce(var: Var) -> bool:
return True
return False

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return is_common_reduce(axes[-1])
return has_common_reduce(axes[-1])

def infer_layout(layout: str, region: List[Range], kind: str = "A"):
"""
Expand Down Expand Up @@ -583,9 +600,13 @@ def is_common_reduce(var: Var) -> bool:
return True
return False

def has_common_reduce(var: Var) -> bool:
vars = collect_vars_from_expr(var)
return any(is_common_reduce(v) for v in vars)

def check_last_trait(region: List[Range]):
axes = get_ordered_axes(region)
return is_common_reduce(axes[-1])
return has_common_reduce(axes[-1])

intrin_info: dict = {}
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
Expand Down
4 changes: 2 additions & 2 deletions python/bitblas/module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

logger = getLogger(__name__)

from typing import List, Union
from typing import List, Union, Optional

from bitblas.cache import global_operator_cache, get_database_path
from bitblas import Matmul, MatmulConfig
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
opt_M: Union[int, List[int]] = opt_M,
# performance related configs
enable_tuning: bool = True,
fast_decoding: bool = True,
fast_decoding: Optional[bool] = None,
propagate_b: bool = False,
):
"""
Expand Down
37 changes: 5 additions & 32 deletions python/bitblas/ops/general_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import operator
from functools import reduce
from bitblas.base.roller.arch.cuda import CUDA
from typing import Any, List, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind
from typing import Any, Literal, Optional, Tuple, Union
from .operator import Operator, TransformKind, OPExecutorCPU
from .impl.matmul_dequantize_impl import (
select_implementation as weight_dequantize_implementation,)
from .impl.matmul_impl import select_implementation as consistent_implementation
from ..base.utils import tensor_replace_dp4a, tensor_remove_make_int4
from bitblas.utils.target_detector import auto_detect_nvidia_target
from bitblas.utils.tensor_adapter import tvm_tensor_to_torch
from dataclasses import dataclass
from .ladder_permutate import LadderPermutate, LadderPermutateConfig
from .lop3_permutate import LOP3Permutate, LOP3PermutateConfig
Expand Down Expand Up @@ -42,34 +41,6 @@ def is_native_compute(A_dtype, W_dtype) -> bool:
return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS


class OPExecutorCPU:

def __init__(self, operators: Optional[List[Operator]] = None):
if operators is None:
operators = []
self.operators = operators

def append(self, op):
self.operators.append(op)

def is_none(self):
return len(self.operators) == 0

def forward(self, weight):
inputs = [weight]
for op in self.operators:
inputs.append(tvm_tensor_to_torch(op.get_profile_tensors()[-1]).cpu())
inputs = [op.forward(*inputs)]
return inputs[-1]

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)

@property
def size(self):
return len(self.operators)


@dataclass(frozen=True)
class MatmulConfig:
M: Union[int, Tuple[int]] = None
Expand Down Expand Up @@ -528,9 +499,11 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
self._forward_from_torch_func(*args)
self._forward_from_prebuild_lib(*args)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)

return output

Expand Down
194 changes: 194 additions & 0 deletions python/bitblas/ops/general_matmul_splitk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm.target import Target
import operator
from functools import reduce
from typing import Any, Optional, Union
from .operator import TransformKind
from .impl.matmul_splitk_impl import select_implementation as consistent_implementation
from .impl.matmul_dequantize_splitk_impl import select_implementation as weight_dequantize_implementation
from dataclasses import dataclass
import logging
import torch
from .general_matmul import MatmulConfig, Matmul
from .general_matmul import is_native_compute

logger = logging.getLogger(__name__)

WORKSPACE_SIZE = 1024 * 1024 * 256


@dataclass(frozen=True)
class MatmulConfigWithSplitK(MatmulConfig):
k_split: int = 1 # split K dimension


class MatmulWithSplitK(Matmul):

def __init__(
self,
config: MatmulConfig,
name: str = "matmul",
target: Optional[Union[str, Target]] = None,
enable_tuning: bool = True,
from_database: bool = False,
):
super().__init__(config, name, target, enable_tuning, from_database)

def _select_implementation(self):
# the major implementation
if is_native_compute(self.A_dtype, self.W_dtype):
return consistent_implementation(
SplitK=self.k_split,
M=self.M,
N=self.N,
K=self.K,
in_dtype=self.A_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
with_bias=self.with_bias,
layout=self.layout,
propagate_a=self.propagate_a,
propagate_b=self.propagate_b,
)
else:
return weight_dequantize_implementation(
SplitK=self.k_split,
M=self.M,
N=self.N,
K=self.K,
in_dtype=self.A_dtype,
out_dtype=self.out_dtype,
accum_dtype=self.accum_dtype,
bit=self.bit,
storage_dtype=self.storage_dtype,
source_format=self.source_format,
with_scaling=self.with_scaling,
with_zeros=self.with_zeros,
group_size=self.group_size,
fast_decoding=self.fast_decoding,
with_bias=self.with_bias,
layout=self.layout,
zeros_mode=self.zeros_mode,
propagate_a=self.propagate_a,
propagate_b=self.propagate_b,
)

def retrieve_weight_shape(self):
return [int(i) for i in self.prim_func.buffer_map[self.prim_func.params[1]].shape]

def transform_weight(self, weight, scale=None, zeros=None, bias=None):
"""
Transforms the given weight tensor based on the specified quantization parameters and
returns the transformed weight along with optional scale, zeros, and bias.
Parameters:
- weight: The input weight tensor to be transformed.
- scale: Optional scaling factor for the weight tensor.
- zeros: Optional zero-point adjustment for the weight tensor.
- bias: Optional bias to be added to the weight tensor.
Returns:
A list containing the transformed weight tensor and optionally the scale, zeros, and bias.
"""
weight = weight.contiguous()
if self.W_dtype == self.A_dtype:
if self.weight_transform is not None:
return self.weight_transform(weight.cpu()).cuda().contiguous()
return weight

from bitblas.quantization import general_compress
import torch
import numpy as np

source_format, bit = self.source_format, self.bit

# Process integer source format
if source_format == "int" and bit < 8:
assert not self.with_scaling, "scale should be False for int source format"
assert not self.with_zeros, "zeros should be False for int source format"
maxq = 2**(bit - 1)
# Clamp weight values to be within the quantizable range and adjust
weight = torch.clamp(weight, -maxq, maxq).int() + maxq
elif source_format in ["fp_e5m2", "fp_e4m3"]:
weight = weight.view(torch.int8)
weight = weight.int()
else:
# For non-integer formats, simply convert weights to integers
weight = weight.int()

np_storage_dtype = getattr(np, self.storage_dtype)

weight = general_compress(
weight.cpu().numpy(), source_bits=bit, storage_dtype=np_storage_dtype)

weight = torch.from_numpy(weight).cuda().contiguous()

# Apply an optional weight transformation if specified
if self.weight_transform is not None:
weight = self.weight_transform(weight.cpu()).cuda().contiguous()

# Prepare the return list with the transformed weight and optionally include scale, zeros, and bias
result = [weight]
if scale is not None:
result.append(scale)
if zeros is not None:
result.append(zeros)
if bias is not None:
result.append(bias)

return next(iter(result), result)

def transform_input(self, input_tensor):
if self.propagate_a is not TransformKind.NonTransform:
# check workspace size
if input_tensor.numel() > WORKSPACE_SIZE:
raise ValueError(
f"Input size {input_tensor.numel()} is larger than the workspace size {WORKSPACE_SIZE}, please increase the workspace size."
)
self.ladder_permutate_a._forward_from_prebuild_lib(input_tensor, self.workspace)
return self.workspace
return input_tensor

def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any:
args = []
args.append(self.transform_input(A))
args.append(W)

if self.lut is not None:
args.append(self.lut)

if output is None:
output = torch.empty(
(self.k_split,) + A.shape[:-1] + (self.N,),
dtype=self.torch_output_dtype,
device=A.device)
if scale is not None:
args.append(scale)
if zeros is not None:
args.append(zeros)
if bias is not None:
args.append(bias)
args.append(output)

if self.dynamic_range is not None:
m = reduce(operator.mul, A.shape[:-1], 1)
args.append(m)

stream = torch.cuda.current_stream()

if self.lib is None:
self._forward_from_torch_func(*args)
self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream)
output = torch.sum(output, dim=0)
return output

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.forward(*args, **kwds)

@property
def k_split(self):
return self.config.k_split


__all__ = ["MatmulConfigWithSplitK", "MatmulWithSplitK"]
Loading

0 comments on commit 99a744e

Please sign in to comment.