-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsequence_rnn.py
53 lines (41 loc) · 1.04 KB
/
sequence_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import tensorflow as tf
import numpy as np
from tensorflow.models.rnn.rnn import *
class SequenceRNN(object):
"""
Base class for sequenced data.
"""
def __init__(self):
self._early_stop = None
self._seq_input = None
self._seq_target = None
self._initial_state = None
self._final_state = None
self._output = None
self._train_op = None
def assign_lr(self, session, lr_value):
session.run(tf.assign(self._lr, lr_value))
@property
def early_stop(self):
return self._early_stop
@property
def seq_input(self):
return self._seq_input
@property
def seq_target(self):
return self._seq_target
@property
def initial_state(self):
return self._initial_state
@property
def final_state(self):
return self._final_state
@property
def error(self):
return self._error
@property
def train_op(self):
return self._train_op
@property
def lr(self):
return self._lr