Skip to content

Functional API not work as expected when concatenating two models with multiple output & input #20325

Closed
@malciin

Description

@malciin

Keras version: 3.6.0
OS: Win

Hello,

Lets say I've got following two models A and B:

A_input = keras.Input(shape=(4,))
A = keras.layers.Dense(5)(A_input)
A = keras.Model(inputs=A_input, outputs=[ keras.layers.Dense(4)(A), keras.layers.Dense(4)(A) ])

model

B_input = [ keras.Input(shape=(4,)), keras.Input(shape=(4,)) ]
B = keras.layers.Concatenate()(B_input)
B = keras.layers.Dense(5)(B)
B = keras.Model(inputs = B_input, outputs=B)

model

and I want to merge them into one model via keras.Model(inputs=A_input, outputs=B(A)) which unfortunately crashes

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[20], [line 1](vscode-notebook-cell:?execution_count=20&line=1)
----> [1](vscode-notebook-cell:?execution_count=20&line=1) merged = keras.Model(inputs=A_input, outputs=B(A)) # why not work?

File c:\Users\Marcin\.miniconda3\envs\torch\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    [119](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:119)     filtered_tb = _process_traceback_frames(e.__traceback__)
    [120](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:120)     # To get the full stack trace, call:
    [121](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:121)     # `keras.config.disable_traceback_filtering()`
--> [122](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:122)     raise e.with_traceback(filtered_tb) from None
    [123](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:123) finally:
    [124](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:124)     del filtered_tb

File c:\Users\Marcin\.miniconda3\envs\torch\Lib\site-packages\keras\src\layers\input_spec.py:160, in assert_input_compatibility(input_spec, inputs, layer_name)
    [158](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:158) inputs = tree.flatten(inputs)
    [159](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:159) if len(inputs) != len(input_spec):
--> [160](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:160)     raise ValueError(
    [161](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:161)         f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
    [162](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:162)         f" but it received {len(inputs)} input tensors. "
    [163](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:163)         f"Inputs received: {inputs}"
    [164](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:164)     )
    [165](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:165) for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    [166](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:166)     if spec is None:

ValueError: Layer "functional_4" expects 2 input(s), but it received 1 input tensors. Inputs received: [<Functional name=functional_2, built=True>]

This looks like a bug to me, because following works:

B(A(keras.ops.ones(shape=(1, 4)))) #works

tensor([[-0.2388, -0.3490, -0.3166,  0.2736, -1.2349]], device='cuda:0',
       grad_fn=<AddBackward0>)

Temporarily I've found following workaround to create that merged model:

merged = keras.Model(inputs=A_input, outputs=B(A(A_input)))

but that have a caveats it plots model with a loop in input:

image

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions