From df7b2674f4389cdfcf7aedbdbd1e0eeae3d923eb Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Tue, 10 Oct 2023 23:33:41 +0200 Subject: [PATCH] Install jax things from a separate environment file --- .github/workflows/tests.yml | 10 ++----- conda-envs/environment-jax.yml | 38 +++++++++++++++++++++++++ scripts/generate_pip_deps_from_conda.py | 1 + 3 files changed, 42 insertions(+), 7 deletions(-) create mode 100644 conda-envs/environment-jax.yml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4fe668ee0ba..8d33a7b898e 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -358,12 +358,12 @@ jobs: - name: Cache conda uses: actions/cache@v3 env: - # Increase this value to reset cache if environment-test.yml has not changed + # Increase this value to reset cache if environment-jax.yml has not changed CACHE_NUMBER: 0 with: path: ~/conda_pkgs_dir key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{ - hashFiles('conda-envs/environment-test.yml') }} + hashFiles('conda-envs/environment-jax.yml') }} - name: Cache multiple paths uses: actions/cache@v3 env: @@ -383,7 +383,7 @@ jobs: mamba-version: "*" activate-environment: pymc-test channel-priority: strict - environment-file: conda-envs/environment-test.yml + environment-file: conda-envs/environment-jax.yml python-version: ${{matrix.python-version}} use-mamba: true use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267 @@ -392,10 +392,6 @@ jobs: conda activate pymc-test pip install -e . python --version - - name: Install external samplers - run: | - conda activate pymc-test - pip install "numpyro>=0.8.0" "blackjax>=1.0.0" - name: Run tests run: | python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET diff --git a/conda-envs/environment-jax.yml b/conda-envs/environment-jax.yml new file mode 100644 index 00000000000..542d0ae27d3 --- /dev/null +++ b/conda-envs/environment-jax.yml @@ -0,0 +1,38 @@ +# "test" conda envs are used to set up our CI environment in GitHub actions +name: pymc-test +channels: +- conda-forge +- defaults +dependencies: +# Base dependencies +- arviz>=0.13.0 +- blas +- cachetools>=4.2.1 +- cloudpickle +- fastprogress>=0.2.0 +- h5py>=2.7 +# Jaxlib version must not be greater than jax version! +- blackjax>=1.0.0 +- jaxlib==0.4.14 +- jax==0.4.16 +- libblas=*=*mkl +- mkl-service +- numpy>=1.15.0 +- numpyro>=0.8.0 +- pandas>=0.24.0 +- pip +- pytensor>=2.17.0,<2.18 +- python-graphviz +- networkx +- scipy>=1.4.1 +- typing-extensions>=3.7.4 +# Extra dependencies for testing +- ipython>=7.16 +- pre-commit>=2.8.0 +- pytest-cov>=2.5 +- pytest>=3.0 +- mypy=1.5.1 +- types-cachetools +- pip: + - numdifftools>=0.9.40 + - mcbackend>=0.4.0 diff --git a/scripts/generate_pip_deps_from_conda.py b/scripts/generate_pip_deps_from_conda.py index cbdc7791fa4..69bdbb49f2a 100755 --- a/scripts/generate_pip_deps_from_conda.py +++ b/scripts/generate_pip_deps_from_conda.py @@ -54,6 +54,7 @@ "networkx", "blas", "jax", + "jaxlib", } RENAME = {}