-
Notifications
You must be signed in to change notification settings - Fork 1
/
RNN.cu
30 lines (26 loc) · 1020 Bytes
/
RNN.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#include "RNN.h"
/*
* inputs: seq * [batch_size, input_size]
* pre_hidden: num_layers * [b, hidden_size]
* output: [b, hidden_size], we need return all hidden outputs for CTC decoder
*/
cuMatrix<float>* RNN::forward(cuMatrix<float>* inputs) {
cuMatrix<float>* x_t;
cuMatrix<float>* h_t;
cuMatrix<float>* h_prev;
cuMatrix<float>* h_ts[num_layers];
for (int t = 0; t < time_step; t++){
// x_t = inputs[t];
x_t = new cuMatrix<float>(inputs, t * batch_size * input_size, batch_size, input_size, 1);
for(int l = 0; l < num_layers; l++){
h_prev = t == 0 ? h_0s[l] : h_ts[l];
h_t = new cuMatrix<float>(hiddens[l], t * batch_size * hidden_size, batch_size, hidden_size, 1);
rnn_cell[l]->forward(x_t, h_prev, h_t);
if (t > 0) delete h_prev; // delete temporary matrix
h_ts[l] = h_t;
if (l == 0) delete x_t; // delete temporary matrix
x_t = h_t;
}
}
return hiddens[num_layers-1];
}