Skip to content

Commit

Permalink
[TL][BugFix] Disable Buffer Vectorization and Add OP Related TL Test …
Browse files Browse the repository at this point in the history
…Cases (#197)

* Refactor BatchMatMulEmitter and BatchMatMulSelector for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* Refactor import statements for improved readability and maintainability

* disable failure email for ci

* remove email notifications.

* move relax pass from testing to mlc_llm

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* Lint Fix

* Refactor scripts with se check_eual_ref_scripts_with_emitter function

* buf fix for matrix support

* lint fix

* dispatch tensor core based on shapes

* update install commands

* import scripts

* remove shared mem hack

* revert change for swizzling

* bug fix

* tl examples

* Enhance Swizzle

* lint fix

* test fix

* lint fix

* optimize layout

* update tl utils.

* macro optimization

* test fix

* gemm_ss

* doc fix

* lint fix

* lint fix

* remove debug print

* remove debug print

* vectorization init

* lint fix

* prelude update

* update tvm

* bug fix for reduce_k with shared memory

* bug fix

* bug fix

* Enhance Macro Generation

* Lift Layout to reduce load time

* lint fix

* test fix

* red fix

* tile lang macro example

* tile lang macro example

* optimize the marcro generator related items

* lint fix

* Tile Lang Test with Dynamic Symbolic

* more test case with block level programming

* all dynamic test case

* simplify the test case for dequantize gemm.

* dequant gemm updare.

* Tile Lang GEMM Implementation

* Tile Lang Gemm Fix

* Update subproject commit in 3rdparty/tvm
Refactor mma_layout.py and remove unused imports
Add matmul_macro_tensorcore and matmul_macro_tensorcore_weight_propagation_level_ldmatrix to dense/__init__.py
Refactor test_general_matmul_tilelang_impl.py to include additional matmul functions

* test fix
  • Loading branch information
LeiWang1999 authored Sep 28, 2024
1 parent 9f5f0ea commit 033495e
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 39 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
6 changes: 5 additions & 1 deletion bitblas/ops/general_matmul/tilelang/dense/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from .matmul import matmul_blocked # noqa: F401
from .matmul import (
matmul_blocked, # noqa: F401
matmul_macro_tensorcore, # noqa: F401
matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401
)
160 changes: 144 additions & 16 deletions bitblas/ops/general_matmul/tilelang/dense/matmul.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from bitblas import tvm as tvm
from tvm import DataType
import tvm.tl.language as T

from bitblas.tl.utils import (
get_mma_micro_size,
make_swizzle_layout,
)

from bitblas.tl.macro_generator import (TensorCoreIntrinEmitter)

from bitblas.tl.macro_generator import (
TensorCoreIntrinEmitter,
TensorCoreIntrinEmitterWithLadderTransform,
)

def maybe_pipeline(
iterable,
num_stages,
):
enable_pipeline = num_stages > 1
if enable_pipeline:
return T.Pipelined(iterable, num_stages=num_stages)
else:
return T.serial(iterable)
from bitblas.ops.operator import TransformKind


def matmul_blocked(
Expand Down Expand Up @@ -59,7 +54,7 @@ def main(
T.use_swizzle(10)

T.clear(C_local)
for k in maybe_pipeline(T.ceildiv(K, block_K), num_stages):
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
if trans_A:
T.copy(A[k * block_K, by * block_M], A_shared)
else:
Expand All @@ -80,6 +75,8 @@ def matmul_macro_tensorcore(
K,
dtypeAB,
dtypeC,
trans_A,
trans_B,
accum_dtype,
block_row_warps,
block_col_warps,
Expand All @@ -89,6 +86,8 @@ def matmul_macro_tensorcore(
num_stages=2,
enable_rasterization=False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
Expand Down Expand Up @@ -129,9 +128,9 @@ def main(
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, dtypeAB, shared_scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB, shared_scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, dtypeC, shared_scope=shared_scope)
A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), dtypeAB)
B_local = T.alloc_local((warp_cols * local_size), dtypeAB)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
Expand All @@ -147,7 +146,7 @@ def main(

T.clear(C_local)

for ko in maybe_pipeline(T.ceildiv(K, block_K), num_stages):
for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):

for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]
Expand Down Expand Up @@ -187,3 +186,132 @@ def main(
i % micro_size_x, j % micro_size_y]

return main


def matmul_macro_tensorcore_weight_propagation_level_ldmatrix(
M,
N,
K,
dtypeAB,
dtypeC,
trans_A,
trans_B,
accum_dtype,
block_row_warps,
block_col_warps,
warp_row_tiles,
warp_col_tiles,
chunk,
num_stages=2,
enable_rasterization=False,
):
assert trans_A is False, "Currently only support Matrix A is not transposed"
assert trans_B is True, "Currently only support Matrix B is transposed"

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

# TODO(lei): Can be generalized to analyzed from bank size
pad_factor = 8 if dtypeAB == "float16" else 16

can_swizzle_a = block_K * DataType(dtypeAB).bits == 512
apply_pad_a = not can_swizzle_a

micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtypeAB)

A_shape = (M, K)
B_shape = (N // micro_size_y, K // micro_size_k, micro_size_y, micro_size_k)
A_shared_shape = (block_M, (block_K + pad_factor) if apply_pad_a else block_K)
B_shared_shape = (block_N // micro_size_y, block_K // micro_size_k, micro_size_y, micro_size_k)
C_shared_shape = (block_M // micro_size_x, block_N // micro_size_y, micro_size_x, micro_size_y)

warp_size = 32 # nvidia gpu warp size is 32
threads = warp_size * (block_row_warps * block_col_warps)
local_size = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

shared_scope = "shared.dyn" # Literal["shared", "shared.dyn"] while shared for static shared memory
mma_emitter = TensorCoreIntrinEmitterWithLadderTransform(
a_dtype=dtypeAB,
b_dtype=dtypeAB,
accum_dtype=accum_dtype,
a_transposed=trans_A,
b_transposed=trans_B,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
transform_kind_b=TransformKind.LDMatrixTransform,
)

@T.prim_func
def main(
A: T.Buffer(A_shape, dtypeAB),
B: T.Buffer(B_shape, dtypeAB),
C: T.Buffer((M, N), dtypeC),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):

A_shared = T.alloc_shared(A_shared_shape, dtypeAB, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, dtypeAB, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, dtypeC, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size), dtypeAB)
B_local = T.alloc_local((warp_cols * local_size), dtypeAB)
C_local = T.alloc_local((warp_rows * warp_cols * local_size), accum_dtype)
thread_bindings = T.thread_binding(0, threads, "threadIdx.x")

T.annotate_layout({
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
})

if enable_rasterization:
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):

for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

for j, k, jj, kk in T.Parallel(block_N // micro_size_y, block_K // micro_size_k,
micro_size_y, micro_size_k):
B_shared[j, k, jj, kk] = B[bx * (block_N // micro_size_y) + j,
ko * (block_K // micro_size_k) + k, jj, kk]

for ki in T.serial(0, (block_K // micro_size_k)):

# Load A into fragment
mma_emitter.ldmatrix_a(
A_local,
A_shared,
ki,
thread_bindings=thread_bindings,
)

# Load B into fragment
mma_emitter.ldmatrix_b(
B_local,
B_shared,
ki,
thread_bindings=thread_bindings,
)

mma_emitter.mma(A_local, B_local, C_local)

mma_emitter.stmatrix(
C_local,
C_shared,
thread_bindings=thread_bindings,
)

for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i,
bx * block_N + j] = C_shared[i // micro_size_x, j // micro_size_y,
i % micro_size_x, j % micro_size_y]

return main
3 changes: 0 additions & 3 deletions bitblas/tl/mma_layout.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from tvm import arith
from tvm import DataType
from typing import Union, Literal


def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
Expand Down
Loading

0 comments on commit 033495e

Please sign in to comment.