Skip to content

Commit 44843a4

Browse files
committed
Blockwise some linalg Ops by default
1 parent 7256c27 commit 44843a4

File tree

7 files changed

+216
-119
lines changed

7 files changed

+216
-119
lines changed

pytensor/tensor/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3638,7 +3638,7 @@ def stacklists(arg):
36383638
return arg
36393639

36403640

3641-
def swapaxes(y, axis1, axis2):
3641+
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
36423642
"Swap the axes of a tensor."
36433643
y = as_tensor_variable(y)
36443644
ndim = y.ndim

pytensor/tensor/nlinalg.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from pytensor.tensor import basic as at
1111
from pytensor.tensor import math as tm
1212
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1415

1516

1617
class MatrixPinv(Op):
1718
__props__ = ("hermitian",)
19+
gufunc_signature = "(m,n)->(n,m)"
1820

1921
def __init__(self, hermitian):
2022
self.hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
7577
solve op.
7678
7779
"""
78-
return MatrixPinv(hermitian=hermitian)(x)
80+
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
7981

8082

8183
class MatrixInverse(Op):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
9395
"""
9496

9597
__props__ = ()
98+
gufunc_signature = "(m,m)->(m,m)"
99+
gufunc_spec = ("numpy.linalg.inv", 1, 1)
96100

97101
def __init__(self):
98102
pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
150154
return shapes
151155

152156

153-
inv = matrix_inverse = MatrixInverse()
157+
inv = matrix_inverse = Blockwise(MatrixInverse())
154158

155159

156160
def matrix_dot(*args):
@@ -181,6 +185,8 @@ class Det(Op):
181185
"""
182186

183187
__props__ = ()
188+
gufunc_signature = "(m,m)->()"
189+
gufunc_spec = ("numpy.linalg.det", 1, 1)
184190

185191
def make_node(self, x):
186192
x = as_tensor_variable(x)
@@ -209,7 +215,7 @@ def __str__(self):
209215
return "Det"
210216

211217

212-
det = Det()
218+
det = Blockwise(Det())
213219

214220

215221
class SLogDet(Op):
@@ -218,6 +224,8 @@ class SLogDet(Op):
218224
"""
219225

220226
__props__ = ()
227+
gufunc_signature = "(m, m)->(),()"
228+
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
221229

222230
def make_node(self, x):
223231
x = as_tensor_variable(x)
@@ -242,7 +250,7 @@ def __str__(self):
242250
return "SLogDet"
243251

244252

245-
slogdet = SLogDet()
253+
slogdet = Blockwise(SLogDet())
246254

247255

248256
class Eig(Op):
@@ -252,6 +260,8 @@ class Eig(Op):
252260
"""
253261

254262
__props__: Tuple[str, ...] = ()
263+
gufunc_signature = "(m,m)->(m),(m,m)"
264+
gufunc_spec = ("numpy.linalg.eig", 1, 2)
255265

256266
def make_node(self, x):
257267
x = as_tensor_variable(x)
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
270280
return [(n,), (n, n)]
271281

272282

273-
eig = Eig()
283+
eig = Blockwise(Eig())
274284

275285

276286
class Eigh(Eig):

pytensor/tensor/rewriting/linalg.py

+96-65
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import logging
2+
from typing import cast
23

34
from pytensor.graph.rewriting.basic import node_rewriter
4-
from pytensor.tensor import basic as at
5+
from pytensor.tensor.basic import TensorVariable, extract_diag, swapaxes
56
from pytensor.tensor.blas import Dot22
7+
from pytensor.tensor.blockwise import Blockwise
68
from pytensor.tensor.elemwise import DimShuffle
79
from pytensor.tensor.math import Dot, Prod, log, prod
8-
from pytensor.tensor.nlinalg import Det, MatrixInverse
10+
from pytensor.tensor.nlinalg import MatrixInverse, det
911
from pytensor.tensor.rewriting.basic import (
1012
register_canonicalize,
1113
register_specialize,
@@ -17,16 +19,40 @@
1719
logger = logging.getLogger(__name__)
1820

1921

22+
def is_matrix_transpose(x: TensorVariable) -> bool:
23+
"""Check if a variable corresponds to a transpose of the last two axes"""
24+
node = x.owner
25+
if (
26+
node
27+
and isinstance(node.op, DimShuffle)
28+
and not (node.op.drop or node.op.augment)
29+
):
30+
[inp] = node.inputs
31+
ndims = inp.type.ndim
32+
if ndims < 2:
33+
return False
34+
transpose_order = tuple(range(ndims - 2)) + (ndims - 1, ndims - 2)
35+
return cast(bool, node.op.new_order == transpose_order)
36+
return False
37+
38+
39+
def T(x: TensorVariable) -> TensorVariable:
40+
"""Matrix transpose for potentially higher dimensionality tensors"""
41+
return swapaxes(x, -1, -2)
42+
43+
2044
@register_canonicalize
2145
@node_rewriter([DimShuffle])
2246
def transinv_to_invtrans(fgraph, node):
23-
if isinstance(node.op, DimShuffle):
24-
if node.op.new_order == (1, 0):
25-
(A,) = node.inputs
26-
if A.owner:
27-
if isinstance(A.owner.op, MatrixInverse):
28-
(X,) = A.owner.inputs
29-
return [A.owner.op(node.op(X))]
47+
if is_matrix_transpose(node.outputs[0]):
48+
(A,) = node.inputs
49+
if (
50+
A.owner
51+
and isinstance(A.owner.op, Blockwise)
52+
and isinstance(A.owner.op.core_op, MatrixInverse)
53+
):
54+
(X,) = A.owner.inputs
55+
return [A.owner.op(node.op(X))]
3056

3157

3258
@register_stabilize
@@ -37,87 +63,98 @@ def inv_as_solve(fgraph, node):
3763
"""
3864
if isinstance(node.op, (Dot, Dot22)):
3965
l, r = node.inputs
40-
if l.owner and isinstance(l.owner.op, MatrixInverse):
66+
if (
67+
l.owner
68+
and isinstance(l.owner.op, Blockwise)
69+
and isinstance(l.owner.op.core_op, MatrixInverse)
70+
):
4171
return [solve(l.owner.inputs[0], r)]
42-
if r.owner and isinstance(r.owner.op, MatrixInverse):
72+
if (
73+
r.owner
74+
and isinstance(r.owner.op, Blockwise)
75+
and isinstance(r.owner.op.core_op, MatrixInverse)
76+
):
4377
x = r.owner.inputs[0]
4478
if getattr(x.tag, "symmetric", None) is True:
45-
return [solve(x, l.T).T]
79+
return [T(solve(x, T(l)))]
4680
else:
47-
return [solve(x.T, l.T).T]
81+
return [T(solve(T(x), T(l)))]
4882

