Skip to content

Commit

Permalink
fix show_time_distributed_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Dobiasd committed Apr 10, 2024
1 parent 80c15c2 commit 94a2d36
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions keras_export/convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,14 +652,13 @@ class CopiedLayer:

for attr in attributes:
try:
if attr not in ['input_shape', '__class__']:
if attr not in ['batch_shape', '__class__']:
setattr(copied_layer, attr, getattr(layer.layer, attr))
elif attr == 'input_shape':
setattr(copied_layer, 'input_shape', input_shape_new)
except Exception:
continue

setattr(copied_layer, "output_shape", getattr(layer, "output_shape"))
setattr(copied_layer, 'batch_shape', input_shape_new)
setattr(copied_layer, "output_shape", layer.output.shape)

return layer_function(copied_layer)

Expand Down Expand Up @@ -711,7 +710,7 @@ def get_layer_weights(layer, name):

result[name]['td_input_len'] = encode_floats(
np.array([len(get_layer_input_shape(layer)) - 1], dtype=np.float32))
result[name]['td_output_len'] = encode_floats(np.array([len(layer.output_shape) - 1], dtype=np.float32))
result[name]['td_output_len'] = encode_floats(np.array([len(layer.output.shape) - 1], dtype=np.float32))
return result


Expand Down

0 comments on commit 94a2d36

Please sign in to comment.