Skip to content

Commit 3fa2108

Browse files
committed
Remove numba-scipy dependency
1 parent 9b7d707 commit 3fa2108

File tree

3 files changed

+3
-8
lines changed

3 files changed

+3
-8
lines changed

.github/workflows/test.yml

+2-6
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,7 @@ jobs:
141141
shell: bash -l {0}
142142
run: |
143143
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
144-
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
145-
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
146-
# PyTensor next, pip installs a lower version of numpy via the PyPI.
147-
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION == "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numpy<1.23" "numba>=0.57" numba-scipy; fi
148-
if [[ $INSTALL_NUMBA == "1" ]] && [[ $PYTHON_VERSION != "3.9" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57" numba-scipy; fi
144+
if [[ $INSTALL_NUMBA == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
149145
if [[ $INSTALL_JAX == "1" ]]; then mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
150146
pip install -e ./
151147
mamba list && pip freeze
@@ -199,7 +195,7 @@ jobs:
199195
- name: Install dependencies
200196
shell: bash -l {0}
201197
run: |
202-
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" numba-scipy jax jaxlib pytest-benchmark
198+
mamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
203199
pip install -e ./
204200
mamba list && pip freeze
205201
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

environment.yml

-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ dependencies:
2323
- libblas=*=*mkl
2424
# numba backend
2525
- numba>=0.57
26-
- numba-scipy
2726
# For testing
2827
- coveralls
2928
- diff-cover

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ tests = [
8080
]
8181
rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"]
8282
jax = ["jax", "jaxlib"]
83-
numba = ["numba>=0.55", "numba-scipy>=0.3.0"]
83+
numba = ["numba>=0.57"]
8484

8585
[tool.setuptools.packages.find]
8686
include = ["pytensor*", "bin"]

0 commit comments

Comments
 (0)