-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmha.py
131 lines (111 loc) · 6.98 KB
/
mha.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex.contrib.multihead_attn import fast_mask_softmax_dropout_func
from bmm1 import *
from bmm2 import *
from padding import *
from softmax import *
class FastUnpadBertSelfAttention(nn.Module):
def __init__(self, config, enable_stream=True, enable_sync=True, fuse_mask=True, fuse_scale=True, fuse_qkv=True, fuse_dropout=True, apex_softmax=True, pad=True):
super(FastUnpadBertSelfAttention, self).__init__()
if config.hidden_size % config.num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.hidden_size = config.hidden_size
self.fuse_qkv = fuse_qkv
self.fuse_scale = fuse_scale
self.fuse_mask = fuse_mask
self.fuse_dropout = fuse_dropout
self.apex_softmax = apex_softmax
self.pad = pad
self.enable_stream = enable_stream
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
if self.fuse_qkv:
self.bmm1 = Bmm1Strided(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync, timer=False)
self.bmm2 = Bmm2Strided(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.bmm1 = Bmm1(None,None,self.num_attention_heads,self.attention_head_size, scale=self.fuse_scale, stream=enable_stream, sync=enable_sync)
self.bmm2 = Bmm2(None,None,self.num_attention_heads,self.attention_head_size, stream=enable_stream, sync=enable_sync)
if self.fuse_dropout == False:
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
if self.fuse_mask == True and self.fuse_dropout == True:
self.softmax = FastMaskSoftmaxDropout(dim=-1, dropout_prob=config.attention_probs_dropout_prob,stream=enable_stream, sync=(not self.pad), timer=False)
elif self.fuse_mask == True:
self.softmax = FastMaskSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
else:
self.softmax = FastSoftmax(dim=-1, stream=enable_stream, sync=enable_sync, timer=False)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 1, 3)
def transpose_key_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = torch.reshape(x, new_x_shape)
return x.permute(0, 2, 3, 1)
def pytorch_softmax(self,attention_scores, batch, seqlen, heads):
ntokens2 = 0
for i in range(batch):
ntokens2 += seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs = torch.zeros(ntokens2, device="cuda", dtype=torch.float16)
ntokens2 = 0
for i in range(batch):
tokens2 = seqlen[i]*seqlen[i]*self.num_attention_heads
attention_probs[ntokens2:ntokens2+tokens2] = F.softmax(attention_scores[ntokens2:ntokens2+tokens2].view(1,self.num_attention_heads,seqlen[i],seqlen[i]), dim=-1).flatten().contiguous()
ntokens2 += tokens2
return attention_probs
def forward(self, hidden_states, attention_mask, seqlen, batch, is_training=True):
self.batch = batch
# QKV
if self.fuse_qkv:
weight = torch.cat([self.query.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.key.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size), self.value.weight.view(self.num_attention_heads,self.attention_head_size,1,self.hidden_size)], dim=1).reshape(self.all_head_size*3,self.hidden_size).contiguous()
bias = torch.cat([self.query.bias.view(self.num_attention_heads,1,self.attention_head_size), self.key.bias.view(self.num_attention_heads,1,self.attention_head_size), self.value.bias.view(self.num_attention_heads,1,self.attention_head_size)],dim=1).reshape(3*self.hidden_size).contiguous()
mixed_x_layer = torch.addmm(bias, hidden_states, weight.t())
else:
query_layer = self.query(hidden_states)
key_layer = self.key(hidden_states)
value_layer = self.value(hidden_states)
# BMM1.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
attention_scores, qkv_layer = self.bmm1(mixed_x_layer, self.batch, seqlen)
else:
attention_scores = self.bmm1(query_layer, key_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_scale == False:
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Softmax.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_mask ==True and self.fuse_dropout == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads, is_training)
elif self.fuse_mask == True:
attention_probs = self.softmax(attention_scores, attention_mask, self.batch, seqlen, self.num_attention_heads)
else:
attention_scores = attention_scores + attention_mask.view(-1)
if self.apex_softmax == True:
attention_probs = self.softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
else:
if self.pad == True:
attention_probs = F.softmax(attention_scores.view(batch,self.num_attention_heads,seqlen[0],seqlen[0]), dim=-1).flatten().contiguous()
else:
attention_probs = self.pytorch_softmax(attention_scores, self.batch, seqlen, self.num_attention_heads)
# Dropout.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_dropout == False:
attention_probs = self.dropout(attention_probs)
# BMM2.
if self.enable_stream: torch.cuda.synchronize()
if self.fuse_qkv:
context_layer = self.bmm2(attention_probs, qkv_layer, self.batch, seqlen)
else:
context_layer = self.bmm2(attention_probs, value_layer, self.batch, seqlen)
if self.enable_stream: torch.cuda.synchronize()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = torch.reshape(context_layer, new_context_layer_shape)
return context_layer