Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom activation functions cause TensorFlow to crash #20333

Open
AtticusBeachy opened this issue Oct 8, 2024 · 2 comments
Open

Custom activation functions cause TensorFlow to crash #20333

AtticusBeachy opened this issue Oct 8, 2024 · 2 comments
Assignees

Comments

@AtticusBeachy
Copy link

I originally posted this issue in the TensorFlow GitHub, and was told it looks like a Keras issue and I should post it here.

TensorFlow version:
2.17.0

OS:
Linux Mint 22

Python version:
3.12.7

Issue:
I can successfully define a custom activation function, but when I try to use it TensorFlow crashes.

Minimal reproducible example:

import tensorflow as tf
from tensorflow.keras.utils import get_custom_objects
from tensorflow.keras.layers import Activation

def fourier_activation_lambda(freq):
    fn = lambda x : tf.sin(freq*x)
    return(fn)

freq = 1.0
fourier = fourier_activation_lambda(freq)

get_custom_objects()["fourier"] = Activation(fourier)

print(3*"\n")
print(f"After addition: {get_custom_objects()=}")

x_input = tf.keras.Input(shape=[5])
activation = "fourier"
layer_2 = tf.keras.layers.Dense(100, input_shape = [5],
                                activation=activation,
                                )(x_input)

model = tf.keras.Model(inputs=x_input, outputs=layer_2)
model.compile(optimizer='adam', loss='mse')
model.summary()

The output of the print statement above indicates that the custom activation function was added successfully. Maybe the crash is related to "built=False"?

# output of print statement
get_custom_objects()={'fourier': <Activation name=activation, built=False>}

The error message reads:

# error message
Traceback (most recent call last):
  File "/home/orca/Downloads/minimal_tf_err.py", line 20, in <module>
    layer_2 = tf.keras.layers.Dense(100, input_shape = [5],
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/orca/.local/lib/python3.12/site-packages/keras/src/layers/core/dense.py", line 89, in __init__
    self.activation = activations.get(activation)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/orca/.local/lib/python3.12/site-packages/keras/src/activations/__init__.py", line 104, in get
    raise ValueError(
ValueError: Could not interpret activation function identifier: fourier
@fchollet
Copy link
Member

fchollet commented Oct 9, 2024

In the example above, you are passing the string "fourier" as activation. A string is not a tensor-in tensor-out callable, so it doesn't work as an activation.

Your code, simplified:

activation = "fourier"
layer_2 = tf.keras.layers.Dense(100, input_shape=[5], activation=activation)(x_input)

@AtticusBeachy
Copy link
Author

I added "fourier" as a key in the global dictionary of custom objects, as described here:

# add fourier function to global dictionary of custom objects
get_custom_objects()["fourier"] = Activation(fourier)

Further, the code runs successfully on my old computer (using TensorFlow 2.10.0). Instead of crashing it runs to the end and outputs the model summary:

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 5)]               0         
                                                                 
 dense (Dense)               (None, 100)               600       
                                                                 
=================================================================
Total params: 600
Trainable params: 600
Non-trainable params: 0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants