-
Notifications
You must be signed in to change notification settings - Fork 1
/
19_DistillGPT2_triton_red_fused__log_softmax__to_copy_21.py
95 lines (83 loc) · 3.99 KB
/
19_DistillGPT2_triton_red_fused__log_softmax__to_copy_21.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import triton
import triton.language as tl
# from torch._inductor.ir import ReductionHint
# from torch._inductor.ir import TileHint
# from intel_extension_for_pytorch._inductor.xpu.triton_heuristics import AutotuneHint, reduction
# from torch._inductor.utils import instance_descriptor
import triton_helpers
from helper import rand_strided
import torch
# from intel_extension_for_pytorch._C import _getCurrentRawStream as get_xpu_stream
# from torch._inductor.triton_heuristics import grid
# @reduction(
# size_hints=[8192, 65536],
# reduction_hint=ReductionHint.DEFAULT,
# filename=__file__,
# meta={'signature': {0: '*bf16', 1: '*fp32', 2: 'i32', 3: 'i32'}, 'device': 0, 'device_type': 'xpu', 'constants': {}, 'mutated_arg_names': [], 'autotune_hints': set(), 'kernel_name': 'triton_red_fused__log_softmax__to_copy_21', 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2,))]}
# )
@triton.jit
def triton_red_fused__log_softmax__to_copy_21(in_ptr0, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr, RBLOCK : tl.constexpr):
xnumel = 8176
rnumel = 50257
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
xmask = xindex < xnumel
rbase = tl.arange(0, RBLOCK)[None, :]
x0 = xindex
_tmp3 = tl.full([XBLOCK, RBLOCK], float("-inf"), tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp0 = tl.load(in_ptr0 + (r1 + (50257*(x0 % 511)) + (25731584*(x0 // 511))), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp1 = tmp0.to(tl.float32)
tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])
tmp4 = triton_helpers.maximum(_tmp3, tmp2)
_tmp3 = tl.where(rmask & xmask, tmp4, _tmp3)
tmp3 = triton_helpers.max2(_tmp3, 1)[:, None]
_tmp10 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp5 = tl.load(in_ptr0 + (r1 + (50257*(x0 % 511)) + (25731584*(x0 // 511))), rmask & xmask, eviction_policy='evict_last', other=0).to(tl.float32)
tmp6 = tmp5.to(tl.float32)
tmp7 = tmp6 - tmp3
tmp8 = tl.exp(tmp7)
tmp9 = tl.broadcast_to(tmp8, [XBLOCK, RBLOCK])
tmp11 = _tmp10 + tmp9
_tmp10 = tl.where(rmask & xmask, tmp11, _tmp10)
tmp10 = tl.sum(_tmp10, 1)[:, None]
for roffset in range(0, rnumel, RBLOCK):
rindex = roffset + rbase
rmask = rindex < rnumel
r1 = rindex
tmp12 = tl.load(in_ptr0 + (r1 + (50257*(x0 % 511)) + (25731584*(x0 // 511))), rmask & xmask, other=0).to(tl.float32)
tmp13 = tmp12.to(tl.float32)
tmp14 = tmp13 - tmp3
tmp15 = tl.log(tmp10)
tmp16 = tmp14 - tmp15
tl.store(out_ptr2 + (r1 + (50257*x0)), tmp16, rmask & xmask)
def get_args():
arg_0 = rand_strided((8192, 50257), (50257, 1), device='xpu:0', dtype=torch.bfloat16)
arg_1 = rand_strided((8176, 50257), (50257, 1), device='xpu:0', dtype=torch.float32)
return arg_0, arg_1,
def call(args):
# with torch.xpu._DeviceGuard(0):
# torch.xpu.set_device(0)
# stream0 = get_xpu_stream(0)
grid=lambda meta: (8176, )
triton_red_fused__log_softmax__to_copy_21[grid](*args, 8176, 50257, 1, 1024)
# def benchmark_all_configs(args):
# with torch.xpu._DeviceGuard(0):
# torch.xpu.set_device(0)
# return triton_red_fused__log_softmax__to_copy_21.benchmark_all_configs(*args, 8176, 50257, grid=grid(8176))
if __name__ == '__main__':
# from torch._inductor.utils import get_num_bytes
# from intel_extension_for_pytorch._inductor.xpu.utils import do_bench
args = get_args()
call(args)
# ms = do_bench(lambda: call(args), rep=40, fast_flush=True)
# num_gb = get_num_bytes(*args, num_in_out_args=0) / 1e9
# gb_per_s = num_gb / (ms / 1e3)
# print(f"{ms:.3f}ms {num_gb:.3f}GB {gb_per_s:.2f}GB/s")