forked from SAITPublic/MLPerf_Training_v2.0
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbmm2.py
130 lines (103 loc) · 5.27 KB
/
bmm2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
import mhalib
###########################################################################################
class Bmm2Function(torch.autograd.Function):
@staticmethod
def forward(ctx, batch1, batch2, seqlen, batch, maxseqlen, heads, embed, sync, stream):
ctx.save_for_backward(batch1, batch2, seqlen)
ctx.batch = batch
ctx.maxseqlen = maxseqlen
ctx.heads = heads
ctx.embed = embed
ctx.stream = stream
ctx.sync = sync
ntokens = seqlen.sum().item()
ctx.ntokens = ntokens
output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16)
mhalib.FastBmm2Fprop(batch2.flatten().contiguous(), batch1.flatten().contiguous(), output, batch, seqlen, heads, embed, False, False, stream, sync)
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
batch1, batch2, seqlen = ctx.saved_tensors
batch = ctx.batch
maxseqlen = ctx.maxseqlen
heads = ctx.heads
embed = ctx.embed
ntokens = ctx.ntokens
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]
grad_batch1 = torch.empty([ntokens2*heads], device="cuda", dtype=torch.float16)
grad_batch2 = torch.empty([ntokens,heads*embed], device="cuda", dtype=torch.float16)
mhalib.FastBmm2Dgrad1(batch2.flatten().contiguous(), grad_output, grad_batch1, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync)
mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_batch2, batch, seqlen, heads, embed, False, False, ctx.stream, ctx.sync)
return grad_batch1[:ntokens2*heads], grad_batch2[:ntokens], None, None, None, None, None, None, None
class Bmm2(torch.nn.Module):
def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True):
super(Bmm2, self).__init__()
self.heads = heads
self.embed = embed
self.maxseqlen = seqlen
self.stream = stream
self.sync = sync
def forward(self, batch1, batch2, batch, seqlen):
return Bmm2Function.apply(batch1, batch2, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync)
###########################################################################################
class Bmm2StridedFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, batch1, mixed, seqlen, batch, maxseqlen, heads, embed, stream, sync, timers):
ctx.save_for_backward(batch1, mixed, seqlen)
ctx.batch = batch
ctx.maxseqlen = maxseqlen
ctx.heads = heads
ctx.embed = embed
ctx.stream = stream
ctx.sync = sync
ctx.timers = timers
ntokens = seqlen.sum().item()
ctx.ntokens = ntokens
output = torch.empty([ntokens,heads,embed], device="cuda", dtype=torch.float16)
if timers: timers['start_fprop'].record()
mhalib.FastBmm2Fprop(mixed, batch1, output, batch, seqlen, heads, embed, False, True, stream, sync)
if timers: timers['stop_fprop'].record()
return output[:ntokens]
@staticmethod
def backward(ctx, grad_output):
batch1, mixed, seqlen = ctx.saved_tensors
batch = ctx.batch
maxseqlen = ctx.maxseqlen
heads = ctx.heads
embed = ctx.embed
ntokens = ctx.ntokens
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]
grad_batch1 = torch.empty(ntokens2*heads, device="cuda", dtype=torch.float16)
grad_mixed = torch.empty([ntokens,heads*3*embed], device="cuda", dtype=torch.float16)
if ctx.timers: ctx.timers['start_dgrad'].record()
mhalib.FastBmm2Dgrad1(mixed, grad_output, grad_batch1, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync)
if ctx.timers: ctx.timers['stop_dgrad'].record()
if ctx.timers: ctx.timers['start_wgrad'].record()
mhalib.FastBmm2Dgrad2(grad_output, batch1, grad_mixed, batch, seqlen, heads, embed, False, True, ctx.stream, ctx.sync)
if ctx.timers: ctx.timers['stop_wgrad'].record()
return grad_batch1[:ntokens2*heads], grad_mixed[:ntokens], None, None, None, None, None, None, None, None
class Bmm2Strided(torch.nn.Module):
def __init__(self, batch, seqlen, heads, embed, stream=True, sync=True, timer=False):
super(Bmm2Strided, self).__init__()
self.heads = heads
self.embed = embed
self.maxseqlen = seqlen
self.stream = stream
self.sync = sync
if timer:
self.timers = {'start_fprop':torch.cuda.Event(enable_timing=True),
'start_dgrad':torch.cuda.Event(enable_timing=True),
'start_wgrad':torch.cuda.Event(enable_timing=True),
'stop_fprop':torch.cuda.Event(enable_timing=True),
'stop_dgrad':torch.cuda.Event(enable_timing=True),
'stop_wgrad':torch.cuda.Event(enable_timing=True)}
else:
self.timers = None
def forward(self, batch1, mixed, batch, seqlen):
return Bmm2StridedFunction.apply(batch1, mixed, seqlen, batch, self.maxseqlen, self.heads, self.embed, self.stream, self.sync, self.timers)
###########################################################################################