forked from lingyongyan/Neural-Machine-Translation
-
Notifications
You must be signed in to change notification settings - Fork 10
/
attention_decoder.py
77 lines (64 loc) · 3.35 KB
/
attention_decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
from attention import Attention
class AttentionDecoderRNN(nn.Module):
"""Recurrent neural network that makes use of gated recurrent units to translate encoded input using attention."""
def __init__(self,
tgt_vocab_size,
embedding_size,
hidden_size,
attn_model,
n_layers=1,
dropout=.1):
super(AttentionDecoderRNN, self).__init__()
self.tgt_vocab_size = tgt_vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.attn_model = attn_model
self.n_layers = n_layers
self.dropout = dropout
# Define layers
self.embedding = nn.Embedding(tgt_vocab_size, embedding_size)
self.dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size + embedding_size, hidden_size, n_layers, dropout=dropout)
self.out = nn.Linear(hidden_size * 2, tgt_vocab_size)
# Choose attention model
if attn_model is not None:
self.attention = Attention(attn_model, hidden_size)
def forward(self, input, decoder_context, hidden_state, encoder_outputs):
"""Run forward propagation one step at a time.
Get the embedding of the current input word (last output word) [s = 1 x batch_size x seq_len]
then combine them with the previous context. Use this as input and run through the RNN. Next,
calculate the attention from the current RNN state and all encoder outputs. The final output
is the next word prediction using the RNN hidden_state state and context vector.
Args:
input: torch Variable representing the word input constituent
decoder_context: torch Variable representing the previous context
hidden_state: torch Variable representing the previous hidden_state state output
encoder_outputs: torch Variable containing the encoder output values
Return:
output: torch Variable representing the predicted word constituent
context: torch Variable representing the context value
hidden_state: torch Variable representing the hidden_state state of the RNN
attention_weights: torch Variable retrieved from the attention model
"""
# Run through RNN
input = input.view(1, -1)
embedded = self.embedding(input) # [1, -1, embedding_size]
embedded = self.dropout(embedded)
# print(embedded.shape)
# print(decoder_context.shape)
rnn_input = torch.cat((embedded, decoder_context), 2) # [1, -1, embedding_size + hidden_size]
rnn_output, hidden_state = self.gru(rnn_input, hidden_state) # [1, -1, hidden_size]
# Calculate attention
# print(rnn_output.shape)
# print(encoder_outputs.shape)
attention_weights = self.attention(rnn_output.squeeze(0), encoder_outputs)
# print(attention_weights.shape)
context = attention_weights.bmm(encoder_outputs.transpose(0, 1)) # [-1, 1, hidden_size]
context = context.transpose(0, 1) # [1, -1, hidden_size]
# Predict output
output = F.log_softmax(self.out(torch.cat((rnn_output, context), 2)), dim=2)
output = output.squeeze(0)
return output, context, hidden_state, attention_weights