Skip to content

Commit

Permalink
Merge pull request #21244 from google:add-tpu-core-count
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634510539
  • Loading branch information
jax authors committed May 16, 2024
2 parents 74ffed9 + 76a7b19 commit 9dd98dc
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ jobs:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
tpu: [
{type: "v3-8", cores: "4"},
{type: "v4-8", cores: "4"},
{type: "v5e-8", cores: "8"}
]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240228
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu.type }}"]
timeout-minutes: 120
defaults:
run:
Expand Down Expand Up @@ -84,7 +88,7 @@ jobs:
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=4 --tb=short \
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
Expand All @@ -95,5 +99,5 @@ jobs:
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu.type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"

0 comments on commit 9dd98dc

Please sign in to comment.