Skip to content

Commit b68b973

Browse files
Merge branch 'main' into inplace-linalg
2 parents bb90822 + f799219 commit b68b973

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+1442
-967
lines changed

.github/workflows/test.yml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ jobs:
4444
- '.github/workflows/*.yml'
4545
- 'setup.cfg'
4646
- 'requirements.txt'
47+
- '.pre-commit-config.yaml'
4748
4849
style:
4950
name: Check code style
@@ -52,7 +53,7 @@ jobs:
5253
if: ${{ needs.changes.outputs.changes == 'true' }}
5354
strategy:
5455
matrix:
55-
python-version: ["3.9", "3.10", "3.11"]
56+
python-version: ["3.9", "3.12"]
5657
steps:
5758
- uses: actions/checkout@v4
5859
- uses: actions/setup-python@v5
@@ -70,7 +71,7 @@ jobs:
7071
strategy:
7172
fail-fast: false
7273
matrix:
73-
python-version: ["3.9", "3.11"]
74+
python-version: ["3.9", "3.12"]
7475
fast-compile: [0,1]
7576
float32: [0,1]
7677
install-numba: [0]
@@ -101,7 +102,7 @@ jobs:
101102
float32: 0
102103
part: "tests/link/numba"
103104
- install-numba: 1
104-
python-version: "3.11"
105+
python-version: "3.11" # TODO: bump to 3.12 when numba 0.59 is released
105106
fast-compile: 0
106107
float32: 0
107108
part: "tests/link/numba"
@@ -111,7 +112,7 @@ jobs:
111112
float32: 0
112113
part: "tests/link/jax"
113114
- install-jax: 1
114-
python-version: "3.11"
115+
python-version: "3.12"
115116
fast-compile: 0
116117
float32: 0
117118
part: "tests/link/jax"
@@ -139,7 +140,7 @@ jobs:
139140
- name: Install dependencies
140141
shell: bash -l {0}
141142
run: |
142-
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy<1.26" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
143+
mamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
143144
# numba-scipy downgrades the installed scipy to 1.7.3 in Python 3.9, but
144145
# not numpy, even though scipy 1.7 requires numpy<1.23. When installing
145146
# PyTensor next, pip installs a lower version of numpy via the PyPI.
@@ -249,7 +250,7 @@ jobs:
249250
- name: Set up Python
250251
uses: actions/setup-python@v5
251252
with:
252-
python-version: "3.11"
253+
python-version: "3.12"
253254

254255
- name: Install dependencies
255256
run: |

.pre-commit-config.yaml

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ exclude: |
77
)$
88
repos:
99
- repo: https://github.com/pre-commit/pre-commit-hooks
10-
rev: v4.4.0
10+
rev: v4.5.0
1111
hooks:
1212
- id: debug-statements
1313
exclude: |
@@ -19,42 +19,15 @@ repos:
1919
pytensor/tensor/variable\.py|
2020
)$
2121
- id: check-merge-conflict
22-
- repo: https://github.com/asottile/pyupgrade
23-
rev: v3.3.1
22+
- repo: https://github.com/astral-sh/ruff-pre-commit
23+
rev: v0.1.13
2424
hooks:
25-
- id: pyupgrade
26-
args: [--py39-plus]
27-
- repo: https://github.com/psf/black
28-
rev: 23.1.0
29-
hooks:
30-
- id: black
31-
language_version: python3
32-
- repo: https://github.com/pycqa/flake8
33-
rev: 6.0.0
34-
hooks:
35-
- id: flake8
36-
additional_dependencies:
37-
- flake8-comprehensions
38-
- repo: https://github.com/pycqa/isort
39-
rev: 5.12.0
40-
hooks:
41-
- id: isort
42-
- repo: https://github.com/humitos/mirrors-autoflake.git
43-
rev: v1.1
44-
hooks:
45-
- id: autoflake
46-
exclude: |
47-
(?x)^(
48-
.*/?__init__\.py|
49-
pytensor/graph/toolbox\.py|
50-
pytensor/link/jax/jax_dispatch\.py|
51-
pytensor/link/jax/jax_linker\.py|
52-
pytensor/scalar/basic_scipy\.py|
53-
pytensor/tensor/linalg\.py
54-
)$
55-
args: ['--in-place', '--remove-all-unused-imports', '--remove-unused-variable']
25+
- id: ruff
26+
args: ["--fix", "--show-source"]
27+
- id: ruff-format
28+
args: ["--line-length=88"]
5629
- repo: https://github.com/pre-commit/mirrors-mypy
57-
rev: v1.0.0
30+
rev: v1.8.0
5831
hooks:
5932
- id: mypy
6033
language: python

README.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
|Tests Status| |Coverage|
77

8-
|Project Name| is a fork of `Aesara <https://github.com/aesara-devs/aesara>`__ -- a Python library that allows one to define, optimize, and
8+
|Project Name| is a Python library that allows one to define, optimize, and
99
efficiently evaluate mathematical expressions involving multi-dimensional arrays.
10+
It provides the computational backend for `PyMC <https://github.com/pymc-devs/pymc>`__.
1011

