1
1
import logging
2
+ from typing import cast
2
3
3
4
from pytensor .graph .rewriting .basic import node_rewriter
4
- from pytensor .tensor import basic as at
5
+ from pytensor .tensor . basic import TensorVariable , extract_diag , swapaxes
5
6
from pytensor .tensor .blas import Dot22
7
+ from pytensor .tensor .blockwise import Blockwise
6
8
from pytensor .tensor .elemwise import DimShuffle
7
9
from pytensor .tensor .math import Dot , Prod , log , prod
8
- from pytensor .tensor .nlinalg import Det , MatrixInverse
10
+ from pytensor .tensor .nlinalg import MatrixInverse , det
9
11
from pytensor .tensor .rewriting .basic import (
10
12
register_canonicalize ,
11
13
register_specialize ,
17
19
logger = logging .getLogger (__name__ )
18
20
19
21
22
+ def is_matrix_transpose (x : TensorVariable ) -> bool :
23
+ """Check if a variable corresponds to a transpose of the last two axes"""
24
+ node = x .owner
25
+ if (
26
+ node
27
+ and isinstance (node .op , DimShuffle )
28
+ and not (node .op .drop or node .op .augment )
29
+ ):
30
+ [inp ] = node .inputs
31
+ ndims = inp .type .ndim
32
+ if ndims < 2 :
33
+ return False
34
+ transpose_order = tuple (range (ndims - 2 )) + (ndims - 1 , ndims - 2 )
35
+ return cast (bool , node .op .new_order == transpose_order )
36
+ return False
37
+
38
+
39
+ def T (x : TensorVariable ) -> TensorVariable :
40
+ """Matrix transpose for potentially higher dimensionality tensors"""
41
+ return swapaxes (x , - 1 , - 2 )
42
+
43
+
20
44
@register_canonicalize
21
45
@node_rewriter ([DimShuffle ])
22
46
def transinv_to_invtrans (fgraph , node ):
23
- if isinstance (node .op , DimShuffle ):
24
- if node .op .new_order == (1 , 0 ):
25
- (A ,) = node .inputs
26
- if A .owner :
27
- if isinstance (A .owner .op , MatrixInverse ):
28
- (X ,) = A .owner .inputs
29
- return [A .owner .op (node .op (X ))]
47
+ if is_matrix_transpose (node .outputs [0 ]):
48
+ (A ,) = node .inputs
49
+ if (
50
+ A .owner
51
+ and isinstance (A .owner .op , Blockwise )
52
+ and isinstance (A .owner .op .core_op , MatrixInverse )
53
+ ):
54
+ (X ,) = A .owner .inputs
55
+ return [A .owner .op (node .op (X ))]
30
56
31
57
32
58
@register_stabilize
@@ -37,87 +63,98 @@ def inv_as_solve(fgraph, node):
37
63
"""
38
64
if isinstance (node .op , (Dot , Dot22 )):
39
65
l , r = node .inputs
40
- if l .owner and isinstance (l .owner .op , MatrixInverse ):
66
+ if (
67
+ l .owner
68
+ and isinstance (l .owner .op , Blockwise )
69
+ and isinstance (l .owner .op .core_op , MatrixInverse )
70
+ ):
41
71
return [solve (l .owner .inputs [0 ], r )]
42
- if r .owner and isinstance (r .owner .op , MatrixInverse ):
72
+ if (
73
+ r .owner
74
+ and isinstance (r .owner .op , Blockwise )
75
+ and isinstance (r .owner .op .core_op , MatrixInverse )
76
+ ):
43
77
x = r .owner .inputs [0 ]
44
78
if getattr (x .tag , "symmetric" , None ) is True :
45
- return [solve (x , l . T ). T ]
79
+ return [T ( solve (x , T ( l ))) ]
46
80
else :
47
- return [solve (x . T , l . T ). T ]
81
+ return [T ( solve (T ( x ), T ( l ))) ]
48
82
49
83
50
84
@register_stabilize
51
85
@register_canonicalize
52
- @node_rewriter ([Solve ])
86
+ @node_rewriter ([Blockwise ])
53
87
def tag_solve_triangular (fgraph , node ):
54
88
"""
55
89
If a general solve() is applied to the output of a cholesky op, then
56
90
replace it with a triangular solve.
57
91
58
92
"""
59
- if isinstance (node .op , Solve ):
60
- if node .op .assume_a == "gen" :
93
+ if isinstance (node .op . core_op , Solve ) and node . op . core_op . b_ndim == 1 :
94
+ if node .op .core_op . assume_a == "gen" :
61
95
A , b = node .inputs # result is solution Ax=b
62
- if A .owner and isinstance (A .owner .op , Cholesky ):
63
- if A .owner .op .lower :
64
- return [Solve (assume_a = "sym" , lower = True )(A , b )]
65
- else :
66
- return [Solve (assume_a = "sym" , lower = False )(A , b )]
67
96
if (
68
97
A .owner
69
- and isinstance (A .owner .op , DimShuffle )
70
- and A .owner .op .new_order == ( 1 , 0 )
98
+ and isinstance (A .owner .op , Blockwise )
99
+ and isinstance ( A .owner .op .core_op , Cholesky )
71
100
):
101
+ if A .owner .op .core_op .lower :
102
+ return [solve (A , b , assume_a = "sym" , lower = True )]
103
+ else :
104
+ return [solve (A , b , assume_a = "sym" , lower = False )]
105
+ if is_matrix_transpose (A ):
72
106
(A_T ,) = A .owner .inputs
73
- if A_T .owner and isinstance (A_T .owner .op , Cholesky ):
107
+ if (
108
+ A_T .owner
109
+ and isinstance (A_T .owner .op , Blockwise )
110
+ and isinstance (A_T .owner .op , Cholesky )
111
+ ):
74
112
if A_T .owner .op .lower :
75
- return [Solve ( assume_a = "sym" , lower = False )( A , b )]
113
+ return [solve ( A , b , assume_a = "sym" , lower = False )]
76
114
else :
77
- return [Solve ( assume_a = "sym" , lower = True )( A , b )]
115
+ return [solve ( A , b , assume_a = "sym" , lower = True )]
78
116
79
117
80
118
@register_canonicalize
81
119
@register_stabilize
82
120
@register_specialize
83
121
@node_rewriter ([DimShuffle ])
84
122
def no_transpose_symmetric (fgraph , node ):
85
- if isinstance (node .op , DimShuffle ):
123
+ if is_matrix_transpose (node .outputs [ 0 ] ):
86
124
x = node .inputs [0 ]
87
- if x .type .ndim == 2 and getattr (x .tag , "symmetric" , None ) is True :
88
- if node .op .new_order == [1 , 0 ]:
89
- return [x ]
125
+ if getattr (x .tag , "symmetric" , None ):
126
+ return [x ]
90
127
91
128
92
129
@register_stabilize
93
- @node_rewriter ([Solve ])
130
+ @node_rewriter ([Blockwise ])
94
131
def psd_solve_with_chol (fgraph , node ):
95
132
"""
96
133
This utilizes a boolean `psd` tag on matrices.
97
134
"""
98
- if isinstance (node .op , Solve ):
135
+ if isinstance (node .op . core_op , Solve ) and node . op . core_op . b_ndim == 2 :
99
136
A , b = node .inputs # result is solution Ax=b
100
137
if getattr (A .tag , "psd" , None ) is True :
101
138
L = cholesky (A )
102
139
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
103
- # __if__ no other Op makes use of the the L matrix during the
140
+ # __if__ no other Op makes use of the L matrix during the
104
141
# stabilization
105
- Li_b = Solve ( assume_a = "sym" , lower = True )( L , b )
106
- x = Solve ( assume_a = "sym" , lower = False )( L . T , Li_b )
142
+ Li_b = solve ( L , b , assume_a = "sym" , lower = True , b_ndim = 2 )
143
+ x = solve ( T ( L ), Li_b , assume_a = "sym" , lower = False , b_ndim = 2 )
107
144
return [x ]
108
145
109
146
110
147
@register_canonicalize
111
148
@register_stabilize
112
- @node_rewriter ([Cholesky ])
149
+ @node_rewriter ([Blockwise ])
113
150
def cholesky_ldotlt (fgraph , node ):
114
151
"""
115
152
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
116
153
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
117
154
118
155
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
119
156
"""
120
- if not isinstance (node .op , Cholesky ):
157
+ if not isinstance (node .op . core_op , Cholesky ):
121
158
return
122
159
123
160
A = node .inputs [0 ]
@@ -129,43 +166,38 @@ def cholesky_ldotlt(fgraph, node):
129
166
# cholesky(dot(L,L.T)) case
130
167
if (
131
168
getattr (l .tag , "lower_triangular" , False )
132
- and r .owner
133
- and isinstance (r .owner .op , DimShuffle )
134
- and r .owner .op .new_order == (1 , 0 )
169
+ and is_matrix_transpose (r )
135
170
and r .owner .inputs [0 ] == l
136
171
):
137
- if node .op .lower :
172
+ if node .op .core_op . lower :
138
173
return [l ]
139
174
return [r ]
140
175
141
176
# cholesky(dot(U.T,U)) case
142
177
if (
143
178
getattr (r .tag , "upper_triangular" , False )
144
- and l .owner
145
- and isinstance (l .owner .op , DimShuffle )
146
- and l .owner .op .new_order == (1 , 0 )
179
+ and is_matrix_transpose (l )
147
180
and l .owner .inputs [0 ] == r
148
181
):
149
- if node .op .lower :
182
+ if node .op .core_op . lower :
150
183
return [l ]
151
184
return [r ]
152
185
153
186
154
187
@register_stabilize
155
188
@register_specialize
156
- @node_rewriter ([Det ])
189
+ @node_rewriter ([det ])
157
190
def local_det_chol (fgraph , node ):
158
191
"""
159
192
If we have det(X) and there is already an L=cholesky(X)
160
193
floating around, then we can use prod(diag(L)) to get the determinant.
161
194
162
195
"""
163
- if isinstance (node .op , Det ):
164
- (x ,) = node .inputs
165
- for cl , xpos in fgraph .clients [x ]:
166
- if isinstance (cl .op , Cholesky ):
167
- L = cl .outputs [0 ]
168
- return [prod (at .extract_diag (L ) ** 2 )]
196
+ (x ,) = node .inputs
197
+ for cl , xpos in fgraph .clients [x ]:
198
+ if isinstance (cl .op , Blockwise ) and isinstance (cl .op .core_op , Cholesky ):
199
+ L = cl .outputs [0 ]
200
+ return [prod (extract_diag (L ) ** 2 , axis = (- 1 , - 2 ))]
169
201
170
202
171
203
@register_canonicalize
@@ -176,16 +208,15 @@ def local_log_prod_sqr(fgraph, node):
176
208
"""
177
209
This utilizes a boolean `positive` tag on matrices.
178
210
"""
179
- if node .op == log :
180
- (x ,) = node .inputs
181
- if x .owner and isinstance (x .owner .op , Prod ):
182
- # we cannot always make this substitution because
183
- # the prod might include negative terms
184
- p = x .owner .inputs [0 ]
185
-
186
- # p is the matrix we're reducing with prod
187
- if getattr (p .tag , "positive" , None ) is True :
188
- return [log (p ).sum (axis = x .owner .op .axis )]
189
-
190
- # TODO: have a reduction like prod and sum that simply
191
- # returns the sign of the prod multiplication.
211
+ (x ,) = node .inputs
212
+ if x .owner and isinstance (x .owner .op , Prod ):
213
+ # we cannot always make this substitution because
214
+ # the prod might include negative terms
215
+ p = x .owner .inputs [0 ]
216
+
217
+ # p is the matrix we're reducing with prod
218
+ if getattr (p .tag , "positive" , None ) is True :
219
+ return [log (p ).sum (axis = x .owner .op .axis )]
220
+
221
+ # TODO: have a reduction like prod and sum that simply
222
+ # returns the sign of the prod multiplication.
0 commit comments