You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In tensorflow graph mode, RandAugment will repeatedly use the same augmentations, which were sampled during graph tracing.
It relies on shuffling a python list using random.shuffle, which only works during eager mode execution. In graph mode the operations are sampled then compiled, but the sampling process itself isn't compiled so the same operations are used repeatedly.
to clarify, the correct behavior would do random augmentations each time, so the printout would not repeat the same 2 but rather all the augmentations uniformly.
You're right, we should be using cond operations here, conditioned on variables sampled via keras.random ops. That way the graph will include the conditional branches.
@gregstarr
Thank you for reporting the issue. I agree that it should be fixed. As @fchollet mentioned, using a conditional operation seems like a good solution. If you have some time, would you be able to create a PR to address this?
In tensorflow graph mode, RandAugment will repeatedly use the same augmentations, which were sampled during graph tracing.
It relies on shuffling a python list using random.shuffle, which only works during eager mode execution. In graph mode the operations are sampled then compiled, but the sampling process itself isn't compiled so the same operations are used repeatedly.
keras/keras/src/layers/preprocessing/image_preprocessing/rand_augment.py
Line 173 in 44a655b
If I add a
tf.print
statement to this code:then run this test:
i get this output
The text was updated successfully, but these errors were encountered: