Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional API not work as expected when concatenating two models with multiple output & input #20325

Closed
malciin opened this issue Oct 5, 2024 · 1 comment
Assignees

Comments

@malciin
Copy link

malciin commented Oct 5, 2024

Keras version: 3.6.0
OS: Win

Hello,

Lets say I've got following two models A and B:

A_input = keras.Input(shape=(4,))
A = keras.layers.Dense(5)(A_input)
A = keras.Model(inputs=A_input, outputs=[ keras.layers.Dense(4)(A), keras.layers.Dense(4)(A) ])

model

B_input = [ keras.Input(shape=(4,)), keras.Input(shape=(4,)) ]
B = keras.layers.Concatenate()(B_input)
B = keras.layers.Dense(5)(B)
B = keras.Model(inputs = B_input, outputs=B)

model

and I want to merge them into one model via keras.Model(inputs=A_input, outputs=B(A)) which unfortunately crashes

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[20], [line 1](vscode-notebook-cell:?execution_count=20&line=1)
----> [1](vscode-notebook-cell:?execution_count=20&line=1) merged = keras.Model(inputs=A_input, outputs=B(A)) # why not work?

File c:\Users\Marcin\.miniconda3\envs\torch\Lib\site-packages\keras\src\utils\traceback_utils.py:122, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    [119](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:119)     filtered_tb = _process_traceback_frames(e.__traceback__)
    [120](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:120)     # To get the full stack trace, call:
    [121](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:121)     # `keras.config.disable_traceback_filtering()`
--> [122](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:122)     raise e.with_traceback(filtered_tb) from None
    [123](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:123) finally:
    [124](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/utils/traceback_utils.py:124)     del filtered_tb

File c:\Users\Marcin\.miniconda3\envs\torch\Lib\site-packages\keras\src\layers\input_spec.py:160, in assert_input_compatibility(input_spec, inputs, layer_name)
    [158](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:158) inputs = tree.flatten(inputs)
    [159](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:159) if len(inputs) != len(input_spec):
--> [160](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:160)     raise ValueError(
    [161](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:161)         f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
    [162](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:162)         f" but it received {len(inputs)} input tensors. "
    [163](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:163)         f"Inputs received: {inputs}"
    [164](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:164)     )
    [165](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:165) for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    [166](file:///C:/Users/Marcin/.miniconda3/envs/torch/Lib/site-packages/keras/src/layers/input_spec.py:166)     if spec is None:

ValueError: Layer "functional_4" expects 2 input(s), but it received 1 input tensors. Inputs received: [<Functional name=functional_2, built=True>]

This looks like a bug to me, because following works:

B(A(keras.ops.ones(shape=(1, 4)))) #works

tensor([[-0.2388, -0.3490, -0.3166,  0.2736, -1.2349]], device='cuda:0',
       grad_fn=<AddBackward0>)

Temporarily I've found following workaround to create that merged model:

merged = keras.Model(inputs=A_input, outputs=B(A(A_input)))

but that have a caveats it plots model with a loop in input:

image

@james77777778
Copy link
Contributor

You cannot directly use a keras.Model as the input to another model. Instead, you can use the Model.inputs and Model.outputs to connect the models together.

Try this one:

import keras

A_input = keras.Input(shape=(4,))
A = keras.layers.Dense(5)(A_input)
A = keras.Model(
    inputs=A_input, outputs=[keras.layers.Dense(4)(A), keras.layers.Dense(4)(A)]
)

B_input = [keras.Input(shape=(4,)), keras.Input(shape=(4,))]
B = keras.layers.Concatenate()(B_input)
B = keras.layers.Dense(5)(B)
B = keras.Model(inputs=B_input, outputs=B)

C = keras.Model(inputs=A_input, outputs=B(A.outputs))
keras.utils.plot_model(C)

@malciin malciin closed this as completed Oct 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants