diff --git a/recurrentshop/engine.py b/recurrentshop/engine.py index e8deb67..aa8f439 100644 --- a/recurrentshop/engine.py +++ b/recurrentshop/engine.py @@ -3,7 +3,8 @@ from keras import initializers from .backend import rnn, learning_phase_scope from .generic_utils import serialize_function, deserialize_function -from keras.engine.topology import Node, _collect_previous_mask, _collect_input_shape +from keras.engine.topology import Node +#from keras.engine.topology import Node, _collect_previous_mask, _collect_input_shape import inspect @@ -839,13 +840,13 @@ def _get_optional_input_placeholder(self, name=None, num=1): self._optional_input_placeholders[name] = self._get_optional_input_placeholder() return self._optional_input_placeholders[name] if num == 1: - optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0] + optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0] assert self._is_optional_input_placeholder(optional_input_placeholder) return optional_input_placeholder else: y = [] for _ in range(num): - optional_input_placeholder = _to_list(_OptionalInputPlaceHolder().inbound_nodes[0].output_tensors)[0] + optional_input_placeholder = _to_list(_OptionalInputPlaceHolder()._inbound_nodes[0].output_tensors)[0] assert self._is_optional_input_placeholder(optional_input_placeholder) y.append(optional_input_placeholder) return y @@ -1072,6 +1073,47 @@ def from_config(cls, config, custom_objects={}): rs.add(cell) return rs - + +def _collect_input_shape(input_tensors): + """Collects the output shape(s) of a list of Keras tensors. + # Arguments + input_tensors: list of input tensors (or single input tensor). + # Returns + List of shape tuples (or single tuple), one tuple per input. + """ + input_tensors = _to_list(input_tensors) + shapes = [] + for x in input_tensors: + try: + shapes.append(K.int_shape(x)) + except TypeError: + shapes.append(None) + if len(shapes) == 1: + return shapes[0] + return shapes + + +def _collect_previous_mask(input_tensors): + """Retrieves the output mask(s) of the previous node. + # Arguments + input_tensors: A tensor or list of tensors. + # Returns + A mask tensor or list of mask tensors. + """ + input_tensors = _to_list(input_tensors) + masks = [] + for x in input_tensors: + if hasattr(x, '_keras_history'): + inbound_layer, node_index, tensor_index = x._keras_history + node = inbound_layer._inbound_nodes[node_index] + mask = node.output_masks[tensor_index] + masks.append(mask) + else: + masks.append(None) + if len(masks) == 1: + return masks[0] + return masks + + # Legacy RecurrentContainer = RecurrentSequential