Skip to content

Commit df7b267

Browse files
michaelosthegericardoV94
authored andcommitted
Install jax things from a separate environment file
1 parent 57d73cc commit df7b267

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,12 @@ jobs:
358358
- name: Cache conda
359359
uses: actions/cache@v3
360360
env:
361-
# Increase this value to reset cache if environment-test.yml has not changed
361+
# Increase this value to reset cache if environment-jax.yml has not changed
362362
CACHE_NUMBER: 0
363363
with:
364364
path: ~/conda_pkgs_dir
365365
key: ${{ runner.os }}-py${{matrix.python-version}}-conda-${{ env.CACHE_NUMBER }}-${{
366-
hashFiles('conda-envs/environment-test.yml') }}
366+
hashFiles('conda-envs/environment-jax.yml') }}
367367
- name: Cache multiple paths
368368
uses: actions/cache@v3
369369
env:
@@ -383,7 +383,7 @@ jobs:
383383
mamba-version: "*"
384384
activate-environment: pymc-test
385385
channel-priority: strict
386-
environment-file: conda-envs/environment-test.yml
386+
environment-file: conda-envs/environment-jax.yml
387387
python-version: ${{matrix.python-version}}
388388
use-mamba: true
389389
use-only-tar-bz2: false # IMPORTANT: This may break caching of conda packages! See https://github.com/conda-incubator/setup-miniconda/issues/267
@@ -392,10 +392,6 @@ jobs:
392392
conda activate pymc-test
393393
pip install -e .
394394
python --version
395-
- name: Install external samplers
396-
run: |
397-
conda activate pymc-test
398-
pip install "numpyro>=0.8.0" "blackjax>=1.0.0"
399395
- name: Run tests
400396
run: |
401397
python -m pytest -vv --cov=pymc --cov-report=xml --no-cov-on-fail --cov-report term --durations=50 $TEST_SUBSET

conda-envs/environment-jax.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# "test" conda envs are used to set up our CI environment in GitHub actions
2+
name: pymc-test
3+
channels:
4+
- conda-forge
5+
- defaults
6+
dependencies:
7+
# Base dependencies
8+
- arviz>=0.13.0
9+
- blas
10+
- cachetools>=4.2.1
11+
- cloudpickle
12+
- fastprogress>=0.2.0
13+
- h5py>=2.7
14+
# Jaxlib version must not be greater than jax version!
15+
- blackjax>=1.0.0
16+
- jaxlib==0.4.14
17+
- jax==0.4.16
18+
- libblas=*=*mkl
19+
- mkl-service
20+
- numpy>=1.15.0
21+
- numpyro>=0.8.0
22+
- pandas>=0.24.0
23+
- pip
24+
- pytensor>=2.17.0,<2.18
25+
- python-graphviz
26+
- networkx
27+
- scipy>=1.4.1
28+
- typing-extensions>=3.7.4
29+
# Extra dependencies for testing
30+
- ipython>=7.16
31+
- pre-commit>=2.8.0
32+
- pytest-cov>=2.5
33+
- pytest>=3.0
34+
- mypy=1.5.1
35+
- types-cachetools
36+
- pip:
37+
- numdifftools>=0.9.40
38+
- mcbackend>=0.4.0

scripts/generate_pip_deps_from_conda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"networkx",
5555
"blas",
5656
"jax",
57+
"jaxlib",
5758
}
5859
RENAME = {}
5960

0 commit comments

Comments
 (0)