Skip to content

Make map_coordinates differentiable for JAX 0.4.34 #347

Make map_coordinates differentiable for JAX 0.4.34

Make map_coordinates differentiable for JAX 0.4.34 #347

Workflow file for this run

name: Dependency test JAX
on:
pull_request:
types: [labeled]
workflow_dispatch:
jobs:
jax_tests:
if: ${{ github.event.label.name == 'test_jax' && github.event_name == 'pull_request' || github.event_name == 'workflow_dispatch' }}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
jax-version: [0.4.11, 0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17,
0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23,
0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29,
0.4.30, 0.4.31, 0.4.33, 0.4.34]
# 0.4.32 is not available on PyPI
# earlier jax versions are not compatible with other
# dependencies as of 2024-10-04
group: [1, 2]
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
- name: Install dependencies with given JAX version
run: |
sed -i '/jax/d' ./requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Verify dependencies
run: |
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
lscpu
python -m pytest -m unit \
--durations=0 \
--mpl \
--maxfail=1 \
--splits 3 \
--group ${{ matrix.group }} \
--splitting-algorithm least_duration