@@ -49,7 +49,7 @@ class Cholesky(Op):
49
49
50
50
__props__ = ("lower" , "destructive" , "on_error" )
51
51
52
- def __init__ (self , lower = True , on_error = "raise" ):
52
+ def __init__ (self , * , lower = True , on_error = "raise" ):
53
53
self .lower = lower
54
54
self .destructive = False
55
55
if on_error not in ("raise" , "nan" ):
@@ -127,77 +127,8 @@ def conjugate_solve_triangular(outer, inner):
127
127
return [grad ]
128
128
129
129
130
- cholesky = Cholesky ()
131
-
132
-
133
- class CholeskySolve (Op ):
134
- __props__ = ("lower" , "check_finite" )
135
-
136
- def __init__ (
137
- self ,
138
- lower = True ,
139
- check_finite = True ,
140
- ):
141
- self .lower = lower
142
- self .check_finite = check_finite
143
-
144
- def __repr__ (self ):
145
- return "CholeskySolve{%s}" % str (self ._props ())
146
-
147
- def make_node (self , C , b ):
148
- C = as_tensor_variable (C )
149
- b = as_tensor_variable (b )
150
- assert C .ndim == 2
151
- assert b .ndim in (1 , 2 )
152
-
153
- # infer dtype by solving the most simple
154
- # case with (1, 1) matrices
155
- o_dtype = scipy .linalg .solve (
156
- np .eye (1 ).astype (C .dtype ), np .eye (1 ).astype (b .dtype )
157
- ).dtype
158
- x = tensor (dtype = o_dtype , shape = b .type .shape )
159
- return Apply (self , [C , b ], [x ])
160
-
161
- def perform (self , node , inputs , output_storage ):
162
- C , b = inputs
163
- rval = scipy .linalg .cho_solve (
164
- (C , self .lower ),
165
- b ,
166
- check_finite = self .check_finite ,
167
- )
168
-
169
- output_storage [0 ][0 ] = rval
170
-
171
- def infer_shape (self , fgraph , node , shapes ):
172
- Cshape , Bshape = shapes
173
- rows = Cshape [1 ]
174
- if len (Bshape ) == 1 : # b is a Vector
175
- return [(rows ,)]
176
- else :
177
- cols = Bshape [1 ] # b is a Matrix
178
- return [(rows , cols )]
179
-
180
-
181
- cho_solve = CholeskySolve ()
182
-
183
-
184
- def cho_solve (c_and_lower , b , check_finite = True ):
185
- """Solve the linear equations A x = b, given the Cholesky factorization of A.
186
-
187
- Parameters
188
- ----------
189
- (c, lower) : tuple, (array, bool)
190
- Cholesky factorization of a, as given by cho_factor
191
- b : array
192
- Right-hand side
193
- check_finite : bool, optional
194
- Whether to check that the input matrices contain only finite numbers.
195
- Disabling may give a performance gain, but may result in problems
196
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
197
- """
198
-
199
- A , lower = c_and_lower
200
- return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
130
+ def cholesky (x , lower = True , on_error = "raise" ):
131
+ return Cholesky (lower = lower , on_error = on_error )(x )
201
132
202
133
203
134
class SolveBase (Op ):
@@ -210,6 +141,7 @@ class SolveBase(Op):
210
141
211
142
def __init__ (
212
143
self ,
144
+ * ,
213
145
lower = False ,
214
146
check_finite = True ,
215
147
):
@@ -276,28 +208,56 @@ def L_op(self, inputs, outputs, output_gradients):
276
208
277
209
return [A_bar , b_bar ]
278
210
279
- def __repr__ (self ):
280
- return f"{ type (self ).__name__ } { self ._props ()} "
211
+
212
+ class CholeskySolve (SolveBase ):
213
+ def __init__ (self , ** kwargs ):
214
+ kwargs .setdefault ("lower" , True )
215
+ super ().__init__ (** kwargs )
216
+
217
+ def perform (self , node , inputs , output_storage ):
218
+ C , b = inputs
219
+ rval = scipy .linalg .cho_solve (
220
+ (C , self .lower ),
221
+ b ,
222
+ check_finite = self .check_finite ,
223
+ )
224
+
225
+ output_storage [0 ][0 ] = rval
226
+
227
+ def L_op (self , * args , ** kwargs ):
228
+ raise NotImplementedError ()
229
+
230
+
231
+ def cho_solve (c_and_lower , b , * , check_finite = True ):
232
+ """Solve the linear equations A x = b, given the Cholesky factorization of A.
233
+
234
+ Parameters
235
+ ----------
236
+ (c, lower) : tuple, (array, bool)
237
+ Cholesky factorization of a, as given by cho_factor
238
+ b : array
239
+ Right-hand side
240
+ check_finite : bool, optional
241
+ Whether to check that the input matrices contain only finite numbers.
242
+ Disabling may give a performance gain, but may result in problems
243
+ (crashes, non-termination) if the inputs do contain infinities or NaNs.
244
+ """
245
+ A , lower = c_and_lower
246
+ return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
281
247
282
248
283
249
class SolveTriangular (SolveBase ):
284
250
"""Solve a system of linear equations."""
285
251
286
252
__props__ = (
287
- "lower" ,
288
253
"trans" ,
289
254
"unit_diagonal" ,
255
+ "lower" ,
290
256
"check_finite" ,
291
257
)
292
258
293
- def __init__ (
294
- self ,
295
- trans = 0 ,
296
- lower = False ,
297
- unit_diagonal = False ,
298
- check_finite = True ,
299
- ):
300
- super ().__init__ (lower = lower , check_finite = check_finite )
259
+ def __init__ (self , * , trans = 0 , unit_diagonal = False , ** kwargs ):
260
+ super ().__init__ (** kwargs )
301
261
self .trans = trans
302
262
self .unit_diagonal = unit_diagonal
303
263
@@ -326,6 +286,7 @@ def L_op(self, inputs, outputs, output_gradients):
326
286
def solve_triangular (
327
287
a : TensorVariable ,
328
288
b : TensorVariable ,
289
+ * ,
329
290
trans : Union [int , str ] = 0 ,
330
291
lower : bool = False ,
331
292
unit_diagonal : bool = False ,
@@ -373,16 +334,11 @@ class Solve(SolveBase):
373
334
"check_finite" ,
374
335
)
375
336
376
- def __init__ (
377
- self ,
378
- assume_a = "gen" ,
379
- lower = False ,
380
- check_finite = True ,
381
- ):
337
+ def __init__ (self , * , assume_a = "gen" , ** kwargs ):
382
338
if assume_a not in ("gen" , "sym" , "her" , "pos" ):
383
339
raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
384
340
385
- super ().__init__ (lower = lower , check_finite = check_finite )
341
+ super ().__init__ (** kwargs )
386
342
self .assume_a = assume_a
387
343
388
344
def perform (self , node , inputs , outputs ):
@@ -396,7 +352,7 @@ def perform(self, node, inputs, outputs):
396
352
)
397
353
398
354
399
- def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
355
+ def solve (a , b , * , assume_a = "gen" , lower = False , check_finite = True ):
400
356
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
401
357
402
358
If the data matrix is known to be a particular type then supplying the
0 commit comments