10
10
from pytensor .tensor import basic as at
11
11
from pytensor .tensor import math as tm
12
12
from pytensor .tensor .basic import as_tensor_variable , extract_diag
13
+ from pytensor .tensor .blockwise import Blockwise
13
14
from pytensor .tensor .type import dvector , lscalar , matrix , scalar , vector
14
15
15
16
16
17
class MatrixPinv (Op ):
17
18
__props__ = ("hermitian" ,)
19
+ gufunc_signature = "(m,n)->(n,m)"
18
20
19
21
def __init__ (self , hermitian ):
20
22
self .hermitian = hermitian
@@ -75,7 +77,7 @@ def pinv(x, hermitian=False):
75
77
solve op.
76
78
77
79
"""
78
- return MatrixPinv (hermitian = hermitian )(x )
80
+ return Blockwise ( MatrixPinv (hermitian = hermitian ) )(x )
79
81
80
82
81
83
class MatrixInverse (Op ):
@@ -93,6 +95,8 @@ class MatrixInverse(Op):
93
95
"""
94
96
95
97
__props__ = ()
98
+ gufunc_signature = "(m,m)->(m,m)"
99
+ gufunc_spec = ("numpy.linalg.inv" , 1 , 1 )
96
100
97
101
def __init__ (self ):
98
102
pass
@@ -150,7 +154,7 @@ def infer_shape(self, fgraph, node, shapes):
150
154
return shapes
151
155
152
156
153
- inv = matrix_inverse = MatrixInverse ()
157
+ inv = matrix_inverse = Blockwise ( MatrixInverse () )
154
158
155
159
156
160
def matrix_dot (* args ):
@@ -181,6 +185,8 @@ class Det(Op):
181
185
"""
182
186
183
187
__props__ = ()
188
+ gufunc_signature = "(m,m)->()"
189
+ gufunc_spec = ("numpy.linalg.det" , 1 , 1 )
184
190
185
191
def make_node (self , x ):
186
192
x = as_tensor_variable (x )
@@ -209,7 +215,7 @@ def __str__(self):
209
215
return "Det"
210
216
211
217
212
- det = Det ()
218
+ det = Blockwise ( Det () )
213
219
214
220
215
221
class SLogDet (Op ):
@@ -218,6 +224,8 @@ class SLogDet(Op):
218
224
"""
219
225
220
226
__props__ = ()
227
+ gufunc_signature = "(m, m)->(),()"
228
+ gufunc_spec = ("numpy.linalg.slogdet" , 1 , 2 )
221
229
222
230
def make_node (self , x ):
223
231
x = as_tensor_variable (x )
@@ -242,7 +250,7 @@ def __str__(self):
242
250
return "SLogDet"
243
251
244
252
245
- slogdet = SLogDet ()
253
+ slogdet = Blockwise ( SLogDet () )
246
254
247
255
248
256
class Eig (Op ):
@@ -252,6 +260,8 @@ class Eig(Op):
252
260
"""
253
261
254
262
__props__ : Tuple [str , ...] = ()
263
+ gufunc_signature = "(m,m)->(m),(m,m)"
264
+ gufunc_spec = ("numpy.linalg.eig" , 1 , 2 )
255
265
256
266
def make_node (self , x ):
257
267
x = as_tensor_variable (x )
@@ -270,7 +280,7 @@ def infer_shape(self, fgraph, node, shapes):
270
280
return [(n ,), (n , n )]
271
281
272
282
273
- eig = Eig ()
283
+ eig = Blockwise ( Eig () )
274
284
275
285
276
286
class Eigh (Eig ):
0 commit comments