Open
Description
In short, we observed mixed_bfloat16
in TPU is slower than float32
in our model benchmarks. Please refer to this sheet (internal only) for comparison results.
To reproduce in JAX backend, on TPU VM, use the command below:
cd benchmarks
KERAS_BACKEND=jax python3 -m model_benchmark.image_classification_benchmark \
--model="ResNet50V2" \
--epochs=1 \
--batch_size=32 \
--mixed_precision_policy="mixed_bfloat16"
To reproduce in TF backend, you need to modify the code to connect to TPU and use a TPU strategy.