Skip to content

RandAugment incorrect behavior in tf graph mode #21169

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
gregstarr opened this issue Apr 15, 2025 · 5 comments
Open

RandAugment incorrect behavior in tf graph mode #21169

gregstarr opened this issue Apr 15, 2025 · 5 comments
Assignees
Labels

Comments

@gregstarr
Copy link

In tensorflow graph mode, RandAugment will repeatedly use the same augmentations, which were sampled during graph tracing.

It relies on shuffling a python list using random.shuffle, which only works during eager mode execution. In graph mode the operations are sampled then compiled, but the sampling process itself isn't compiled so the same operations are used repeatedly.

If I add a tf.print statement to this code:

random.shuffle(self._AUGMENT_LAYERS)
for layer_name in self._AUGMENT_LAYERS[: self.num_ops]:
    tf.print(layer_name, tf.executing_eagerly())  # <----
    augmentation_layer = getattr(self, layer_name)
    transformation[layer_name] = (
        augmentation_layer.get_random_transformation(
            data,
            training=training,
            seed=self._get_seed_generator(self.backend._backend),
        )
    )

then run this test:

def test_graph_issue(self):
    input_data = np.random.random((10, 8, 8, 3))
    layer = layers.RandAugment()
    ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
    print()
    for output in ds:
        output.numpy()

i get this output

equalization False
random_posterization False
equalization False
random_posterization False
equalization False
random_posterization False
equalization False
random_posterization False
equalization False
random_posterization False
@gregstarr
Copy link
Author

to clarify, the correct behavior would do random augmentations each time, so the printout would not repeat the same 2 but rather all the augmentations uniformly.

@fchollet
Copy link
Collaborator

You're right, we should be using cond operations here, conditioned on variables sampled via keras.random ops. That way the graph will include the conditional branches.

@shashaka what do you think?

@shashaka
Copy link
Contributor

@gregstarr
Thank you for reporting the issue. I agree that it should be fixed. As @fchollet mentioned, using a conditional operation seems like a good solution. If you have some time, would you be able to create a PR to address this?

@gregstarr
Copy link
Author

yes I will give it a shot

@gregstarr
Copy link
Author

can you offer some advice on this PR: #21185 ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants