Skip to content

Commit c15744e

Browse files
authored
[TL] Add TL Layout and Macro utils (#174)
* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability * Refactor import statements for improved readability and maintainability * Refactor import statements for improved readability and maintainability * disable failure email for ci * remove email notifications. * move relax pass from testing to mlc_llm * Refactor scripts with se check_eual_ref_scripts_with_emitter function * Lint Fix * Refactor scripts with se check_eual_ref_scripts_with_emitter function * buf fix for matrix support * lint fix * dispatch tensor core based on shapes * update install commands * import scripts * remove shared mem hack * revert change for swizzling * bug fix * tl examples * Enhance Swizzle * lint fix * test fix * lint fix * optimize layout * update tl utils. * macro optimization * test fix
1 parent 3aa9439 commit c15744e

File tree

6 files changed

+344
-11
lines changed

6 files changed

+344
-11
lines changed

bitblas/tl/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from .utils import (
5+
get_swizzle_layout, # noqa: F401
6+
mma_store_index_map, # noqa: F401
7+
get_ldmatrix_offset, # noqa: F401
8+
)
9+
10+
from .macro_generator import TensorCorePTXMacroGenerator # noqa: F401

bitblas/tl/macro_generator.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import tvm.tl.language as T
5+
6+
from tvm import DataType
7+
from tvm.runtime import convert
8+
from .utils import (
9+
mma_store_index_map,
10+
get_ldmatrix_offset,
11+
)
12+
13+
lift = convert
14+
15+
16+
class TensorCorePTXMacroGenerator(object):
17+
"""
18+
To eliminate Python syntax within TIR Macro.
19+
"""
20+
21+
M_DIM = 16
22+
N_DIM = 16
23+
WARP_SIZE = 32
24+
dtype_abbrv = {
25+
"float16": "fp16",
26+
"bfloat16": "bf16",
27+
"float32": "fp32",
28+
"int8": "int8",
29+
"int32": "int32",
30+
"e4m3_float8": "e4m3",
31+
"e5m2_float8": "e5m2",
32+
}
33+
34+
def __init__(
35+
self,
36+
a_dtype="float16",
37+
b_dtype="float16",
38+
accum_dtype="float16",
39+
a_transposed=False,
40+
b_transposed=False,
41+
block_row_warps=2,
42+
block_col_warps=2,
43+
warp_row_tiles=8,
44+
warp_col_tiles=8,
45+
chunk=16,
46+
threads=128,
47+
):
48+
self.a_dtype = a_dtype
49+
self.b_dtype = b_dtype
50+
self.accum_dtype = accum_dtype
51+
self.a_transposed = a_transposed
52+
self.b_transposed = b_transposed
53+
# Hint Information
54+
self.block_row_warps = block_row_warps
55+
self.block_col_warps = block_col_warps
56+
self.warp_row_tiles = warp_row_tiles
57+
self.warp_col_tiles = warp_col_tiles
58+
self.chunk = chunk
59+
self._initialize_k_dim(a_dtype)
60+
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
61+
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
62+
self._initialize_mma_prefix(self.k_dim)
63+
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
64+
self.warp_rows = warp_row_tiles // self.micro_size_x
65+
self.warp_cols = warp_col_tiles // self.micro_size_y
66+
self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps)
67+
68+
def _initialize_k_dim(self, a_dtype="float16"):
69+
self.k_dim = 256 // DataType(a_dtype).bits
70+
71+
def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32):
72+
self.local_size_a = (m_dim * k_dim) // warp_size
73+
self.local_size_b = (n_dim * k_dim) // warp_size
74+
self.local_size_out = (m_dim * n_dim) // warp_size
75+
76+
def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype):
77+
self.a_dtype_abbrv = self.dtype_abbrv[a_dtype]
78+
self.b_dtype_abbrv = self.dtype_abbrv[b_dtype]
79+
self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype]
80+
81+
def _initialize_mma_prefix(self, k_dim=16):
82+
if k_dim == 16:
83+
self.mma_prefix = "m16n8k16"
84+
elif k_dim == 32:
85+
self.mma_prefix = "m16n8k32"
86+
else:
87+
raise ValueError("Unsupported k_dim")
88+
89+
def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
90+
self.micro_size_x = m_dim
91+
self.micro_size_y = n_dim
92+
self.micro_size_k = k_dim
93+
94+
def _initialize_thread_axis(self,
95+
threads=128,
96+
warp_size=32,
97+
block_row_warps=2,
98+
block_col_warps=2):
99+
self.threads = threads
100+
# thread_bindings = T.env_thread("threadIdx.x")
101+
# self.tx = thread_bindings % warp_size
102+
# self.ty = (thread_bindings // warp_size) % block_row_warps
103+
# self.tz = thread_bindings // (warp_size * block_row_warps)
104+
105+
@staticmethod
106+
@T.macro
107+
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
108+
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
109+
T.ptx_mma(
110+
inst.accum_dtype,
111+
"m16n8k16",
112+
"row",
113+
"col",
114+
inst.a_dtype_abbrv,
115+
inst.b_dtype_abbrv,
116+
inst.accum_dtype_abbrv,
117+
A_local_buf.data,
118+
i * inst.local_size_a,
119+
B_local_buf.data,
120+
j * inst.local_size_b,
121+
C_local_buf.data,
122+
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out,
123+
T.bool(False),
124+
)
125+
126+
T.ptx_mma(
127+
inst.accum_dtype,
128+
"m16n8k16",
129+
"row",
130+
"col",
131+
inst.a_dtype_abbrv,
132+
inst.b_dtype_abbrv,
133+
inst.accum_dtype_abbrv,
134+
A_local_buf.data,
135+
i * inst.local_size_a,
136+
B_local_buf.data,
137+
j * inst.local_size_b + lift(inst.local_size_b) // 2,
138+
C_local_buf.data,
139+
i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out +
140+
lift(inst.local_size_out) // 2,
141+
T.bool(False),
142+
)
143+
144+
@staticmethod
145+
@T.macro
146+
def LDMATRIX_A(
147+
inst,
148+
A_local_buf,
149+
A_shared_buf,
150+
ki,
151+
thread_bindings,
152+
):
153+
stride = inst.chunk
154+
tx = thread_bindings % inst.WARP_SIZE
155+
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
156+
# self.ty = (thread_bindings // warp_size) % block_row_warps
157+
# self.tz = thread_bindings // (warp_size * block_row_warps)
158+
for i in T.serial(inst.warp_rows):
159+
T.ptx_ldmatrix(
160+
"float16",
161+
T.bool(False),
162+
4,
163+
".b16",
164+
A_local_buf.data,
165+
i * inst.local_size_a,
166+
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x,
167+
ki * inst.micro_size_k,]),
168+
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False),
169+
)
170+
171+
@staticmethod
172+
@T.macro
173+
def LDMATRIX_B(
174+
inst,
175+
B_local_buf,
176+
B_shared_buf,
177+
ki,
178+
thread_bindings,
179+
):
180+
stride = inst.chunk
181+
tx = thread_bindings % inst.WARP_SIZE
182+
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
183+
for j in T.serial(inst.warp_cols):
184+
T.ptx_ldmatrix(
185+
"float16",
186+
T.bool(False), # TODO(lei): should be optimized
187+
4,
188+
".b16",
189+
B_local_buf.data,
190+
j * inst.local_size_b,
191+
T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y,
192+
ki * inst.micro_size_k,]),
193+
get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True),
194+
)
195+
196+
# STS
197+
# MMA Store must be in simulated instead of TVM Intrins
198+
# As TVM Intrins is like a hack that the threadIdx.x should be always
199+
# equal to the warp_size
200+
@staticmethod
201+
@T.macro
202+
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
203+
tx = thread_bindings % inst.WARP_SIZE
204+
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
205+
tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)
206+
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
207+
for local_id in T.serial(inst.local_size_out):
208+
row, col = T.meta_var(mma_store_index_map(tx, local_id))
209+
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row,
210+
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
211+
j * inst.local_size_out + local_id]

bitblas/tl/utils.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from tvm import arith
4+
from tvm import DataType
5+
from typing import Union, Literal
6+
7+
8+
def get_swizzle_layout(row_idx, col_idx, row_size, dtype: Union[DataType, str]):
9+
ana = arith.Analyzer()
10+
BANK_SIZE_BYTES = 128
11+
if isinstance(dtype, str):
12+
dtype = DataType(dtype)
13+
col_idx_outer, col_idx_inner = col_idx // (BANK_SIZE_BYTES // dtype.bits), col_idx % (
14+
BANK_SIZE_BYTES // dtype.bits)
15+
# use transaction bits to support diverse dtype.
16+
# for fp16, 64 elems * 16 bits = 1024 bits, 32 elems * 32 bits = 512 bits
17+
# for int8, 128 elems * 8 bits = 1024 bits, 64 elems * 8 bits = 512 bits
18+
coalescent_bits = dtype.bits * row_size
19+
# permutation on 4 banks, each bank has 32 bits
20+
bank_elems = BANK_SIZE_BYTES // dtype.bits
21+
new_col_idx_outer = None
22+
print(f"coalescent_bits: {coalescent_bits}")
23+
if coalescent_bits % 1024 == 0:
24+
# Use 8 * 8 permuted layout
25+
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
26+
# Every row below corresponds to 32 banks
27+
# 0 1 2 3 4 5 6 7 ==> 0 1 2 3 4 5 6 7
28+
# 0 1 2 3 4 5 6 7 ==> 1 0 3 2 5 4 7 6
29+
# 0 1 2 3 4 5 6 7 ==> 2 3 0 1 6 7 4 5
30+
# 0 1 2 3 4 5 6 7 ==> 3 2 1 0 7 6 5 4
31+
# 0 1 2 3 4 5 6 7 ==> 4 5 6 7 0 1 2 3
32+
# 0 1 2 3 4 5 6 7 ==> 5 4 7 6 1 0 3 2
33+
# 0 1 2 3 4 5 6 7 ==> 6 7 4 5 2 3 0 1
34+
# 0 1 2 3 4 5 6 7 ==> 7 6 5 4 3 2 1 0
35+
row_idx_sub = row_idx % bank_elems
36+
new_col_idx_outer = col_idx_outer ^ row_idx_sub
37+
else:
38+
assert coalescent_bits % 512 == 0
39+
# Use 8 * 4 permuted layout
40+
# Every number below corresponds to 8 consecutive fp16 number in shared mem, i.e. one read
41+
# Every row below corresponds to 16 banks
42+
# 0 1 2 3 ==> 0 1 2 3
43+
# 0 1 2 3 ==> 0 1 2 3
44+
# 0 1 2 3 ==> 1 0 3 2
45+
# 0 1 2 3 ==> 1 0 3 2
46+
# 0 1 2 3 ==> 2 3 0 1
47+
# 0 1 2 3 ==> 2 3 0 1
48+
# 0 1 2 3 ==> 3 2 1 0
49+
# 0 1 2 3 ==> 3 2 1 0
50+
# View with 8 elements per row:
51+
# 0 1 2 3 4 0 1 2 3 ==> 0 1 2 3 0 1 2 3
52+
# 0 1 2 3 4 0 1 2 3 ==> 1 0 3 2 1 0 3 2
53+
# 0 1 2 3 4 0 1 2 3 ==> 2 3 0 1 2 3 0 1
54+
# 0 1 2 3 4 0 1 2 3 ==> 3 2 1 0 3 2 1 0
55+
row_idx_sub = row_idx % bank_elems
56+
# Interleave elems per byte
57+
interleave_elems = 32 // dtype.bits
58+
new_col_idx_outer = col_idx_outer ^ (row_idx_sub // interleave_elems)
59+
60+
assert (new_col_idx_outer is not None), f"Unsupported dtype {dtype} with {coalescent_bits} bits"
61+
return row_idx, ana.simplify(new_col_idx_outer * bank_elems + col_idx_inner)
62+
63+
64+
def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
65+
row = thread_id % 16
66+
col = 8 * (thread_id // 16) + local_id % 8
67+
return row, col
68+
69+
70+
def ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id):
71+
row = 8 * (thread_id // 16) + (thread_id % 8)
72+
col = 8 * ((thread_id % 16) // 8) + local_id % 8
73+
return row, col
74+
75+
76+
def ldmatrix_32x16_to_shared_16x32_layout_a(thread_id, local_id):
77+
row = thread_id % 16
78+
col = local_id + (thread_id // 16) * 16
79+
return row, col
80+
81+
82+
def ldmatrix_32x16_to_shared_16x32_layout_b(thread_id, local_id):
83+
row = (thread_id // 16) * 8 + (thread_id % 8)
84+
col = local_id + 16 * ((thread_id % 16) // 8)
85+
return row, col
86+
87+
88+
def mma_store_32x8_to_shared_16x16_layout(thread_id, local_id):
89+
row = 8 * (local_id % 4 // 2) + (thread_id // 4)
90+
col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
91+
return row, col
92+
93+
94+
def get_ldmatrix_offset(
95+
matrix: Literal["A", "B"],
96+
row_idx,
97+
col_idx,
98+
stride,
99+
dtype: Literal["float16", "int8"] = "float16",
100+
transpose: bool = False,
101+
):
102+
assert matrix in ["A", "B"], "matrix should be either A or B"
103+
transform_func = (
104+
ldmatrix_32x8_to_shared_16x16_layout
105+
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_b)
106+
transform_func_trans = (
107+
ldmatrix_trans_32x8_to_shared_16x16_layout
108+
if dtype in ["float16", "bfloat16"] else ldmatrix_32x16_to_shared_16x32_layout_a)
109+
if matrix == "A":
110+
assert not transpose, "A matrix should not be transposed"
111+
new_row_idx, new_col_idx = transform_func(row_idx, col_idx)
112+
return new_row_idx * stride + new_col_idx
113+
else:
114+
new_row_idx, new_col_idx = transform_func_trans(row_idx, col_idx)
115+
return new_row_idx * stride + new_col_idx
116+
117+
118+
def mma_store_index_map(*args, **kwargs):
119+
return mma_store_32x8_to_shared_16x16_layout(*args, **kwargs)

integration/BitNet/utils_quant.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def activation_quant(self, x, num_bits=8):
165165
Qp = 2**(num_bits - 1) - 1
166166
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
167167
result = (x * s).round().clamp(Qn, Qp)
168-
return result.type(torch.int8)
168+
return result.type(torch.int8), s
169169

170170
@torch.compile
171171
def post_quant_process(self, input, si, sw):
@@ -186,16 +186,14 @@ def native_forward(self, input):
186186
return out
187187

188188
def forward_fp32_simulated(self, input):
189-
quant_input = self.activation_quant(input, self.input_bits).detach()
189+
quant_input, si = self.activation_quant(input, self.input_bits).detach()
190190
quant_weight = self.weight_quant(self.weight).detach()
191191

192192
fp32_simulated_input = quant_input.float()
193193
fp32_simulated_weight = quant_weight.float()
194194
fp32_simulated_out = nn.functional.linear(fp32_simulated_input, fp32_simulated_weight)
195195

196196
sw = 1 / self.weight.abs().mean().clamp(min=1e-5)
197-
Qp = 2**(self.input_bits - 1) - 1
198-
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
199197
# if / (si * sw) it will inf in some cases
200198
out = fp32_simulated_out / si
201199
out = out / sw
@@ -206,11 +204,9 @@ def forward_fp32_simulated(self, input):
206204

207205
def forward(self, input):
208206
# return self.forward_fp32_simulated(input)
209-
quant_input = self.activation_quant(input, self.input_bits).detach()
207+
quant_input, si = self.activation_quant(input, self.input_bits)
210208
fp32_out = self.bitblas_matmul(quant_input, self.qweight)
211209
sw = self.sw
212-
Qp = self.Qp
213-
si = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
214210
# if / (si * sw) it will inf in some cases
215211
out = self.post_quant_process(fp32_out, si, sw)
216212

testing/python/tilelang/test_tilelang_dequantize_gemm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ def run_gemm(
113113

114114
print(f"output is {out}")
115115

116-
with open("debug/kernel.cu", "w") as f:
117-
f.write(mod.mod.imported_modules[0].get_source())
118-
119116
def ref_program(A, qB):
120117
import torch
121118

0 commit comments

Comments
 (0)