You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In short, we observed mixed_bfloat16 in TPU is slower than float32 in our model benchmarks. Please refer to this sheet (internal only) for comparison results.
To reproduce in JAX backend, on TPU VM, use the command below:
The sheet is accessible for me.
Mixed precision will speedup will only speed up models on recent NVIDIA GPUs and Google TPUs. NVIDIA GPUs support using a mix of float16 and float32, while TPUs support a mix of bfloat16 and float32. More details you can find here.
On which hardware you are using mixed_bfloat16 and float32 ?
@mehtamansi29 Thanks for looking into that! I am not sure if the result is still valid, that's a benchmark I did before the first official release of Keras 3. The TPU was v3-8, which is a very old distribution as of today.
In short, we observed
mixed_bfloat16
in TPU is slower thanfloat32
in our model benchmarks. Please refer to this sheet (internal only) for comparison results.To reproduce in JAX backend, on TPU VM, use the command below:
To reproduce in TF backend, you need to modify the code to connect to TPU and use a TPU strategy.
The text was updated successfully, but these errors were encountered: