Skip to content

[TL] Add TL Layout and Macro utils #174

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
d8884e6
Refactor BatchMatMulEmitter and BatchMatMulSelector for improved read…
LeiWang1999 Jul 5, 2024
fc84173
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
02f64de
Refactor import statements for improved readability and maintainability
LeiWang1999 Jul 5, 2024
397eee6
disable failure email for ci
LeiWang1999 Jul 5, 2024
20f6ad1
remove email notifications.
LeiWang1999 Jul 6, 2024
b93c394
move relax pass from testing to mlc_llm
LeiWang1999 Jul 6, 2024
ba6a6df
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
257693a
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
9bb7f49
Lint Fix
LeiWang1999 Jul 6, 2024
39e7614
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Jul 6, 2024
93eb5a5
Refactor scripts with se check_eual_ref_scripts_with_emitter function
LeiWang1999 Jul 6, 2024
72b9740
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 23, 2024
5b65979
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 27, 2024
d9bd479
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into main
LeiWang1999 Aug 29, 2024
99515cb
buf fix for matrix support
LeiWang1999 Aug 29, 2024
14406ef
lint fix
LeiWang1999 Aug 29, 2024
d30ec4f
dispatch tensor core based on shapes
LeiWang1999 Aug 29, 2024
fde4029
update install commands
LeiWang1999 Aug 30, 2024
6a04749
import scripts
LeiWang1999 Aug 31, 2024
9d90c40
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Aug 31, 2024
9ef14e9
remove shared mem hack
LeiWang1999 Sep 1, 2024
63f363e
revert change for swizzling
LeiWang1999 Sep 1, 2024
b29c66c
bug fix
LeiWang1999 Sep 1, 2024
4643dd9
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into docs
LeiWang1999 Sep 1, 2024
28beb13
tl examples
LeiWang1999 Sep 2, 2024
c0b476f
Enhance Swizzle
LeiWang1999 Sep 2, 2024
2bf14a8
lint fix
LeiWang1999 Sep 2, 2024
52accbf
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 2, 2024
19aa985
test fix
LeiWang1999 Sep 3, 2024
ef8f93c
lint fix
LeiWang1999 Sep 3, 2024
4015cc4
optimize layout
LeiWang1999 Sep 3, 2024
5c5880c
update tl utils.
LeiWang1999 Sep 3, 2024
1042ffd
macro optimization
LeiWang1999 Sep 3, 2024
1ecd76e
Merge branch 'main' of https://github.com/Microsoft/BitBLAS into tl-l…
LeiWang1999 Sep 3, 2024
7bb21e7
test fix
LeiWang1999 Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
10 changes: 10 additions & 0 deletions bitblas/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .utils import (
get_swizzle_layout, # noqa: F401
mma_store_index_map, # noqa: F401
get_ldmatrix_offset, # noqa: F401
)

from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401
211 changes: 211 additions & 0 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import tvm.tl.language as T

from tvm import DataType
from tvm.runtime import convert
from .utils import (
mma_store_index_map,
get_ldmatrix_offset,
)

lift = convert


class TensorCorePTXMacroGenerator(object):
"""
To eliminate Python syntax within TIR Macro.
"""

M_DIM = 16
N_DIM = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}

def __init__(
self,
a_dtype="float16",
b_dtype="float16",
accum_dtype="float16",
a_transposed=False,
b_transposed=False,
block_row_warps=2,
block_col_warps=2,
warp_row_tiles=8,
warp_col_tiles=8,
chunk=16,
threads=128,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps)

def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits

def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
self.local_size_a = (m_dim * k_dim) // warp_size
self.local_size_b = (n_dim * k_dim) // warp_size
self.local_size_out = (m_dim * n_dim) // warp_size

def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]

def _initialize_mma_prefix(self, k_dim=16):
if k_dim == 16:
self.mma_prefix = "m16n8k16"
elif k_dim == 32:
self.mma_prefix = "m16n8k32"
else:
raise ValueError("Unsupported k_dim")

def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_x = m_dim
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_thread_axis(self,
threads=128,
warp_size=32,
block_row_warps=2,
block_col_warps=2):
self.threads = threads
# thread_bindings = T.env_thread("threadIdx.x")
# self.tx = thread_bindings % warp_size
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)

@staticmethod
@T.macro
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
"row",
"col",
inst.a_dtype_abbrv,
inst.b_dtype_abbrv,
inst.accum_dtype_abbrv,
A_local_buf.data,
i * inst.local_size_a,
B_local_buf.data,
j * inst.local_size_b,
C_local_buf.data,
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out,
T.bool(False),
)

T.ptx_mma(
inst.accum_dtype,
"m16n8k16",
"row",
"col",
inst.a_dtype_abbrv,
inst.b_dtype_abbrv,
inst.accum_dtype_abbrv,
A_local_buf.data,
i * inst.local_size_a,
B_local_buf.data,
j * inst.local_size_b + lift(inst.local_size_b) // 2,
C_local_buf.data,
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out +
lift(inst.local_size_out) // 2,
T.bool(False),
)

@staticmethod
@T.macro
def LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
):
stride = inst.chunk
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
# self.ty = (thread_bindings // warp_size) % block_row_warps
# self.tz = thread_bindings // (warp_size * block_row_warps)
for i in T.serial(inst.warp_rows):
T.ptx_ldmatrix(
"float16",
T.bool(False),
4,
".b16",
A_local_buf.data,
i * inst.local_size_a,
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False),
)

@staticmethod
@T.macro
def LDMATRIX_B(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings,
):
stride = inst.chunk
tx = thread_bindings % inst.WARP_SIZE
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
for j in T.serial(inst.warp_cols):
T.ptx_ldmatrix(
"float16",
T.bool(False), # TODO(lei): should be optimized
4,
".b16",
B_local_buf.data,
j * inst.local_size_b,
T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y,
ki * inst.micro_size_k,]),
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True),
)

# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@staticmethod
@T.macro
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
for local_id in T.serial(inst.local_size_out):
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row,
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
j * inst.local_size_out + local_id]
119 changes: 119 additions & 0 deletions bitblas/tl/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import arith
from tvm import DataType
from typing import Union, Literal


def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
ana = arith.Analyzer()
BANK_SIZE_BYTES = 128
if isinstance(dtype, str):
dtype = DataType(dtype)
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % (
BANK_SIZE_BYTES // dtype.bits)
# use transaction bits to support diverse dtype.
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
coalescent_bits = dtype.bits * row_size
# permutation on 4 banks, each bank has 32 bits
bank_elems = BANK_SIZE_BYTES // dtype.bits
new_col_idx_outer = None
print(f"coalescent_bits: {coalescent_bits}")
if coalescent_bits % 1024 == 0:
# Use 8 * 8 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 32 banks
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
row_idx_sub = row_idx % bank_elems
new_col_idx_outer = col_idx_outer ^ row_idx_sub
else:
assert coalescent_bits % 512 == 0
# Use 8 * 4 permuted layout
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
# Every row below corresponds to 16 banks
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 0 1 2 3
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 1 0 3 2
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 2 3 0 1
# 0 1 2 3 ==> 3 2 1 0
# 0 1 2 3 ==> 3 2 1 0
# View with 8 elements per row:
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
row_idx_sub = row_idx % bank_elems
# Interleave elems per byte
interleave_elems = 32 // dtype.bits
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)

assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
row = thread_id % 16
col = 8 * (thread_id // 16) + local_id % 8
return row, col


def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (thread_id // 16) + (thread_id % 8)
col = 8 * ((thread_id % 16) // 8) + local_id % 8
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
row = thread_id % 16
col = local_id + (thread_id // 16) * 16
return row, col


def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
row = (thread_id // 16) * 8 + (thread_id % 8)
col = local_id + 16 * ((thread_id % 16) // 8)
return row, col


def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
return row, col


def get_ldmatrix_offset(
matrix: Literal["A", "B"],
row_idx,
col_idx,
stride,
dtype: Literal["float16", "int8"] = "float16",
transpose: bool = False,
):
assert matrix in ["A", "B"], "matrix should be either A or B"
transform_func = (
ldmatrix_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b)
transform_func_trans = (
ldmatrix_trans_32x8_to_shared_16x16_layout
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a)
if matrix == "A":
assert not transpose, "A matrix should not be transposed"
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
return new_row_idx * stride + new_col_idx
else:
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
return new_row_idx * stride + new_col_idx


def mma_store_index_map(*args, **kwargs):
return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs)
10 changes: 3 additions & 7 deletions integration/BitNet/utils_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def activation_quant(self, x, num_bits=8):
Qp = 2**(num_bits - 1) - 1
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
result = (x * s).round().clamp(Qn, Qp)
return result.type(torch.int8)
return result.type(torch.int8), s

@torch.compile
def post_quant_process(self, input, si, sw):
Expand All @@ -186,16 +186,14 @@ def native_forward(self, input):
return out

def forward_fp32_simulated(self, input):
quant_input = self.activation_quant(input, self.input_bits).detach()
quant_input, si = self.activation_quant(input, self.input_bits).detach()
quant_weight = self.weight_quant(self.weight).detach()

fp32_simulated_input = quant_input.float()
fp32_simulated_weight = quant_weight.float()
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)

sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
Qp = 2**(self.input_bits - 1) - 1
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = fp32_simulated_out / si
out = out / sw
Expand All @@ -206,11 +204,9 @@ def forward_fp32_simulated(self, input):

def forward(self, input):
# return self.forward_fp32_simulated(input)
quant_input = self.activation_quant(input, self.input_bits).detach()
quant_input, si = self.activation_quant(input, self.input_bits)
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
sw = self.sw
Qp = self.Qp
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
# if / (si * sw) it will inf in some cases
out = self.post_quant_process(fp32_out, si, sw)

Expand Down
3 changes: 0 additions & 3 deletions testing/python/tilelang/test_tilelang_dequantize_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,6 @@ def run_gemm(

print(f"output is {out}")

with open("debug/kernel.cu", "w") as f:
f.write(mod.mod.imported_modules[0].get_source())

def ref_program(A, qB):
import torch

Expand Down
Loading