@@ -47,7 +47,7 @@ class Cholesky(Op):
47
47
48
48
__props__ = ("lower" , "destructive" , "on_error" )
49
49
50
- def __init__ (self , lower = True , on_error = "raise" ):
50
+ def __init__ (self , * , lower = True , on_error = "raise" ):
51
51
self .lower = lower
52
52
self .destructive = False
53
53
if on_error not in ("raise" , "nan" ):
@@ -125,77 +125,8 @@ def conjugate_solve_triangular(outer, inner):
125
125
return [grad ]
126
126
127
127
128
- cholesky = Cholesky ()
129
-
130
-
131
- class CholeskySolve (Op ):
132
- __props__ = ("lower" , "check_finite" )
133
-
134
- def __init__ (
135
- self ,
136
- lower = True ,
137
- check_finite = True ,
138
- ):
139
- self .lower = lower
140
- self .check_finite = check_finite
141
-
142
- def __repr__ (self ):
143
- return "CholeskySolve{%s}" % str (self ._props ())
144
-
145
- def make_node (self , C , b ):
146
- C = as_tensor_variable (C )
147
- b = as_tensor_variable (b )
148
- assert C .ndim == 2
149
- assert b .ndim in (1 , 2 )
150
-
151
- # infer dtype by solving the most simple
152
- # case with (1, 1) matrices
153
- o_dtype = scipy .linalg .solve (
154
- np .eye (1 ).astype (C .dtype ), np .eye (1 ).astype (b .dtype )
155
- ).dtype
156
- x = tensor (dtype = o_dtype , shape = b .type .shape )
157
- return Apply (self , [C , b ], [x ])
158
-
159
- def perform (self , node , inputs , output_storage ):
160
- C , b = inputs
161
- rval = scipy .linalg .cho_solve (
162
- (C , self .lower ),
163
- b ,
164
- check_finite = self .check_finite ,
165
- )
166
-
167
- output_storage [0 ][0 ] = rval
168
-
169
- def infer_shape (self , fgraph , node , shapes ):
170
- Cshape , Bshape = shapes
171
- rows = Cshape [1 ]
172
- if len (Bshape ) == 1 : # b is a Vector
173
- return [(rows ,)]
174
- else :
175
- cols = Bshape [1 ] # b is a Matrix
176
- return [(rows , cols )]
177
-
178
-
179
- cho_solve = CholeskySolve ()
180
-
181
-
182
- def cho_solve (c_and_lower , b , check_finite = True ):
183
- """Solve the linear equations A x = b, given the Cholesky factorization of A.
184
-
185
- Parameters
186
- ----------
187
- (c, lower) : tuple, (array, bool)
188
- Cholesky factorization of a, as given by cho_factor
189
- b : array
190
- Right-hand side
191
- check_finite : bool, optional
192
- Whether to check that the input matrices contain only finite numbers.
193
- Disabling may give a performance gain, but may result in problems
194
- (crashes, non-termination) if the inputs do contain infinities or NaNs.
195
- """
196
-
197
- A , lower = c_and_lower
198
- return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
128
+ def cholesky (x , lower = True , on_error = "raise" ):
129
+ return Cholesky (lower = lower , on_error = on_error )(x )
199
130
200
131
201
132
class SolveBase (Op ):
@@ -208,6 +139,7 @@ class SolveBase(Op):
208
139
209
140
def __init__ (
210
141
self ,
142
+ * ,
211
143
lower = False ,
212
144
check_finite = True ,
213
145
):
@@ -274,28 +206,56 @@ def L_op(self, inputs, outputs, output_gradients):
274
206
275
207
return [A_bar , b_bar ]
276
208
277
- def __repr__ (self ):
278
- return f"{ type (self ).__name__ } { self ._props ()} "
209
+
210
+ class CholeskySolve (SolveBase ):
211
+ def __init__ (self , ** kwargs ):
212
+ kwargs .setdefault ("lower" , True )
213
+ super ().__init__ (** kwargs )
214
+
215
+ def perform (self , node , inputs , output_storage ):
216
+ C , b = inputs
217
+ rval = scipy .linalg .cho_solve (
218
+ (C , self .lower ),
219
+ b ,
220
+ check_finite = self .check_finite ,
221
+ )
222
+
223
+ output_storage [0 ][0 ] = rval
224
+
225
+ def L_op (self , * args , ** kwargs ):
226
+ raise NotImplementedError ()
227
+
228
+
229
+ def cho_solve (c_and_lower , b , * , check_finite = True ):
230
+ """Solve the linear equations A x = b, given the Cholesky factorization of A.
231
+
232
+ Parameters
233
+ ----------
234
+ (c, lower) : tuple, (array, bool)
235
+ Cholesky factorization of a, as given by cho_factor
236
+ b : array
237
+ Right-hand side
238
+ check_finite : bool, optional
239
+ Whether to check that the input matrices contain only finite numbers.
240
+ Disabling may give a performance gain, but may result in problems
241
+ (crashes, non-termination) if the inputs do contain infinities or NaNs.
242
+ """
243
+ A , lower = c_and_lower
244
+ return CholeskySolve (lower = lower , check_finite = check_finite )(A , b )
279
245
280
246
281
247
class SolveTriangular (SolveBase ):
282
248
"""Solve a system of linear equations."""
283
249
284
250
__props__ = (
285
- "lower" ,
286
251
"trans" ,
287
252
"unit_diagonal" ,
253
+ "lower" ,
288
254
"check_finite" ,
289
255
)
290
256
291
- def __init__ (
292
- self ,
293
- trans = 0 ,
294
- lower = False ,
295
- unit_diagonal = False ,
296
- check_finite = True ,
297
- ):
298
- super ().__init__ (lower = lower , check_finite = check_finite )
257
+ def __init__ (self , * , trans = 0 , unit_diagonal = False , ** kwargs ):
258
+ super ().__init__ (** kwargs )
299
259
self .trans = trans
300
260
self .unit_diagonal = unit_diagonal
301
261
@@ -321,12 +281,10 @@ def L_op(self, inputs, outputs, output_gradients):
321
281
return res
322
282
323
283
324
- solvetriangular = SolveTriangular ()
325
-
326
-
327
284
def solve_triangular (
328
285
a : TensorVariable ,
329
286
b : TensorVariable ,
287
+ * ,
330
288
trans : Union [int , str ] = 0 ,
331
289
lower : bool = False ,
332
290
unit_diagonal : bool = False ,
@@ -374,16 +332,11 @@ class Solve(SolveBase):
374
332
"check_finite" ,
375
333
)
376
334
377
- def __init__ (
378
- self ,
379
- assume_a = "gen" ,
380
- lower = False ,
381
- check_finite = True ,
382
- ):
335
+ def __init__ (self , * , assume_a = "gen" , ** kwargs ):
383
336
if assume_a not in ("gen" , "sym" , "her" , "pos" ):
384
337
raise ValueError (f"{ assume_a } is not a recognized matrix structure" )
385
338
386
- super ().__init__ (lower = lower , check_finite = check_finite )
339
+ super ().__init__ (** kwargs )
387
340
self .assume_a = assume_a
388
341
389
342
def perform (self , node , inputs , outputs ):
@@ -397,10 +350,7 @@ def perform(self, node, inputs, outputs):
397
350
)
398
351
399
352
400
- solve = Solve ()
401
-
402
-
403
- def solve (a , b , assume_a = "gen" , lower = False , check_finite = True ):
353
+ def solve (a , b , * , assume_a = "gen" , lower = False , check_finite = True ):
404
354
"""Solves the linear equation set ``a * x = b`` for the unknown ``x`` for square ``a`` matrix.
405
355
406
356
If the data matrix is known to be a particular type then supplying the
0 commit comments