diff --git a/recurrentshop/engine.py b/recurrentshop/engine.py index 13ebdfd..e8deb67 100644 --- a/recurrentshop/engine.py +++ b/recurrentshop/engine.py @@ -272,7 +272,7 @@ class RecurrentModel(Recurrent): # INITIALIZATION - def __init__(self, input, output, initial_states=None, final_states=None, readout_input=None, teacher_force=False, decode=False, output_length=None, return_states=False, state_initializer=None, preprocess_function=None, **kwargs): + def __init__(self, input, output, initial_states=None, final_states=None, readout_input=None, teacher_force=False, decode=False, output_length=None, return_states=False, state_initializer=None, **kwargs): inputs = [input] outputs = [output] state_spec = None @@ -328,14 +328,6 @@ def __init__(self, input, output, initial_states=None, final_states=None, readou state_initializer += [None] * (self.num_states - len(state_initializer)) state_initializer = [initializers.get(init) if init else initializers.get('zeros') for init in state_initializer] self.state_initializer = state_initializer - if preprocess_function is None: - self._preprocess_function = None - elif type(preprocess_function) is tuple: - self._preprocess_function = deserialize_function(preprocess_function) - else: - self._preprocess_function = preprocess_function - preprocess_function = serialize_function(preprocess_function) - self.preprocess_function = preprocess_function def build(self, input_shape): if type(input_shape) is list: @@ -655,11 +647,6 @@ def call(self, inputs, initial_state=None, initial_readout=None, ground_truth=No else: return y - def preprocess_input(self, input, training=None): - if self._preprocess_function is None: - return input - return self._preprocess_function(input) - def step(self, inputs, states): states = list(states) if self.teacher_force: @@ -809,8 +796,7 @@ def get_config(self): 'decode': self.decode, 'output_length': self.output_length, 'return_states': self.return_states, - 'state_initializer': self._serialize_state_initializer(), - 'preprocess_function': self.preprocess_function + 'state_initializer': self._serialize_state_initializer() } base_config = super(RecurrentModel, self).get_config() config.update(base_config)