Skip to content

JAX GPU Test aborts in PyDatasetAdapter if multiprocessing=True #18431

Open
@sampathweb

Description

@sampathweb

Repro steps in Colab Pro with GPU enabled terminal:

pip install -U tensorflow  # Update TF to 2.13
git clone https://github.com/keras-team/keras-core.git
cd keras-core

KERAS_BACKEND=jax pytest keras_core --ignore keras_core/applications

This will abort at 98% in test - PyDatasetAdapterTest when it runs with multiprocessing=True option.

jax-colab-test-run-error2

Buf when its running the test independently or even at higher level folders it doesn't abort. Also, it doesn't abort for TensorFlow or Torch backends. I am not sure why it aborts only for JAX GPU when running the entire test suite. May be multiprocessing with pytest doesn't play well with XLA / JAX / Cuda?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions