|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import tvm.tl.language as T |
| 5 | + |
| 6 | +from tvm import DataType |
| 7 | +from tvm.runtime import convert |
| 8 | +from .utils import ( |
| 9 | + mma_store_index_map, |
| 10 | + get_ldmatrix_offset, |
| 11 | +) |
| 12 | + |
| 13 | +lift = convert |
| 14 | + |
| 15 | + |
| 16 | +class TensorCorePTXMacroGenerator(object): |
| 17 | + """ |
| 18 | + To eliminate Python syntax within TIR Macro. |
| 19 | + """ |
| 20 | + |
| 21 | + M_DIM = 16 |
| 22 | + N_DIM = 16 |
| 23 | + WARP_SIZE = 32 |
| 24 | + dtype_abbrv = { |
| 25 | + "float16": "fp16", |
| 26 | + "bfloat16": "bf16", |
| 27 | + "float32": "fp32", |
| 28 | + "int8": "int8", |
| 29 | + "int32": "int32", |
| 30 | + "e4m3_float8": "e4m3", |
| 31 | + "e5m2_float8": "e5m2", |
| 32 | + } |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + a_dtype="float16", |
| 37 | + b_dtype="float16", |
| 38 | + accum_dtype="float16", |
| 39 | + a_transposed=False, |
| 40 | + b_transposed=False, |
| 41 | + block_row_warps=2, |
| 42 | + block_col_warps=2, |
| 43 | + warp_row_tiles=8, |
| 44 | + warp_col_tiles=8, |
| 45 | + chunk=16, |
| 46 | + threads=128, |
| 47 | + ): |
| 48 | + self.a_dtype = a_dtype |
| 49 | + self.b_dtype = b_dtype |
| 50 | + self.accum_dtype = accum_dtype |
| 51 | + self.a_transposed = a_transposed |
| 52 | + self.b_transposed = b_transposed |
| 53 | + # Hint Information |
| 54 | + self.block_row_warps = block_row_warps |
| 55 | + self.block_col_warps = block_col_warps |
| 56 | + self.warp_row_tiles = warp_row_tiles |
| 57 | + self.warp_col_tiles = warp_col_tiles |
| 58 | + self.chunk = chunk |
| 59 | + self._initialize_k_dim(a_dtype) |
| 60 | + self._initialize_abbrev(a_dtype, b_dtype, accum_dtype) |
| 61 | + self._initialize_local_size(self.M_DIM, self.N_DIM, self.k_dim, self.WARP_SIZE) |
| 62 | + self._initialize_mma_prefix(self.k_dim) |
| 63 | + self._initialize_micro_size(self.M_DIM, self.N_DIM, self.k_dim) |
| 64 | + self.warp_rows = warp_row_tiles // self.micro_size_x |
| 65 | + self.warp_cols = warp_col_tiles // self.micro_size_y |
| 66 | + self._initialize_thread_axis(threads, self.WARP_SIZE, block_row_warps, block_col_warps) |
| 67 | + |
| 68 | + def _initialize_k_dim(self, a_dtype="float16"): |
| 69 | + self.k_dim = 256 // DataType(a_dtype).bits |
| 70 | + |
| 71 | + def _initialize_local_size(self, m_dim=16, n_dim=16, k_dim=16, warp_size=32): |
| 72 | + self.local_size_a = (m_dim * k_dim) // warp_size |
| 73 | + self.local_size_b = (n_dim * k_dim) // warp_size |
| 74 | + self.local_size_out = (m_dim * n_dim) // warp_size |
| 75 | + |
| 76 | + def _initialize_abbrev(self, a_dtype, b_dtype, accum_dtype): |
| 77 | + self.a_dtype_abbrv = self.dtype_abbrv[a_dtype] |
| 78 | + self.b_dtype_abbrv = self.dtype_abbrv[b_dtype] |
| 79 | + self.accum_dtype_abbrv = self.dtype_abbrv[accum_dtype] |
| 80 | + |
| 81 | + def _initialize_mma_prefix(self, k_dim=16): |
| 82 | + if k_dim == 16: |
| 83 | + self.mma_prefix = "m16n8k16" |
| 84 | + elif k_dim == 32: |
| 85 | + self.mma_prefix = "m16n8k32" |
| 86 | + else: |
| 87 | + raise ValueError("Unsupported k_dim") |
| 88 | + |
| 89 | + def _initialize_micro_size(self, m_dim=16, n_dim=16, k_dim=16): |
| 90 | + self.micro_size_x = m_dim |
| 91 | + self.micro_size_y = n_dim |
| 92 | + self.micro_size_k = k_dim |
| 93 | + |
| 94 | + def _initialize_thread_axis(self, |
| 95 | + threads=128, |
| 96 | + warp_size=32, |
| 97 | + block_row_warps=2, |
| 98 | + block_col_warps=2): |
| 99 | + self.threads = threads |
| 100 | + # thread_bindings = T.env_thread("threadIdx.x") |
| 101 | + # self.tx = thread_bindings % warp_size |
| 102 | + # self.ty = (thread_bindings // warp_size) % block_row_warps |
| 103 | + # self.tz = thread_bindings // (warp_size * block_row_warps) |
| 104 | + |
| 105 | + @staticmethod |
| 106 | + @T.macro |
| 107 | + def MMA(inst, A_local_buf, B_local_buf, C_local_buf): |
| 108 | + for i, j in T.grid(inst.warp_rows, inst.warp_cols): |
| 109 | + T.ptx_mma( |
| 110 | + inst.accum_dtype, |
| 111 | + "m16n8k16", |
| 112 | + "row", |
| 113 | + "col", |
| 114 | + inst.a_dtype_abbrv, |
| 115 | + inst.b_dtype_abbrv, |
| 116 | + inst.accum_dtype_abbrv, |
| 117 | + A_local_buf.data, |
| 118 | + i * inst.local_size_a, |
| 119 | + B_local_buf.data, |
| 120 | + j * inst.local_size_b, |
| 121 | + C_local_buf.data, |
| 122 | + i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out, |
| 123 | + T.bool(False), |
| 124 | + ) |
| 125 | + |
| 126 | + T.ptx_mma( |
| 127 | + inst.accum_dtype, |
| 128 | + "m16n8k16", |
| 129 | + "row", |
| 130 | + "col", |
| 131 | + inst.a_dtype_abbrv, |
| 132 | + inst.b_dtype_abbrv, |
| 133 | + inst.accum_dtype_abbrv, |
| 134 | + A_local_buf.data, |
| 135 | + i * inst.local_size_a, |
| 136 | + B_local_buf.data, |
| 137 | + j * inst.local_size_b + lift(inst.local_size_b) // 2, |
| 138 | + C_local_buf.data, |
| 139 | + i * inst.warp_cols * inst.local_size_out + j * inst.local_size_out + |
| 140 | + lift(inst.local_size_out) // 2, |
| 141 | + T.bool(False), |
| 142 | + ) |
| 143 | + |
| 144 | + @staticmethod |
| 145 | + @T.macro |
| 146 | + def LDMATRIX_A( |
| 147 | + inst, |
| 148 | + A_local_buf, |
| 149 | + A_shared_buf, |
| 150 | + ki, |
| 151 | + thread_bindings, |
| 152 | + ): |
| 153 | + stride = inst.chunk |
| 154 | + tx = thread_bindings % inst.WARP_SIZE |
| 155 | + ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps |
| 156 | + # self.ty = (thread_bindings // warp_size) % block_row_warps |
| 157 | + # self.tz = thread_bindings // (warp_size * block_row_warps) |
| 158 | + for i in T.serial(inst.warp_rows): |
| 159 | + T.ptx_ldmatrix( |
| 160 | + "float16", |
| 161 | + T.bool(False), |
| 162 | + 4, |
| 163 | + ".b16", |
| 164 | + A_local_buf.data, |
| 165 | + i * inst.local_size_a, |
| 166 | + T.address_of(A_shared_buf[ty * inst.warp_row_tiles + i * inst.micro_size_x, |
| 167 | + ki * inst.micro_size_k,]), |
| 168 | + get_ldmatrix_offset("A", tx, 0, stride, inst.a_dtype, False), |
| 169 | + ) |
| 170 | + |
| 171 | + @staticmethod |
| 172 | + @T.macro |
| 173 | + def LDMATRIX_B( |
| 174 | + inst, |
| 175 | + B_local_buf, |
| 176 | + B_shared_buf, |
| 177 | + ki, |
| 178 | + thread_bindings, |
| 179 | + ): |
| 180 | + stride = inst.chunk |
| 181 | + tx = thread_bindings % inst.WARP_SIZE |
| 182 | + tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) |
| 183 | + for j in T.serial(inst.warp_cols): |
| 184 | + T.ptx_ldmatrix( |
| 185 | + "float16", |
| 186 | + T.bool(False), # TODO(lei): should be optimized |
| 187 | + 4, |
| 188 | + ".b16", |
| 189 | + B_local_buf.data, |
| 190 | + j * inst.local_size_b, |
| 191 | + T.address_of(B_shared_buf[tz * inst.warp_col_tiles + j * inst.micro_size_y, |
| 192 | + ki * inst.micro_size_k,]), |
| 193 | + get_ldmatrix_offset("B", tx, 0, stride, inst.b_dtype, True), |
| 194 | + ) |
| 195 | + |
| 196 | + # STS |
| 197 | + # MMA Store must be in simulated instead of TVM Intrins |
| 198 | + # As TVM Intrins is like a hack that the threadIdx.x should be always |
| 199 | + # equal to the warp_size |
| 200 | + @staticmethod |
| 201 | + @T.macro |
| 202 | + def STMATRIX(inst, C_local_buf, C_shared_buf, thread_bindings): |
| 203 | + tx = thread_bindings % inst.WARP_SIZE |
| 204 | + ty = (thread_bindings // inst.WARP_SIZE) % inst.block_row_warps |
| 205 | + tz = thread_bindings // (inst.WARP_SIZE * inst.block_row_warps) |
| 206 | + for i, j in T.grid(inst.warp_rows, inst.warp_cols): |
| 207 | + for local_id in T.serial(inst.local_size_out): |
| 208 | + row, col = T.meta_var(mma_store_index_map(tx, local_id)) |
| 209 | + C_shared_buf[ty * inst.warp_rows + i, tz * inst.warp_cols + j, row, |
| 210 | + col] = C_local_buf[i * (inst.warp_cols * inst.local_size_out) + |
| 211 | + j * inst.local_size_out + local_id] |
0 commit comments