1
1
import logging
2
+ from typing import cast
2
3
3
4
from pytensor .graph .rewriting .basic import node_rewriter
4
- from pytensor .tensor import basic as at
5
+ from pytensor .tensor . basic import TensorVariable , extract_diag , swapaxes
5
6
from pytensor .tensor .blas import Dot22
7
+ from pytensor .tensor .blockwise import Blockwise
6
8
from pytensor .tensor .elemwise import DimShuffle
7
9
from pytensor .tensor .math import Dot , Prod , log , prod
8
- from pytensor .tensor .nlinalg import Det , MatrixInverse
10
+ from pytensor .tensor .nlinalg import MatrixInverse , det
9
11
from pytensor .tensor .rewriting .basic import (
10
12
register_canonicalize ,
11
13
register_specialize ,
17
19
logger = logging .getLogger (__name__ )
18
20
19
21
22
+ def is_matrix_transpose (x : TensorVariable ) -> bool :
23
+ """Check if a variable corresponds to a transpose of the last two axes"""
24
+ node = x .owner
25
+ if (
26
+ node
27
+ and isinstance (node .op , DimShuffle )
28
+ and not (node .op .drop or node .op .augment )
29
+ ):
30
+ [inp ] = node .inputs
31
+ ndims = inp .type .ndim
32
+ if ndims < 2 :
33
+ return False
34
+ transpose_order = tuple (range (ndims - 2 )) + (ndims - 1 , ndims - 2 )
35
+ return cast (bool , node .op .new_order == transpose_order )
36
+ return False
37
+
38
+
39
+ def _T (x : TensorVariable ) -> TensorVariable :
40
+ """Matrix transpose for potentially higher dimensionality tensors"""
41
+ return swapaxes (x , - 1 , - 2 )
42
+
43
+
20
44
@register_canonicalize
21
45
@node_rewriter ([DimShuffle ])
22
46
def transinv_to_invtrans (fgraph , node ):
23
- if isinstance (node .op , DimShuffle ):
24
- if node .op .new_order == (1 , 0 ):
25
- (A ,) = node .inputs
26
- if A .owner :
27
- if isinstance (A .owner .op , MatrixInverse ):
28
- (X ,) = A .owner .inputs
29
- return [A .owner .op (node .op (X ))]
47
+ if is_matrix_transpose (node .outputs [0 ]):
48
+ (A ,) = node .inputs
49
+ if (
50
+ A .owner
51
+ and isinstance (A .owner .op , Blockwise )
52
+ and isinstance (A .owner .op .core_op , MatrixInverse )
53
+ ):
54
+ (X ,) = A .owner .inputs
55
+ return [A .owner .op (node .op (X ))]
30
56
31
57
32
58
@register_stabilize
@@ -37,86 +63,98 @@ def inv_as_solve(fgraph, node):
37
63
"""
38
64
if isinstance (node .op , (Dot , Dot22 )):
39
65
l , r = node .inputs
40
- if l .owner and isinstance (l .owner .op , MatrixInverse ):
66
+ if (
67
+ l .owner
68
+ and isinstance (l .owner .op , Blockwise )
69
+ and isinstance (l .owner .op .core_op , MatrixInverse )
70
+ ):
41
71
return [solve (l .owner .inputs [0 ], r )]
42
- if r .owner and isinstance (r .owner .op , MatrixInverse ):
72
+ if (
73
+ r .owner
74
+ and isinstance (r .owner .op , Blockwise )
75
+ and isinstance (r .owner .op .core_op , MatrixInverse )
76
+ ):
43
77
x = r .owner .inputs [0 ]
44
78
if getattr (x .tag , "symmetric" , None ) is True :
45
- return [solve (x , l . T ). T ]
79
+ return [_T ( solve (x , _T ( l ))) ]
46
80
else :
47
- return [solve (x . T , l . T ). T ]
81
+ return [_T ( solve (_T ( x ), _T ( l ))) ]
48
82
49
83
50
84
@register_stabilize
51
85
@register_canonicalize
52
- @node_rewriter ([Solve ])
86
+ @node_rewriter ([Blockwise ])
53
87
def generic_solve_to_solve_triangular (fgraph , node ):
54
88
"""
55
89
If any solve() is applied to the output of a cholesky op, then
56
90
replace it with a triangular solve.
57
91
58
92
"""
59
- if isinstance (node .op , Solve ):
60
- A , b = node .inputs # result is solution Ax=b
61
- if A .owner and isinstance (A .owner .op , Cholesky ):
62
- if A .owner .op .lower :
63
- return [SolveTriangular (lower = True )(A , b )]
64
- else :
65
- return [SolveTriangular (lower = False )(A , b )]
66
- if (
67
- A .owner
68
- and isinstance (A .owner .op , DimShuffle )
69
- and A .owner .op .new_order == (1 , 0 )
70
- ):
71
- (A_T ,) = A .owner .inputs
72
- if A_T .owner and isinstance (A_T .owner .op , Cholesky ):
73
- if A_T .owner .op .lower :
74
- return [SolveTriangular (lower = False )(A , b )]
75
- else :
93
+ if isinstance (node .op .core_op , Solve ) and node .op .core_op .b_ndim == 1 :
94
+ if node .op .core_op .assume_a == "gen" :
95
+ A , b = node .inputs # result is solution Ax=b
96
+ if (
97
+ A .owner
98
+ and isinstance (A .owner .op , Blockwise )
99
+ and isinstance (A .owner .op .core_op , Cholesky )
100
+ ):
101
+ if A .owner .op .core_op .lower :
76
102
return [SolveTriangular (lower = True )(A , b )]
103
+ else :
104
+ return [SolveTriangular (lower = False )(A , b )]
105
+ if is_matrix_transpose (A ):
106
+ (A_T ,) = A .owner .inputs
107
+ if (
108
+ A_T .owner
109
+ and isinstance (A_T .owner .op , Blockwise )
110
+ and isinstance (A_T .owner .op , Cholesky )
111
+ ):
112
+ if A_T .owner .op .lower :
113
+ return [SolveTriangular (lower = False )(A , b )]
114
+ else :
115
+ return [SolveTriangular (lower = True )(A , b )]
77
116
78
117
79
118
@register_canonicalize
80
119
@register_stabilize
81
120
@register_specialize
82
121
@node_rewriter ([DimShuffle ])
83
122
def no_transpose_symmetric (fgraph , node ):
84
- if isinstance (node .op , DimShuffle ):
123
+ if is_matrix_transpose (node .outputs [ 0 ] ):
85
124
x = node .inputs [0 ]
86
- if x .type .ndim == 2 and getattr (x .tag , "symmetric" , None ) is True :
87
- if node .op .new_order == [1 , 0 ]:
88
- return [x ]
125
+ if getattr (x .tag , "symmetric" , None ):
126
+ return [x ]
89
127
90
128
91
129
@register_stabilize
92
- @node_rewriter ([Solve ])
130
+ @node_rewriter ([Blockwise ])
93
131
def psd_solve_with_chol (fgraph , node ):
94
132
"""
95
133
This utilizes a boolean `psd` tag on matrices.
96
134
"""
97
- if isinstance (node .op , Solve ):
135
+ if isinstance (node .op . core_op , Solve ) and node . op . core_op . b_ndim == 2 :
98
136
A , b = node .inputs # result is solution Ax=b
99
137
if getattr (A .tag , "psd" , None ) is True :
100
138
L = cholesky (A )
101
139
# N.B. this can be further reduced to a yet-unwritten cho_solve Op
102
- # __if__ no other Op makes use of the the L matrix during the
140
+ # __if__ no other Op makes use of the L matrix during the
103
141
# stabilization
104
- Li_b = Solve ( assume_a = "sym" , lower = True )( L , b )
105
- x = Solve ( assume_a = "sym" , lower = False )( L . T , Li_b )
142
+ Li_b = solve ( L , b , assume_a = "sym" , lower = True , b_ndim = 2 )
143
+ x = solve ( _T ( L ), Li_b , assume_a = "sym" , lower = False , b_ndim = 2 )
106
144
return [x ]
107
145
108
146
109
147
@register_canonicalize
110
148
@register_stabilize
111
- @node_rewriter ([Cholesky ])
149
+ @node_rewriter ([Blockwise ])
112
150
def cholesky_ldotlt (fgraph , node ):
113
151
"""
114
152
rewrite cholesky(dot(L, L.T), lower=True) = L, where L is lower triangular,
115
153
or cholesky(dot(U.T, U), upper=True) = U where U is upper triangular.
116
154
117
155
This utilizes a boolean `lower_triangular` or `upper_triangular` tag on matrices.
118
156
"""
119
- if not isinstance (node .op , Cholesky ):
157
+ if not isinstance (node .op . core_op , Cholesky ):
120
158
return
121
159
122
160
A = node .inputs [0 ]
@@ -128,45 +166,40 @@ def cholesky_ldotlt(fgraph, node):
128
166
# cholesky(dot(L,L.T)) case
129
167
if (
130
168
getattr (l .tag , "lower_triangular" , False )
131
- and r .owner
132
- and isinstance (r .owner .op , DimShuffle )
133
- and r .owner .op .new_order == (1 , 0 )
169
+ and is_matrix_transpose (r )
134
170
and r .owner .inputs [0 ] == l
135
171
):
136
- if node .op .lower :
172
+ if node .op .core_op . lower :
137
173
return [l ]
138
174
return [r ]
139
175
140
176
# cholesky(dot(U.T,U)) case
141
177
if (
142
178
getattr (r .tag , "upper_triangular" , False )
143
- and l .owner
144
- and isinstance (l .owner .op , DimShuffle )
145
- and l .owner .op .new_order == (1 , 0 )
179
+ and is_matrix_transpose (l )
146
180
and l .owner .inputs [0 ] == r
147
181
):
148
- if node .op .lower :
182
+ if node .op .core_op . lower :
149
183
return [l ]
150
184
return [r ]
151
185
152
186
153
187
@register_stabilize
154
188
@register_specialize
155
- @node_rewriter ([Det ])
189
+ @node_rewriter ([det ])
156
190
def local_det_chol (fgraph , node ):
157
191
"""
158
192
If we have det(X) and there is already an L=cholesky(X)
159
193
floating around, then we can use prod(diag(L)) to get the determinant.
160
194
161
195
"""
162
- if isinstance (node .op , Det ):
163
- (x ,) = node .inputs
164
- for cl , xpos in fgraph .clients [x ]:
165
- if cl == "output" :
166
- continue
167
- if isinstance (cl .op , Cholesky ):
168
- L = cl .outputs [0 ]
169
- return [prod (at .extract_diag (L ) ** 2 )]
196
+ (x ,) = node .inputs
197
+ for cl , xpos in fgraph .clients [x ]:
198
+ if cl == "output" :
199
+ continue
200
+ if isinstance (cl .op , Blockwise ) and isinstance (cl .op .core_op , Cholesky ):
201
+ L = cl .outputs [0 ]
202
+ return [prod (extract_diag (L ) ** 2 , axis = (- 1 , - 2 ))]
170
203
171
204
172
205
@register_canonicalize
@@ -177,16 +210,15 @@ def local_log_prod_sqr(fgraph, node):
177
210
"""
178
211
This utilizes a boolean `positive` tag on matrices.
179
212
"""
180
- if node .op == log :
181
- (x ,) = node .inputs
182
- if x .owner and isinstance (x .owner .op , Prod ):
183
- # we cannot always make this substitution because
184
- # the prod might include negative terms
185
- p = x .owner .inputs [0 ]
186
-
187
- # p is the matrix we're reducing with prod
188
- if getattr (p .tag , "positive" , None ) is True :
189
- return [log (p ).sum (axis = x .owner .op .axis )]
190
-
191
- # TODO: have a reduction like prod and sum that simply
192
- # returns the sign of the prod multiplication.
213
+ (x ,) = node .inputs
214
+ if x .owner and isinstance (x .owner .op , Prod ):
215
+ # we cannot always make this substitution because
216
+ # the prod might include negative terms
217
+ p = x .owner .inputs [0 ]
218
+
219
+ # p is the matrix we're reducing with prod
220
+ if getattr (p .tag , "positive" , None ) is True :
221
+ return [log (p ).sum (axis = x .owner .op .axis )]
222
+
223
+ # TODO: have a reduction like prod and sum that simply
224
+ # returns the sign of the prod multiplication.
0 commit comments