Skip to content

Commit c88f74c

Browse files
committed
Blockwise some linalg Ops by default
1 parent eff0721 commit c88f74c

File tree

8 files changed

+226
-128
lines changed

8 files changed

+226
-128
lines changed

pytensor/tensor/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3735,7 +3735,7 @@ def stacklists(arg):
37353735
return arg
37363736

37373737

3738-
def swapaxes(y, axis1, axis2):
3738+
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
37393739
"Swap the axes of a tensor."
37403740
y = as_tensor_variable(y)
37413741
ndim = y.ndim

pytensor/tensor/nlinalg.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from pytensor.tensor import basic as at
1111
from pytensor.tensor import math as tm
1212
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1415

1516

1617
class MatrixPinv(Op):
1718
__props__ = ("hermitian",)
19+
gufunc_signature = "(m,n)->(n,m)"
1820

1921
def __init__(self, hermitian):
2022
self.hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
7577
solve op.
7678
7779
"""
78-
return MatrixPinv(hermitian=hermitian)(x)
80+
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
7981

8082

8183
class MatrixInverse(Op):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
9395
"""
9496

9597
__props__ = ()
98+
gufunc_signature = "(m,m)->(m,m)"
99+
gufunc_spec = ("numpy.linalg.inv", 1, 1)
96100

97101
def __init__(self):
98102
pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
150154
return shapes
151155

152156

153-
inv = matrix_inverse = MatrixInverse()
157+
inv = matrix_inverse = Blockwise(MatrixInverse())
154158

155159

156160
def matrix_dot(*args):
@@ -181,6 +185,8 @@ class Det(Op):
181185
"""
182186

183187
__props__ = ()
188+
gufunc_signature = "(m,m)->()"
189+
gufunc_spec = ("numpy.linalg.det", 1, 1)
184190

185191
def make_node(self, x):
186192
x = as_tensor_variable(x)
@@ -209,7 +215,7 @@ def __str__(self):
209215
return "Det"
210216

211217

212-
det = Det()
218+
det = Blockwise(Det())
213219

214220

215221
class SLogDet(Op):
@@ -218,6 +224,8 @@ class SLogDet(Op):
218224
"""
219225

220226
__props__ = ()
227+
gufunc_signature = "(m, m)->(),()"
228+
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
221229

222230
def make_node(self, x):
223231
x = as_tensor_variable(x)
@@ -242,7 +250,7 @@ def __str__(self):
242250
return "SLogDet"
243251

244252

245-
slogdet = SLogDet()
253+
slogdet = Blockwise(SLogDet())
246254

247255

248256
class Eig(Op):
@@ -252,6 +260,8 @@ class Eig(Op):
252260
"""
253261

254262
__props__: Tuple[str, ...] = ()
263+
gufunc_signature = "(m,m)->(m),(m,m)"
264+
gufunc_spec = ("numpy.linalg.eig", 1, 2)
255265

256266
def make_node(self, x):
257267
x = as_tensor_variable(x)
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
270280
return [(n,), (n, n)]
271281

272282

273-
eig = Eig()
283+
eig = Blockwise(Eig())
274284

275285

276286
class Eigh(Eig):

pytensor/tensor/rewriting/linalg.py

+104-72
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
from typing import cast
23

34
from pytensor.graph.rewriting.basic import node_rewriter
4-
from pytensor.tensor import basic as at
5+
from pytensor.tensor.basic import TensorVariable, extract_diag, swapaxes
56
from pytensor.tensor.blas import Dot22
7+
from pytensor.tensor.blockwise import Blockwise
68
from pytensor.tensor.elemwise import DimShuffle
79
from pytensor.tensor.math import Dot, Prod, log, prod
8-
from pytensor.tensor.nlinalg import Det, MatrixInverse
10+
from pytensor.tensor.nlinalg import MatrixInverse, det
911
from pytensor.tensor.rewriting.basic import (
1012
register_canonicalize,
1113
register_specialize,
@@ -17,16 +19,40 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
def is_matrix_transpose(x: TensorVariable) -> bool:
23+
"""Check if a variable corresponds to a transpose of the last two axes"""
24+
node = x.owner
25+
if (
26+
node
27+
and isinstance(node.op, DimShuffle)
28+
and not (node.op.drop or node.op.augment)
29+
):
30+
[inp] = node.inputs
31+
ndims = inp.type.ndim
32+
if ndims < 2:
33+
return False
34+
transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2)
35+
return cast(bool, node.op.new_order == transpose_order)
36+
return False
37+
38+
39+
def _T(x: TensorVariable) -> TensorVariable:
40+
"""Matrix transpose for potentially higher dimensionality tensors"""
41+
return swapaxes(x, -1, -2)
42+
43+
2044
@register_canonicalize
2145
@node_rewriter([DimShuffle])
2246
def transinv_to_invtrans(fgraph, node):
23-
if isinstance(node.op, DimShuffle):
24-
if node.op.new_order == (1, 0):
25-
(A,) = node.inputs
26-
if A.owner:
27-
if isinstance(A.owner.op, MatrixInverse):
28-
(X,) = A.owner.inputs
29-
return [A.owner.op(node.op(X))]
47+
if is_matrix_transpose(node.outputs[0]):
48+
(A,) = node.inputs
49+
if (
50+
A.owner
51+
and isinstance(A.owner.op, Blockwise)
52+
and isinstance(A.owner.op.core_op, MatrixInverse)
53+
):
54+
(X,) = A.owner.inputs
55+
return [A.owner.op(node.op(X))]
3056

3157

3258
@register_stabilize
@@ -37,86 +63,98 @@ def inv_as_solve(fgraph, node):
3763
"""
3864
if isinstance(node.op, (Dot, Dot22)):
3965
l, r = node.inputs
40-
if l.owner and isinstance(l.owner.op, MatrixInverse):
66+
if (
67+
l.owner
68+
and isinstance(l.owner.op, Blockwise)
69+
and isinstance(l.owner.op.core_op, MatrixInverse)
70+
):
4171
return [solve(l.owner.inputs[0], r)]
42-
if r.owner and isinstance(r.owner.op, MatrixInverse):
72+
if (
73+
r.owner
74+
and isinstance(r.owner.op, Blockwise)
75+
and isinstance(r.owner.op.core_op, MatrixInverse)
76+
):
4377
x = r.owner.inputs[0]
4478
if getattr(x.tag, "symmetric", None) is True:
45-
return [solve(x, l.T).T]
79+
return [_T(solve(x, _T(l)))]
4680
else:
47-
return [solve(x.T, l.T).T]
81+
return [_T(solve(_T(x), _T(l)))]
4882

4983

5084
@register_stabilize
5185
@register_canonicalize
52-
@node_rewriter([Solve])
86+
@node_rewriter([Blockwise])
5387
def generic_solve_to_solve_triangular(fgraph, node):
5488
"""
5589
If any solve() is applied to the output of a cholesky op, then
5690
replace it with a triangular solve.
5791
5892
"""
59-
if isinstance(node.op, Solve):
60-
A, b = node.inputs # result is solution Ax=b
61-
if A.owner and isinstance(A.owner.op, Cholesky):
62-
if A.owner.op.lower:
63-
return [SolveTriangular(lower=True)(A, b)]
64-
else:
65-
return [SolveTriangular(lower=False)(A, b)]
66-
if (
67-
A.owner
68-
and isinstance(A.owner.op, DimShuffle)
69-
and A.owner.op.new_order == (1, 0)
70-
):
71-
(A_T,) = A.owner.inputs
72-
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
73-
if A_T.owner.op.lower:
74-
return [SolveTriangular(lower=False)(A, b)]
75-
else:
93+
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 1:
94+
if node.op.core_op.assume_a == "gen":
95+
A, b = node.inputs # result is solution Ax=b
96+
if (
97+
A.owner
98+
and isinstance(A.owner.op, Blockwise)
99+
and isinstance(A.owner.op.core_op, Cholesky)
100+
):
101+
if A.owner.op.core_op.lower:
76102
return [SolveTriangular(lower=True)(A, b)]
103+
else:
104+
return [SolveTriangular(lower=False)(A, b)]
105+
if is_matrix_transpose(A):
106+
(A_T,) = A.owner.inputs
107+
if (
108+
A_T.owner
109+
and isinstance(A_T.owner.op, Blockwise)
110+
and isinstance(A_T.owner.op, Cholesky)
111+
):
112+
if A_T.owner.op.lower:
113+
return [SolveTriangular(lower=False)(A, b)]
114+
else:
115+
return [SolveTriangular(lower=True)(A, b)]
77116

78117

79118
@register_canonicalize
80119
@register_stabilize
81120
@register_specialize
82121
@node_rewriter([DimShuffle])
83122
def no_transpose_symmetric(fgraph, node):
84-
if isinstance(node.op, DimShuffle):
123+
if is_matrix_transpose(node.outputs[0]):
85124
x = node.inputs[0]
86-
if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True:
87-
if node.op.new_order == [1, 0]:
88-
return [x]
125+
if getattr(x.tag, "symmetric", None):
126+
return [x]
89127