1112
Features
1213
========
@@ -15,7 +16,8 @@ Features
1516
- Extensible graph framework suitable for rapid development of custom operators and symbolic optimizations
1617
- Implements an extensible graph transpilation framework that currently provides
1718
compilation via C, `JAX <https://github.com/google/jax>`__, and `Numba <https://github.com/numba/numba>`__
18-
- Based on one of the most widely-used Python tensor libraries: `Theano <https://github.com/Theano/Theano>`__
19+
- Contrary to PyTorch and TensorFlow, PyTensor maintains a static graph which can be modified in-place to
20+
allow for advanced optimizations
1921

2022
Getting started
2123
===============
@@ -113,6 +115,11 @@ The current development branch of |Project Name| can be installed from GitHub, a
113115
pip install git+https://github.com/pymc-devs/pytensor
114116

115117

118+
Background
119+
==========
120+
121+
PyTensor is a fork of `Aesara <https://github.com/aesara-devs/aesara>`__, which is a fork of `Theano <https://github.com/Theano/Theano>`__.
122+
116123
Contributing
117124
============
118125

bin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import warnings
22

3+
34
warnings.warn(
45
message= "Importing 'bin.pytensor_cache' is deprecated. Import from "
56
"'pytensor.bin.pytensor_cache' instead.",

environment.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,7 @@ dependencies:
3939
- pydot
4040
- ipython
4141
# code style
42-
- black
43-
- isort
44-
# For linting
45-
- flake8
46-
- pep8
47-
- pyflakes
42+
- ruff
4843
# developer tools
4944
- pre-commit
5045
- packaging

pyproject.toml

Lines changed: 43 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,20 @@
11
[build-system]
22
requires = [
3-
"setuptools>=48.0.0",
4-
"cython",
5-
"numpy>=1.17.0",
6-
"versioneer[toml]==0.28",
3+
"setuptools>=48.0.0",
4+
"cython",
5+
"numpy>=1.17.0",
6+
"versioneer[toml]>=0.28",
77
]
88
build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "pytensor"
12-
dynamic = [
13-
'version'
14-
]
15-
requires-python = ">=3.9,<3.12"
16-
authors = [
17-
{name = "pymc-devs", email = "[email protected]"}
18-
]
12+
dynamic = ['version']
13+
requires-python = ">=3.9,<3.13"
14+
authors = [{ name = "pymc-devs", email = "[email protected]" }]
1915
description = "Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs."
2016
readme = "README.rst"
21-
license = {file = "LICENSE.txt"}
17+
license = { file = "LICENSE.txt" }
2218
classifiers = [
2319
"Development Status :: 6 - Mature",
2420
"Intended Audience :: Education",
@@ -37,6 +33,7 @@ classifiers = [
3733
"Programming Language :: Python :: 3.9",
3834
"Programming Language :: Python :: 3.10",
3935
"Programming Language :: Python :: 3.11",
36+
"Programming Language :: Python :: 3.12",
4037
]
4138

4239
keywords = [
@@ -71,15 +68,8 @@ documentation = "https://pytensor.readthedocs.io/en/latest/"
7168
pytensor-cache = "pytensor.bin.pytensor_cache:main"
7269

7370
[project.optional-dependencies]
74-
complete = [
75-
"pytensor[jax]",
76-
"pytensor[numba]",
77-
]
78-
development = [
79-
"pytensor[complete]",
80-
"pytensor[tests]",
81-
"pytensor[rtd]"
82-
]
71+
complete = ["pytensor[jax]", "pytensor[numba]"]
72+
development = ["pytensor[complete]", "pytensor[tests]", "pytensor[rtd]"]
8373
tests = [
8474
"pytest",
8575
"pre-commit",
@@ -88,34 +78,16 @@ tests = [
8878
"pytest-benchmark",
8979
"pytest-mock",
9080
]
91-
rtd = [
92-
"sphinx>=5.1.0,<6",
93-
"pygments",
94-
"pydot",
95-
"pydot2",
96-
"pydot-ng",
97-
]
98-
jax = [
99-
"jax",
100-
"jaxlib",
101-
]
102-
numba = [
103-
"numba>=0.55",
104-
"numba-scipy>=0.3.0"
105-
]
81+
rtd = ["sphinx>=5.1.0,<6", "pygments", "pydot", "pydot2", "pydot-ng"]
82+
jax = ["jax", "jaxlib"]
83+
numba = ["numba>=0.55", "numba-scipy>=0.3.0"]
10684

10785
[tool.setuptools.packages.find]
10886
include = ["pytensor*", "bin"]
10987

11088
[tool.setuptools.package-data]
111-
pytensor = [
112-
"py.typed"
113-
]
114-
"pytensor.d3viz" = [
115-
"html/*",
116-
"css/*",
117-
"js/*",
118-
]
89+
pytensor = ["py.typed"]
90+
"pytensor.d3viz" = ["html/*", "css/*", "js/*"]
11991

12092
[tool.coverage.run]
12193
omit = [
@@ -136,14 +108,8 @@ branch = true
136108
relative_files = true
137109

138110
[tool.coverage.report]
139-
omit = [
140-
"pytensor/_version.py",
141-
"tests/",
142-
]
143-
exclude_lines = [
144-
"pragma: no cover",
145-
"if TYPE_CHECKING:",
146-
]
111+
omit = ["pytensor/_version.py", "tests/"]
112+
exclude_lines = ["pragma: no cover", "if TYPE_CHECKING:"]
147113
show_missing = true
148114

149115
[tool.versioneer]
@@ -163,41 +129,36 @@ max-line-length = 88
163129
[tool.pylint.messages_control]
164130
disable = ["C0330", "C0326"]
165131

166-
[tool.isort]
167-
profile = "black"
168-
lines_after_imports = 2
169-
lines_between_sections = 1
170-
honor_noqa = true
171-
skip_gitignore = true
172-
skip = "pytensor/version.py"
173-
skip_glob = "**/*.pyx"
174132

175133
[tool.ruff]
176-
select=["C","E","F","W"]
177-
ignore=["E501","E741","C408","C901"]
178-
exclude = [
179-
"doc/",
180-
"pytensor/_version.py",
181-
"bin/pytensor_cache.py",
182-
]
134+
select = ["C", "E", "F", "I", "UP", "W"]
135+
ignore = ["C408", "C901", "E501", "E741", "UP031"]
136+
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
137+
138+
139+
[tool.ruff.isort]
140+
lines-after-imports = 2
183141

184142
[tool.ruff.per-file-ignores]
185143
# TODO: Get rid of these:
186-
"**/__init__.py"=["F401","E402","F403"]
187-
"pytensor/tensor/linalg.py"=["F401","F403"]
188-
"pytensor/scalar/basic_scipy.py"=["E402","F403","F401"]
189-
"pytensor/graph/toolbox.py"=["E402","F403","F401"]
190-
"pytensor/link/jax/jax_dispatch.py"=["E402","F403","F401"]
191-
"pytensor/link/jax/jax_linker.py"=["E402","F403","F401"]
192-
"pytensor/sparse/sandbox/sp2.py"=["F401"]
193-
"tests/tensor/test_math_scipy.py"=["E402"]
194-
"tests/sparse/test_basic.py"=["E402"]
195-
"tests/sparse/test_opt.py"=["E402"]
196-
"tests/sparse/test_sp2.py"=["E402"]
197-
"tests/sparse/test_utils.py"=["E402","F401"]
198-
"tests/sparse/sandbox/test_sp.py"=["E402","F401"]
199-
"tests/scalar/test_basic_sympy.py"=["E402"]
200-
"pytensor/graph/rewriting/unify.py"=["F811"]
144+
"**/__init__.py" = ["F401", "E402", "F403"]
145+
"pytensor/tensor/linalg.py" = ["F401", "F403"]
146+
"pytensor/scalar/basic_scipy.py" = ["E402"]
147+
"pytensor/graph/toolbox.py" = ["E402"]
148+
# For the tests we skip because `pytest.importorskip` is used:
149+
"tests/link/jax/test_scalar.py" = ["E402"]
150+
"tests/link/jax/test_tensor_basic.py" = ["E402"]
151+
"tests/link/numba/test_basic.py" = ["E402"]
152+
"tests/link/numba/test_cython_support.py" = ["E402"]
153+
"tests/link/numba/test_performance.py" = ["E402"]
154+
"tests/link/numba/test_sparse.py" = ["E402"]
155+
"tests/link/numba/test_tensor_basic.py" = ["E402"]
156+
"tests/tensor/test_math_scipy.py" = ["E402"]
157+
"tests/sparse/test_basic.py" = ["E402"]
158+
"tests/sparse/test_sp2.py" = ["E402"]
159+
"tests/sparse/test_utils.py" = ["E402"]
160+
"tests/sparse/sandbox/test_sp.py" = ["E402", "F401"]
161+
"versioneer.py" = ["I"]
201162

202163

203164
[tool.mypy]

pytensor/compile/debugmode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1866,7 +1866,7 @@ def thunk():
18661866
# Nothing should be in storage map after evaluating
18671867
# each the thunk (specifically the last one)
18681868
for r, s in storage_map.items():
1869-
assert type(s) is list
1869+
assert isinstance(s, list)
18701870
assert s[0] is None
18711871

18721872
# store our output variables to their respective storage lists

pytensor/compile/profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,8 +1084,8 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
10841084
viewof_change = []
10851085
# Use to track view_of changes
10861086

1087-
viewedby_add = defaultdict(lambda: [])
1088-
viewedby_remove = defaultdict(lambda: [])
1087+
viewedby_add = defaultdict(list)
1088+
viewedby_remove = defaultdict(list)
10891089
# Use to track viewed_by changes
10901090

10911091
for var in node.outputs:

0 commit comments

Comments
 (0)