Skip to content

Commit bb62817

Browse files
committed
Blockwise some linalg Ops by default
1 parent 0ff0f29 commit bb62817

File tree

8 files changed

+249
-129
lines changed

8 files changed

+249
-129
lines changed

pytensor/tensor/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3764,7 +3764,7 @@ def stacklists(arg):
37643764
return arg
37653765

37663766

3767-
def swapaxes(y, axis1, axis2):
3767+
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
37683768
"Swap the axes of a tensor."
37693769
y = as_tensor_variable(y)
37703770
ndim = y.ndim

pytensor/tensor/nlinalg.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from pytensor.tensor import basic as at
1111
from pytensor.tensor import math as tm
1212
from pytensor.tensor.basic import as_tensor_variable, extract_diag
13+
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.type import dvector, lscalar, matrix, scalar, vector
1415

1516

1617
class MatrixPinv(Op):
1718
__props__ = ("hermitian",)
19+
gufunc_signature = "(m,n)->(n,m)"
1820

1921
def __init__(self, hermitian):
2022
self.hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
7577
solve op.
7678
7779
"""
78-
return MatrixPinv(hermitian=hermitian)(x)
80+
return Blockwise(MatrixPinv(hermitian=hermitian))(x)
7981

8082

8183
class MatrixInverse(Op):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
9395
"""
9496

9597
__props__ = ()
98+
gufunc_signature = "(m,m)->(m,m)"
99+
gufunc_spec = ("numpy.linalg.inv", 1, 1)
96100

97101
def __init__(self):
98102
pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
150154
return shapes
151155

152156

153-
inv = matrix_inverse = MatrixInverse()
157+
inv = matrix_inverse = Blockwise(MatrixInverse())
154158

155159

156160
def matrix_dot(*args):
@@ -181,6 +185,8 @@ class Det(Op):
181185
"""
182186

183187
__props__ = ()
188+
gufunc_signature = "(m,m)->()"
189+
gufunc_spec = ("numpy.linalg.det", 1, 1)
184190

185191
def make_node(self, x):
186192
x = as_tensor_variable(x)
@@ -209,7 +215,7 @@ def __str__(self):
209215
return "Det"
210216

211217

212-
det = Det()
218+
det = Blockwise(Det())
213219

214220

215221
class SLogDet(Op):
@@ -218,6 +224,8 @@ class SLogDet(Op):
218224
"""
219225

220226
__props__ = ()
227+
gufunc_signature = "(m, m)->(),()"
228+
gufunc_spec = ("numpy.linalg.slogdet", 1, 2)
221229

222230
def make_node(self, x):
223231
x = as_tensor_variable(x)
@@ -242,7 +250,7 @@ def __str__(self):
242250
return "SLogDet"
243251

244252

245-
slogdet = SLogDet()
253+
slogdet = Blockwise(SLogDet())
246254

247255

248256
class Eig(Op):
@@ -252,6 +260,8 @@ class Eig(Op):
252260
"""
253261

254262
__props__: Tuple[str, ...] = ()
263+
gufunc_signature = "(m,m)->(m),(m,m)"
264+
gufunc_spec = ("numpy.linalg.eig", 1, 2)
255265

256266
def make_node(self, x):
257267
x = as_tensor_variable(x)
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
270280
return [(n,), (n, n)]
271281

272282

273-
eig = Eig()
283+
eig = Blockwise(Eig())
274284

275285

276286
class Eigh(Eig):

0 commit comments

Comments
 (0)