4983

5084
@register_stabilize
5185
@register_canonicalize
52-
@node_rewriter([Solve])
86+
@node_rewriter([Blockwise])
5387
def tag_solve_triangular(fgraph, node):
5488
"""
5589
If a general solve() is applied to the output of a cholesky op, then
5690
replace it with a triangular solve.
5791
5892
"""
59-
if isinstance(node.op, Solve):
60-
if node.op.assume_a == "gen":
93+
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 1:
94+
if node.op.core_op.assume_a == "gen":
6195
A, b = node.inputs # result is solution Ax=b
62-
if A.owner and isinstance(A.owner.op, Cholesky):
63-
if A.owner.op.lower:
64-
return [Solve(assume_a="sym", lower=True)(A, b)]
65-
else:
66-
return [Solve(assume_a="sym", lower=False)(A, b)]
6796
if (
6897
A.owner
69-
and isinstance(A.owner.op, DimShuffle)
70-
and A.owner.op.new_order == (1, 0)
98+
and isinstance(A.owner.op, Blockwise)
99+
and isinstance(A.owner.op.core_op, Cholesky)
71100
):
101+
if A.owner.op.core_op.lower:
102+
return [solve(A, b, assume_a="sym", lower=True)]
103+
else:
104+
return [solve(A, b, assume_a="sym", lower=False)]
105+
if is_matrix_transpose(A):
72106
(A_T,) = A.owner.inputs
73-
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
107+
if (
108+
A_T.owner
109+
and isinstance(A_T.owner.op, Blockwise)
110+
and isinstance(A_T.owner.op, Cholesky)
111+
):
74112
if A_T.owner.op.lower:
75-
return [Solve(assume_a="sym", lower=False)(A, b)]
113+
return [solve(A, b, assume_a="sym", lower=False)]
76114
else:
77-
return [Solve(assume_a="sym", lower=True)(A, b)]
115+
return [solve(A, b, assume_a="sym", lower=True)]
78116

79117

80118
@register_canonicalize
81119
@register_stabilize
82120
@register_specialize
83121
@node_rewriter([DimShuffle])
84122
def no_transpose_symmetric(fgraph, node):
85-
if isinstance(node.op, DimShuffle):
123+
if is_matrix_transpose(node.outputs[0]):
86124
x = node.inputs[0]
87-
if x.type.ndim == 2 and getattr(x.tag, "symmetric", None) is True:
88-
if node.op.new_order == [1, 0]:
89-
return [x]
125+
if getattr(x.tag, "symmetric", None):
126+
return [x]
90127

91128

