Open
Description
Repro steps in Colab Pro with GPU enabled terminal:
pip install -U tensorflow # Update TF to 2.13
git clone https://github.com/keras-team/keras-core.git
cd keras-core
KERAS_BACKEND=jax pytest keras_core --ignore keras_core/applications
This will abort at 98% in test - PyDatasetAdapterTest
when it runs with multiprocessing=True
option.
Buf when its running the test independently or even at higher level folders it doesn't abort. Also, it doesn't abort for TensorFlow or Torch backends. I am not sure why it aborts only for JAX GPU when running the entire test suite. May be multiprocessing with pytest doesn't play well with XLA / JAX / Cuda?