-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathVSRN_EncoderRNN.py
68 lines (53 loc) · 2.52 KB
/
VSRN_EncoderRNN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.nn as nn
class EncoderRNN(nn.Module):
def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0.5,
n_layers=1, bidirectional=False, rnn_cell='gru'):
"""
Args:
hidden_dim (int): dim of hidden state of rnn
input_dropout_p (int): dropout probability for the input sequence
dropout_p (float): dropout probability for the output sequence
n_layers (int): number of rnn layers
rnn_cell (str): type of RNN cell ('LSTM'/'GRU')
"""
super(EncoderRNN, self).__init__()
self.dim_vid = dim_vid
self.dim_hidden = dim_hidden
self.input_dropout_p = input_dropout_p
self.rnn_dropout_p = rnn_dropout_p
self.n_layers = n_layers
# 双向
self.bidirectional = bidirectional
self.rnn_cell = rnn_cell
self.vid2hid = nn.Linear(dim_vid, dim_hidden)
self.input_dropout = nn.Dropout(input_dropout_p)
if rnn_cell.lower() == 'lstm':
self.rnn_cell = nn.LSTM
# 构建全局关系
elif rnn_cell.lower() == 'gru':
self.rnn_cell = nn.GRU
# 建立 RNN 结构
self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
bidirectional=bidirectional, dropout=self.rnn_dropout_p)
self._init_hidden()
def _init_hidden(self):
nn.init.xavier_normal_(self.vid2hid.weight)
def forward(self, vid_feats):
"""
Applies a multi-layer RNN to an input sequence.
Args:
input_var (batch, seq_len): tensor containing the features of the input sequence.
input_lengths (list of int, optional): A list that contains the lengths of sequences
in the mini-batch
Returns: output, hidden
- **output** (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence
- **hidden** (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h
"""
batch_size, seq_len, dim_vid = vid_feats.size()
vid_feats = self.vid2hid(vid_feats.view(-1, dim_vid))
vid_feats = self.input_dropout(vid_feats)
vid_feats = vid_feats.view(batch_size, seq_len, self.dim_hidden)
self.rnn.flatten_parameters()
# 获得全局关系
output, hidden = self.rnn(vid_feats)
return output, hidden