Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix time_distributed layer with mask and partial_batch_size #20765

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

Surya2k1
Copy link
Contributor

If the model includes an Embedding layer with mask_zero = True parameter and sub sequent model has time distributed layer it is observed that training fails in graph mode if there is partial_batch_size . This happens due to concatenation of partial batch dataset which makes batch_size to None and hence shape to (None,...).

Hence the model fails with graph execution error if we try to compare batch_size with the respective value from mask.

Hence I am proposing to omit the batch_size comparison for TF backend with graph mode. It would have been better if this check is for when there is actually a partial_batch_size but not sure how to propogate this to time distributed layer .

Fixes #20754

Code to replicate the issue:

import keras
import numpy as np

model = keras.Sequential([
    keras.Input(shape = (20,)),
    keras.layers.Embedding( input_dim = 10,
                           output_dim = 5,
                            mask_zero = True
                            ),

    keras.layers.Bidirectional(keras.layers.LSTM(units = 10, return_sequences = True )),
    keras.layers.TimeDistributed(keras.layers.Dense(units = 5, activation = "softmax")  ),#not works with mask_zero
])
model.compile(
    optimizer="adam",
    loss="binary_crossentropy",
    metrics=["accuracy"],
)
X_train = np.random.uniform(1,10,size = (50,20))
Y_train = np.random.randint(1,2,size = (50,20,5))

model.fit(X_train, Y_train, epochs = 2, batch_size = 8)

@codecov-commenter
Copy link

codecov-commenter commented Jan 15, 2025

Codecov Report

Attention: Patch coverage is 42.85714% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.00%. Comparing base (e345cbd) to head (2ae666a).
Report is 12 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/layers/rnn/time_distributed.py 42.85% 2 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20765      +/-   ##
==========================================
+ Coverage   81.96%   82.00%   +0.03%     
==========================================
  Files         554      557       +3     
  Lines       51656    52019     +363     
  Branches     7996     8040      +44     
==========================================
+ Hits        42342    42660     +318     
- Misses       7367     7404      +37     
- Partials     1947     1955       +8     
Flag Coverage Δ
keras 81.82% <28.57%> (+0.03%) ⬆️
keras-jax 64.21% <0.00%> (+0.18%) ⬆️
keras-numpy 58.95% <0.00%> (+0.03%) ⬆️
keras-openvino 29.89% <0.00%> (-0.07%) ⬇️
keras-tensorflow 64.78% <28.57%> (+0.05%) ⬆️
keras-torch 64.16% <0.00%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Surya2k1
Copy link
Contributor Author

Ping: @mattdangerw

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

f"({batch_size}, {timesteps}, ...), "
f"received: mask.shape={mask_shape}"
)
if backend.backend() == "tensorflow" and not tf.executing_eagerly():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the fix TF-only? What makes the problem TF specific?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is issue with TF backend + Graph Mode + partial_batch_size dataset,i.e when total_samples % batch_size !=0 In which case the dataset adapter returns a strided_slice as batch_shape which is not an integer value. Hence the below check fails with OperatorNotAllowedInGraphError

if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps,): raise ValueError(...)

I am assuming that there is a limitation to set batch_size to a exact integer for this case since last batch will have different batch_size. Hence proposing to skip the evaluation of batch_size for this case i.e "TF backend + Graph Mode + partial_batch_size dataset".

But this change also omits batch_size check for dataset which don't have partial_batch_size also. Hence the below change can also a better fix instead of current commit.

        if backend.backend() == "tensorflow" and not tf.executing_eagerly():
            mask_shape = list(mask_shape)
            mask_shape[0] = tf.get_static_value(mask_shape[0])
            timesteps = tf.get_static_value(timesteps)
            batch_size = tf.get_static_value(batch_size)
            mask_shape = tuple(mask_shape)
        if mask_shape is not None and mask_shape[:2] != (batch_size, timesteps,):
            raise ValueError(
                "`TimeDistributed` Layer should be passed a `mask` of "
                f"shape ({batch_size}, {timesteps}, ...), "
                f"received: mask.shape={mask_shape}"
            )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
Status: Assigned Reviewer
Development

Successfully merging this pull request may close these issues.

Problem with using masking in Embedding Layer for POS Tagging Model
4 participants