Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiWang1999 committed Sep 26, 2024
1 parent 2c93dad commit e5bbf81
Show file tree
Hide file tree
Showing 3 changed files with 718 additions and 142 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/tvm
158 changes: 25 additions & 133 deletions bitblas/tl/macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@ def __init__(self,
warp_col_tiles=8,
chunk=16,
reduce_k=1,
transform_kind_a: Union[int, TransformKind] = 0,
transform_kind_b: Union[int, TransformKind] = 0,
num_elems_per_byte=1):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
Expand All @@ -68,7 +66,6 @@ def __init__(self,
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
self._initialize_transform_kind(transform_kind_a, transform_kind_b)
self.num_elems_per_byte = num_elems_per_byte

def _initialize_k_dim(self, a_dtype="float16"):
Expand Down Expand Up @@ -99,23 +96,6 @@ def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16):
self.micro_size_y = n_dim
self.micro_size_k = k_dim

def _initialize_transform_kind(self, transform_kind_a, transform_kind_b):
if isinstance(transform_kind_a, int):
self.transform_kind_a = TransformKind(transform_kind_a)
elif isinstance(transform_kind_a, TransformKind):
self.transform_kind_a = transform_kind_a
else:
raise ValueError("Unsupported transform_kind_a")

if isinstance(transform_kind_b, int):
self.transform_kind_b = TransformKind(transform_kind_b)
elif isinstance(transform_kind_b, TransformKind):
self.transform_kind_b = transform_kind_b
else:
raise ValueError("Unsupported transform_kind_b")

assert transform_kind_b in [0, 3], "Currently only support 0 and 3"

@T.macro
def _warp_ldmatrix_a(
inst,
Expand Down Expand Up @@ -240,24 +220,12 @@ def stmatrix(self, C_local_buf, C_shared_buf, thread_bindings):
return self._warp_stmatrix(self, C_local_buf, C_shared_buf, thread_bindings)


class TensorCoreIntrinEmitterWithLadderTransform(object):
class TensorCoreIntrinEmitterWithLadderTransform(TensorCoreIntrinEmitter):
"""
To eliminate Python syntax within TIR Macro.
With Ladder Transform Plugin.
"""

M_DIM = 16
N_DIM = 16
WARP_SIZE = 32
dtype_abbrv = {
"float16": "fp16",
"bfloat16": "bf16",
"float32": "fp32",
"int8": "int8",
"int32": "int32",
"e4m3_float8": "e4m3",
"e5m2_float8": "e5m2",
}

def __init__(
self,
a_dtype="float16",
Expand All @@ -275,28 +243,21 @@ def __init__(
transform_kind_b: Union[int, TransformKind] = 0,
num_elems_per_byte=1,
):
self.a_dtype = a_dtype
self.b_dtype = b_dtype
self.accum_dtype = accum_dtype
self.a_transposed = a_transposed
self.b_transposed = b_transposed
# Hint Information
self.block_row_warps = block_row_warps
self.block_col_warps = block_col_warps
self.warp_row_tiles = warp_row_tiles
self.warp_col_tiles = warp_col_tiles
self.chunk = chunk
self._initialize_k_dim(a_dtype)
self._initialize_abbrev(a_dtype, b_dtype, accum_dtype)
self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE)
self._initialize_mma_prefix(self.k_dim)
self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim)
self.warp_rows = warp_row_tiles // self.micro_size_x
self.warp_cols = warp_col_tiles // self.micro_size_y
self.reduce_k = reduce_k
self.threads = self.WARP_SIZE * (block_row_warps * block_col_warps) * reduce_k
super().__init__(
a_dtype=a_dtype,
b_dtype=b_dtype,
accum_dtype=accum_dtype,
a_transposed=a_transposed,
b_transposed=b_transposed,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
reduce_k=reduce_k,
num_elems_per_byte=num_elems_per_byte,
)
self._initialize_transform_kind(transform_kind_a, transform_kind_b)
self.num_elems_per_byte = num_elems_per_byte

def _initialize_k_dim(self, a_dtype="float16"):
self.k_dim = 256 // DataType(a_dtype).bits
Expand Down Expand Up @@ -339,38 +300,13 @@ def _initialize_transform_kind(self, transform_kind_a, transform_kind_b):
else:
raise ValueError("Unsupported transform_kind_b")

assert transform_kind_b in [0, 3], "Currently only support 0 and 3"
if self.transform_kind_a != TransformKind.NonTransform:
raise ValueError("TransformKind A is not supported yet")

@staticmethod
@T.macro
def LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings,
rk=0,
):
stride = A_shared_buf.shape[-1]
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps

for i in T.serial(inst.warp_rows):
T.ptx_ldmatrix(
inst.a_dtype,
T.bool(False),
4,
".b16",
A_local_buf.data,
i * inst.local_size_a,
T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x,
rk * inst.chunk + ki * inst.micro_size_k,]),
get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, inst.a_transposed),
)
assert transform_kind_b in [0, 3], "Currently only support 0 and 3"

@staticmethod
@T.macro
def LDMATRIX_B(
def _warp_ldmatrix_b(
inst,
B_local_buf,
B_shared_buf,
Expand Down Expand Up @@ -414,9 +350,8 @@ def LDMATRIX_B(
B_local_buf[j * local_size_dequantize + local_id] = B_shared_buf[ri, rj, rii,
rjj]

@staticmethod
@T.macro
def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
def _warp_mma(inst, A_local_buf, B_local_buf, C_local_buf):
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
T.ptx_mma(
inst.accum_dtype,
Expand Down Expand Up @@ -453,51 +388,8 @@ def MMA(inst, A_local_buf, B_local_buf, C_local_buf):
T.bool(False),
)

# STS
# MMA Store must be in simulated instead of TVM Intrins
# As TVM Intrins is like a hack that the threadIdx.x should be always
# equal to the warp_size
@staticmethod
@T.macro
def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings):
tx = thread_bindings % inst.WARP_SIZE
ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps
tz = (thread_bindings // (inst.WARP_SIZE * inst.block_row_warps)) % inst.block_col_warps
for i, j in T.grid(inst.warp_rows, inst.warp_cols):
for local_id_o in T.serial(inst.local_size_out // 2):
for local_id_i in T.vectorized(2):
local_id = local_id_o * 2 + local_id_i
row, col = T.meta_var(mma_store_index_map(tx, local_id))
C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row,
col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) +
j * inst.local_size_out + local_id]

# Allow GEMM from shared memory to local memory
@staticmethod
@T.macro
def GEMM_SS(inst, A_shared_buf, B_shared_buf, C_local_buf, thread_bindings):
# TODO(lei): alloc_buffer within the macro is not supported yet.
A_local_buf = T.alloc_fragment((inst.warp_rows * inst.local_size_a),
inst.a_dtype,
scope="local")
B_local_buf = T.alloc_fragment((inst.warp_cols * inst.local_size_b),
inst.b_dtype,
scope="local")
for ki in T.serial(0, (inst.chunk // inst.micro_size_k)):
inst.LDMATRIX_A(
inst,
A_local_buf,
A_shared_buf,
ki,
thread_bindings=thread_bindings,
)

inst.LDMATRIX_B(
inst,
B_local_buf,
B_shared_buf,
ki,
thread_bindings=thread_bindings,
)
def ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk=0):
return self._warp_ldmatrix_b(self, B_local_buf, B_shared_buf, ki, thread_bindings, rk)

inst.MMA(inst, A_local_buf, B_local_buf, C_local_buf)
def mma(self, A_local_buf, B_local_buf, C_local_buf):
return self._warp_mma(self, A_local_buf, B_local_buf, C_local_buf)
Loading

0 comments on commit e5bbf81

Please sign in to comment.