-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlayout_anno_example.py
48 lines (39 loc) · 1.5 KB
/
layout_anno_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
from tvm import tl
import tvm.tl.language as T
def transpose(M, N):
dtype = "float"
BLK = 64
@T.prim_func
def main(ins: T.Buffer((M, N), dtype), outs: T.Buffer((N, M), dtype)):
with T.Kernel(T.ceildiv(M, BLK), T.ceildiv(N, BLK), threads=256) as (bx, by):
shared = T.alloc_shared((BLK, BLK), dtype)
local = T.alloc_fragment((BLK, BLK), dtype)
local_t = T.alloc_fragment((BLK, BLK), dtype)
T.annotate_layout(
{
# pad by 4
shared: T.Layout(shared.shape, lambda i, j: i * (BLK + 4) + j),
# assign 4x4 float tile to each thread
local: T.Fragment(local.shape, lambda i, j: j // 4 + 16 * (i // 4)),
}
)
T.copy(ins[BLK * by, BLK * bx], shared)
T.copy(shared, local)
for i, j in T.Parallel(BLK, BLK):
local_t[i, j] = local[j, i]
T.copy(local_t, shared)
T.copy(shared, outs[BLK * bx, BLK * by])
return main
def ref_program(A):
return A.T.contiguous()
if __name__ == "__main__":
M, N = 8192, 8192
program = transpose(M, N)
mod, params = tl.lower(program)
mod = tl.Profiler(mod, params, [1], tl.TensorSupplyType.Integer)
mod.assert_allclose(ref_program)
latency = mod.do_bench(ref_program, warmup=500)
print("{:.2f} ms".format(latency))
latency = mod.do_bench(mod.func)
print("{:.2f} ms".format(latency))