Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
farizrahman4u committed Sep 30, 2017
1 parent b6872a0 commit 3a8ea08
Showing 1 changed file with 2 additions and 16 deletions.
18 changes: 2 additions & 16 deletions recurrentshop/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class RecurrentModel(Recurrent):

# INITIALIZATION

def __init__(self, input, output, initial_states=None, final_states=None, readout_input=None, teacher_force=False, decode=False, output_length=None, return_states=False, state_initializer=None, preprocess_function=None, **kwargs):
def __init__(self, input, output, initial_states=None, final_states=None, readout_input=None, teacher_force=False, decode=False, output_length=None, return_states=False, state_initializer=None, **kwargs):
inputs = [input]
outputs = [output]
state_spec = None
Expand Down Expand Up @@ -328,14 +328,6 @@ def __init__(self, input, output, initial_states=None, final_states=None, readou
state_initializer += [None] * (self.num_states - len(state_initializer))
state_initializer = [initializers.get(init) if init else initializers.get('zeros') for init in state_initializer]
self.state_initializer = state_initializer
if preprocess_function is None:
self._preprocess_function = None
elif type(preprocess_function) is tuple:
self._preprocess_function = deserialize_function(preprocess_function)
else:
self._preprocess_function = preprocess_function
preprocess_function = serialize_function(preprocess_function)
self.preprocess_function = preprocess_function

def build(self, input_shape):
if type(input_shape) is list:
Expand Down Expand Up @@ -655,11 +647,6 @@ def call(self, inputs, initial_state=None, initial_readout=None, ground_truth=No
else:
return y

def preprocess_input(self, input, training=None):
if self._preprocess_function is None:
return input
return self._preprocess_function(input)

def step(self, inputs, states):
states = list(states)
if self.teacher_force:
Expand Down Expand Up @@ -809,8 +796,7 @@ def get_config(self):
'decode': self.decode,
'output_length': self.output_length,
'return_states': self.return_states,
'state_initializer': self._serialize_state_initializer(),
'preprocess_function': self.preprocess_function
'state_initializer': self._serialize_state_initializer()
}
base_config = super(RecurrentModel, self).get_config()
config.update(base_config)
Expand Down

0 comments on commit 3a8ea08

Please sign in to comment.