92129
@register_stabilize
93-
@node_rewriter([Solve])
130+
@node_rewriter([Blockwise])
94131
def psd_solve_with_chol(fgraph, node):
95132
"""
96133
This utilizes a boolean `psd` tag on matrices.
97134
"""
98-
if isinstance(node.op, Solve):
135+
if isinstance(node.op.core_op, Solve) and node.op.core_op.b_ndim == 2:
99136
A, b = node.inputs # result is solution Ax=b
100137
if getattr(A.tag, "psd", None) is True:
101138
L = cholesky(A)
102139
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
103-
# __if__ no other Op makes use of the the L matrix during the
140+
# __if__ no other Op makes use of the L matrix during the
104141
# stabilization
105-
Li_b = Solve(assume_a="sym", lower=True)(L, b)
106-
x = Solve(assume_a="sym", lower=False)(L.T, Li_b)
142+
Li_b = solve(L, b, assume_a="sym", lower=True, b_ndim=2)
143+
x = solve(T(L), Li_b, assume_a="sym", lower=False, b_ndim=2)
107144
return [x]
108145

109146

110147
@register_canonicalize
111148
@register_stabilize
112-
@node_rewriter([Cholesky])
149+
@node_rewriter([Blockwise])
113150
def cholesky_ldotlt(fgraph, node):
114151
"""
115152
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
116153
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
117154
118155
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
119156
"""
120-
if not isinstance(node.op, Cholesky):
157+
if not isinstance(node.op.core_op, Cholesky):
121158
return
122159

123160
A = node.inputs[0]
@@ -129,43 +166,38 @@ def cholesky_ldotlt(fgraph, node):
129166
# cholesky(dot(L,L.T)) case
130167
if (
131168
getattr(l.tag, "lower_triangular", False)
132-
and r.owner
133-
and isinstance(r.owner.op, DimShuffle)
134-
and r.owner.op.new_order == (1, 0)
169+
and is_matrix_transpose(r)
135170
and r.owner.inputs[0] == l
136171
):
137-
if node.op.lower:
172+
if node.op.core_op.lower:
138173
return [l]
139174
return [r]
140175

141176
# cholesky(dot(U.T,U)) case
142177
if (
143178
getattr(r.tag, "upper_triangular", False)
144-
and l.owner
145-
and isinstance(l.owner.op, DimShuffle)
146-
and l.owner.op.new_order == (1, 0)
179+
and is_matrix_transpose(l)
147180
and l.owner.inputs[0] == r
148181
):
149-
if node.op.lower:
182+
if node.op.core_op.lower:
150183
return [l]
151184
return [r]
152185

153186

154187
@register_stabilize
155188
@register_specialize
156-
@node_rewriter([Det])
189+
@node_rewriter([det])
157190
def local_det_chol(fgraph, node):
158191
"""
159192
If we have det(X) and there is already an L=cholesky(X)
160193
floating around, then we can use prod(diag(L)) to get the determinant.
161194
162195
"""
163-
if isinstance(node.op, Det):
164-
(x,) = node.inputs
165-
for cl, xpos in fgraph.clients[x]:
166-
if isinstance(cl.op, Cholesky):
167-
L = cl.outputs[0]
168-
return [prod(at.extract_diag(L) ** 2)]
196+
(x,) = node.inputs
197+
for cl, xpos in fgraph.clients[x]:
198+
if isinstance(cl.op, Blockwise) and isinstance(cl.op.core_op, Cholesky):
199+
L = cl.outputs[0]
200+
return [prod(extract_diag(L) ** 2, axis=(-1, -2))]
169201

170202

171203
@register_canonicalize
@@ -176,16 +208,15 @@ def local_log_prod_sqr(fgraph, node):
176208
"""
177209
This utilizes a boolean `positive` tag on matrices.
178210
"""
179-
if node.op == log:
180-
(x,) = node.inputs
181-
if x.owner and isinstance(x.owner.op, Prod):
182-
# we cannot always make this substitution because
183-
# the prod might include negative terms
184-
p = x.owner.inputs[0]
185-
186-
# p is the matrix we're reducing with prod
187-
if getattr(p.tag, "positive", None) is True:
188-
return [log(p).sum(axis=x.owner.op.axis)]
189-
190-
# TODO: have a reduction like prod and sum that simply
191-
# returns the sign of the prod multiplication.
211+
(x,) = node.inputs
212+
if x.owner and isinstance(x.owner.op, Prod):
213+
# we cannot always make this substitution because
214+
# the prod might include negative terms
215+
p = x.owner.inputs[0]
216+
217+
# p is the matrix we're reducing with prod
218+
if getattr(p.tag, "positive", None) is True:
219+
return [log(p).sum(axis=x.owner.op.axis)]
220+
221+
# TODO: have a reduction like prod and sum that simply
222+
# returns the sign of the prod multiplication.

0 commit comments

Comments
 (0)