Skip to content

add cuda dep to jaxlib #3

add cuda dep to jaxlib

add cuda dep to jaxlib #3

Workflow file for this run

name: spackify
on:
push
jobs:
build:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- name: Set up Spack
uses: spack/setup-spack@v2
with:
ref: develop # Spack version (examples: develop, releases/v0.21)
buildcache: true # Configure oci://ghcr.io/spack/github-actions-buildcache
color: true # Force color output (SPACK_COLOR=always)
path: spack # Where to clone Spack
- run: spack install --jobs 2 py-jax py-jaxlib+cuda cuda_arch=89
# TODO: do inside apptainer?
- run: eval `spack load --sh py-jax` && python -c 'import jax.numpy as jp; print(jp.zeros(8).device)'