Skip to content

Commit

Permalink
Install jax things from a separate environment file
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelosthege authored and ricardoV94 committed Oct 11, 2023
1 parent 57d73cc commit df7b267
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,12 @@ jobs:
- name: Cache conda
uses: actions/cache@v3
env:
# Increase this value to reset cache if environment-test.yml has not changed
# Increase this value to reset cache if environment-jax.yml has not changed
CACHE_NUMBER: 0
with:
path: ~/conda_pkgs_dir
key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{
hashFiles('conda-envs/environment-test.yml') }}
hashFiles('conda-envs/environment-jax.yml') }}
- name: Cache multiple paths
uses: actions/cache@v3
env:
Expand All @@ -383,7 +383,7 @@ jobs:
mamba-version: "*"
activate-environment: pymc-test
channel-priority: strict
environment-file: conda-envs/environment-test.yml
environment-file: conda-envs/environment-jax.yml
python-version: ${{matrix.python-version}}
use-mamba: true
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
Expand All @@ -392,10 +392,6 @@ jobs:
conda activate pymc-test
pip install -e .
python --version
- name: Install external samplers
run: |
conda activate pymc-test
pip install "numpyro>=0.8.0" "blackjax>=1.0.0"
- name: Run tests
run: |
python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET
Expand Down
38 changes: 38 additions & 0 deletions conda-envs/environment-jax.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# "test" conda envs are used to set up our CI environment in GitHub actions
name: pymc-test
channels:
- conda-forge
- defaults
dependencies:
# Base dependencies
- arviz>=0.13.0
- blas
- cachetools>=4.2.1
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
# Jaxlib version must not be greater than jax version!
- blackjax>=1.0.0
- jaxlib==0.4.14
- jax==0.4.16
- libblas=*=*mkl
- mkl-service
- numpy>=1.15.0
- numpyro>=0.8.0
- pandas>=0.24.0
- pip
- pytensor>=2.17.0,<2.18
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
- ipython>=7.16
- pre-commit>=2.8.0
- pytest-cov>=2.5
- pytest>=3.0
- mypy=1.5.1
- types-cachetools
- pip:
- numdifftools>=0.9.40
- mcbackend>=0.4.0
1 change: 1 addition & 0 deletions scripts/generate_pip_deps_from_conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"networkx",
"blas",
"jax",
"jaxlib",
}
RENAME = {}

Expand Down

0 comments on commit df7b267

Please sign in to comment.