90128

91129
@register_stabilize
92-
@node_rewriter([Solve])
130+
@node_rewriter([Blockwise])
93131
def psd_solve_with_chol(fgraph, node):
94132
"""
95133
This utilizes a boolean `psd` tag on matrices.
96134
"""
97-
if isinstance(node.op, Solve):
135+
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
98136
A, b = node.inputs # result is solution Ax=b
99137
if getattr(A.tag, "psd", None) is True:
100138
L = cholesky(A)
101139
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
102-
# __if__ no other Op makes use of the the L matrix during the
140+
# __if__ no other Op makes use of the L matrix during the
103141
# stabilization
104-
Li_b = Solve(assume_a="sym", lower=True)(L, b)
105-
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
142+
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
143+
x = solve(_T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
106144
return [x]
107145

108146

109147
@register_canonicalize
110148
@register_stabilize
111-
@node_rewriter([Cholesky])
149+
@node_rewriter([Blockwise])
112150
def cholesky_ldotlt(fgraph, node):
113151
"""
114152
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
115153
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
116154
117155
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
118156
"""
119-
if not isinstance(node.op, Cholesky):
157+
if not isinstance(node.op.core_op, Cholesky):
120158
return
121159

122160
A = node.inputs[0]
@@ -128,45 +166,40 @@ def cholesky_ldotlt(fgraph, node):
128166
# cholesky(dot(L,L.T)) case
129167
if (
130168
getattr(l.tag, "lower_triangular", False)
131-
and r.owner
132-
and isinstance(r.owner.op, DimShuffle)
133-
and r.owner.op.new_order == (1, 0)
169+
and is_matrix_transpose(r)
134170
and r.owner.inputs[0] == l
135171
):
136-
if node.op.lower:
172+
if node.op.core_op.lower:
137173
return [l]
138174
return [r]
139175

140176
# cholesky(dot(U.T,U)) case
141177
if (
142178
getattr(r.tag, "upper_triangular", False)
143-
and l.owner
144-
and isinstance(l.owner.op, DimShuffle)
145-
and l.owner.op.new_order == (1, 0)
179+
and is_matrix_transpose(l)
146180
and l.owner.inputs[0] == r
147181
):
148-
if node.op.lower:
182+
if node.op.core_op.lower:
149183
return [l]
150184
return [r]
151185

152186

153187
@register_stabilize
154188
@register_specialize
155-
@node_rewriter([Det])
189+
@node_rewriter([det])
156190
def local_det_chol(fgraph, node):
157191
"""
158192
If we have det(X) and there is already an L=cholesky(X)
159193
floating around, then we can use prod(diag(L)) to get the determinant.
160194
161195
"""
162-
if isinstance(node.op, Det):
163-
(x,) = node.inputs
164-
for cl, xpos in fgraph.clients[x]:
165-
if cl == "output":
166-
continue
167-
if isinstance(cl.op, Cholesky):
168-
L = cl.outputs[0]
169-
return [prod(at.extract_diag(L) ** 2)]
196+
(x,) = node.inputs
197+
for cl, xpos in fgraph.clients[x]:
198+
if cl == "output":
199+
continue
200+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
201+
L = cl.outputs[0]
202+
return [prod(extract_diag(L) ** 2, axis=(-1, -2))]
170203

171204

172205
@register_canonicalize
@@ -177,16 +210,15 @@ def local_log_prod_sqr(fgraph, node):
177210
"""
178211
This utilizes a boolean `positive` tag on matrices.
179212
"""
180-
if node.op == log:
181-
(x,) = node.inputs
182-
if x.owner and isinstance(x.owner.op, Prod):
183-
# we cannot always make this substitution because
184-
# the prod might include negative terms
185-
p = x.owner.inputs[0]
186-
187-
# p is the matrix we're reducing with prod
188-
if getattr(p.tag, "positive", None) is True:
189-
return [log(p).sum(axis=x.owner.op.axis)]
190-
191-
# TODO: have a reduction like prod and sum that simply
192-
# returns the sign of the prod multiplication.
213+
(x,) = node.inputs
214+
if x.owner and isinstance(x.owner.op, Prod):
215+
# we cannot always make this substitution because
216+
# the prod might include negative terms
217+
p = x.owner.inputs[0]
218+
219+
# p is the matrix we're reducing with prod
220+
if getattr(p.tag, "positive", None) is True:
221+
return [log(p).sum(axis=x.owner.op.axis)]
222+
223+
# TODO: have a reduction like prod and sum that simply
224+
# returns the sign of the prod multiplication.

0 commit comments

Comments
 (0)