9
9
10
10
from pytensor .link .numba .dispatch import basic as numba_basic
11
11
from pytensor .link .numba .dispatch .basic import numba_funcify
12
- from pytensor .tensor .slinalg import BlockDiagonal , SolveTriangular
12
+ from pytensor .tensor .slinalg import BlockDiagonal , Cholesky , SolveTriangular
13
13
14
14
15
15
_PTR = ctypes .POINTER
25
25
_ptr_int = _PTR (_int )
26
26
27
27
28
+ @numba .core .extending .register_jitable
29
+ def _check_finite_matrix (a , func_name ):
30
+ for v in np .nditer (a ):
31
+ if not np .isfinite (v .item ()):
32
+ raise np .linalg .LinAlgError (
33
+ "Non-numeric values (nan or inf) in input to " + func_name
34
+ )
35
+
36
+
28
37
@intrinsic
29
38
def val_to_dptr (typingctx , data ):
30
39
def impl (context , builder , signature , args ):
@@ -177,6 +186,22 @@ def numba_xtrtrs(cls, dtype):
177
186
178
187
return functype (lapack_ptr )
179
188
189
+ @classmethod
190
+ def numba_xpotrf (cls , dtype ):
191
+ """
192
+ Called by scipy.linalg.cholesky
193
+ """
194
+ lapack_ptr , float_pointer = _get_lapack_ptr_and_ptr_type (dtype , "potrf" )
195
+ functype = ctypes .CFUNCTYPE (
196
+ None ,
197
+ _ptr_int , # UPLO,
198
+ _ptr_int , # N
199
+ float_pointer , # A
200
+ _ptr_int , # LDA
201
+ _ptr_int , # INFO
202
+ )
203
+ return functype (lapack_ptr )
204
+
180
205
181
206
def _solve_triangular (A , B , trans = 0 , lower = False , unit_diagonal = False ):
182
207
return linalg .solve_triangular (
@@ -190,13 +215,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False):
190
215
191
216
_check_scipy_linalg_matrix (A , "solve_triangular" )
192
217
_check_scipy_linalg_matrix (B , "solve_triangular" )
193
-
194
218
dtype = A .dtype
195
- if str (dtype ).startswith ("complex" ):
196
- raise ValueError (
197
- "Complex inputs not currently supported by solve_triangular in Numba mode"
198
- )
199
-
200
219
w_type = _get_underlying_float (dtype )
201
220
numba_trtrs = _LAPACK ().numba_xtrtrs (dtype )
202
221
@@ -249,8 +268,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False):
249
268
)
250
269
251
270
if B_is_1d :
252
- return B_copy [..., 0 ]
253
- return B_copy
271
+ return B_copy [..., 0 ], int_ptr_to_val ( INFO )
272
+ return B_copy , int_ptr_to_val ( INFO )
254
273
255
274
return impl
256
275
@@ -262,19 +281,122 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
262
281
unit_diagonal = op .unit_diagonal
263
282
check_finite = op .check_finite
264
283
284
+ dtype = node .inputs [0 ].dtype
285
+ if str (dtype ).startswith ("complex" ):
286
+ raise NotImplementedError (
287
+ "Complex inputs not currently supported by solve_triangular in Numba mode"
288
+ )
289
+
265
290
@numba_basic .numba_njit (inline = "always" )
266
291
def solve_triangular (a , b ):
267
- res = _solve_triangular (a , b , trans , lower , unit_diagonal )
268
292
if check_finite :
269
- if np .any (np .bitwise_or (np .isinf (res ), np .isnan (res ))):
270
- raise ValueError (
271
- "Non-numeric values (nan or inf) returned by solve_triangular"
293
+ if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
294
+ raise np . linalg . LinAlgError (
295
+ "Non-numeric values (nan or inf) in input A to solve_triangular"
272
296
)
297
+ if np .any (np .bitwise_or (np .isinf (b ), np .isnan (b ))):
298
+ raise np .linalg .LinAlgError (
299
+ "Non-numeric values (nan or inf) in input b to solve_triangular"
300
+ )
301
+
302
+ res , info = _solve_triangular (a , b , trans , lower , unit_diagonal )
303
+ if info != 0 :
304
+ raise np .linalg .LinAlgError (
305
+ "Singular matrix in input A to solve_triangular"
306
+ )
273
307
return res
274
308
275
309
return solve_triangular
276
310
277
311
312
+ def _cholesky (a , lower = False , overwrite_a = False , check_finite = True ):
313
+ return linalg .cholesky (
314
+ a , lower = lower , overwrite_a = overwrite_a , check_finite = check_finite
315
+ )
316
+
317
+
318
+ @overload (_cholesky )
319
+ def cholesky_impl (A , lower = 0 , overwrite_a = False , check_finite = True ):
320
+ ensure_lapack ()
321
+ _check_scipy_linalg_matrix (A , "cholesky" )
322
+ dtype = A .dtype
323
+ w_type = _get_underlying_float (dtype )
324
+ numba_potrf = _LAPACK ().numba_xpotrf (dtype )
325
+
326
+ def impl (A , lower = 0 , overwrite_a = False , check_finite = True ):
327
+ _N = np .int32 (A .shape [- 1 ])
328
+ if A .shape [- 2 ] != _N :
329
+ raise linalg .LinAlgError ("Last 2 dimensions of A must be square" )
330
+
331
+ UPLO = val_to_int_ptr (ord ("L" ) if lower else ord ("U" ))
332
+ N = val_to_int_ptr (_N )
333
+ LDA = val_to_int_ptr (_N )
334
+ INFO = val_to_int_ptr (0 )
335
+
336
+ if not overwrite_a :
337
+ A_copy = _copy_to_fortran_order (A )
338
+ else :
339
+ A_copy = A
340
+
341
+ numba_potrf (
342
+ UPLO ,
343
+ N ,
344
+ A_copy .view (w_type ).ctypes ,
345
+ LDA ,
346
+ INFO ,
347
+ )
348
+
349
+ return A_copy , int_ptr_to_val (INFO )
350
+
351
+ return impl
352
+
353
+
354
+ @numba_funcify .register (Cholesky )
355
+ def numba_funcify_Cholesky (op , node , ** kwargs ):
356
+ """
357
+ Overload scipy.linalg.cholesky with a numba function.
358
+
359
+ Note that np.linalg.cholesky is already implemented in numba, but it does not support additional keyword arguments.
360
+ In particular, the `inplace` argument is not supported, which is why we choose to implement our own version.
361
+ """
362
+ lower = op .lower
363
+ overwrite_a = False
364
+ check_finite = op .check_finite
365
+ on_error = op .on_error
366
+
367
+ dtype = node .inputs [0 ].dtype
368
+ if str (dtype ).startswith ("complex" ):
369
+ raise NotImplementedError (
370
+ "Complex inputs not currently supported by cholesky in Numba mode"
371
+ )
372
+
373
+ @numba_basic .numba_njit (inline = "always" )
374
+ def nb_cholesky (a ):
375
+ if check_finite :
376
+ if np .any (np .bitwise_or (np .isinf (a ), np .isnan (a ))):
377
+ raise np .linalg .LinAlgError (
378
+ "Non-numeric values (nan or inf) found in input to cholesky"
379
+ )
380
+ res , info = _cholesky (a , lower , overwrite_a , check_finite )
381
+
382
+ if on_error == "raise" :
383
+ if info > 0 :
384
+ raise np .linalg .LinAlgError (
385
+ "Input to cholesky is not positive definite"
386
+ )
387
+ if info < 0 :
388
+ raise ValueError (
389
+ 'LAPACK reported an illegal value in input on entry to "POTRF."'
390
+ )
391
+ else :
392
+ if info != 0 :
393
+ res = np .full_like (res , np .nan )
394
+
395
+ return res
396
+
397
+ return nb_cholesky
398
+
399
+
278
400
@numba_funcify .register (BlockDiagonal )
279
401
def numba_funcify_BlockDiagonal (op , node , ** kwargs ):
280
402
dtype = node .outputs [0 ].dtype
0 commit comments