diff --git a/.github/workflows/_test_rosetta.yaml b/.github/workflows/_test_rosetta.yaml
index 017662ea3..172d0cc02 100644
--- a/.github/workflows/_test_rosetta.yaml
+++ b/.github/workflows/_test_rosetta.yaml
@@ -8,6 +8,10 @@ on:
description: 'Rosetta image build by NVIDIA/JAX-Toolbox'
required: true
default: 'ghcr.io/nvidia/t5x:latest'
+ TIMEOUT_MINUTES:
+ type: number
+ description: 'Maximum test runtime, in minutes'
+ default: 60
outputs:
TEST_ARTIFACT_NAME:
description: 'Name of the unit test artifact for downstream workflows'
@@ -21,8 +25,19 @@ env:
TEST_LOG_LOCAL_PATH: /log/unit-report.jsonl
jobs:
+ runner:
+ uses: ./.github/workflows/_runner_ondemand_slurm.yaml
+ with:
+ NAME: "A100"
+ LABELS: "A100,${{ github.run_id }}"
+ TIME: "${{ inputs.TIMEOUT_MINUTES }}:00"
+ secrets: inherit
+
rosetta-unit-tests:
- runs-on: [self-hosted, V100]
+ runs-on:
+ - self-hosted
+ - A100
+ - "${{ github.run_id }}"
outputs:
TEST_ARTIFACT_NAME: ${{ env.TEST_ARTIFACT_NAME }}
steps:
@@ -92,6 +107,6 @@ jobs:
BADGE_COLOR=yellow
fi
fi
- echo "LABEL='V100 Unit'" >> $GITHUB_OUTPUT
+ echo "LABEL='A100 Unit'" >> $GITHUB_OUTPUT
echo "MESSAGE='${PASSED_TESTS}/${SKIPPED_TESTS}/${FAILED_TESTS} pass/skip/fail'" >> $GITHUB_OUTPUT
- echo "COLOR='${BADGE_COLOR}'" >> $GITHUB_OUTPUT
+ echo "COLOR='${BADGE_COLOR}'" >> $GITHUB_OUTPUT
\ No newline at end of file
diff --git a/.github/workflows/_test_unit.yaml b/.github/workflows/_test_unit.yaml
index fa29557e0..d820eb348 100644
--- a/.github/workflows/_test_unit.yaml
+++ b/.github/workflows/_test_unit.yaml
@@ -37,7 +37,7 @@ jobs:
strategy:
fail-fast: false
matrix:
- GPU_ARCH: [V100, A100]
+ GPU_ARCH: [A100]
include:
- EXTRA_LABEL: "self-hosted"
# ensures A100 job lands on dedicated runner for this particular job
diff --git a/README.md b/README.md
index 4438f7efc..491ac32df 100644
--- a/README.md
+++ b/README.md
@@ -71,18 +71,11 @@ We support and test the following JAX frameworks and model architectures. More d