Skip to content

mixed_bfloat16 in TPU is slower than float32 #18448

Open
@chenmoneygithub

Description

@chenmoneygithub

In short, we observed mixed_bfloat16 in TPU is slower than float32 in our model benchmarks. Please refer to this sheet (internal only) for comparison results.

To reproduce in JAX backend, on TPU VM, use the command below:

cd benchmarks
KERAS_BACKEND=jax python3 -m model_benchmark.image_classification_benchmark  \
   --model="ResNet50V2"  \
   --epochs=1 \
   --batch_size=32 \ 
   --mixed_precision_policy="mixed_bfloat16"

To reproduce in TF backend, you need to modify the code to connect to TPU and use a TPU strategy.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions