|
36 | 36 | from pytensor.tensor.blas import BatchedDot
|
37 | 37 | from pytensor.tensor.math import Dot
|
38 | 38 | from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
|
39 |
| -from pytensor.tensor.slinalg import Solve |
| 39 | +from pytensor.tensor.slinalg import Cholesky, Solve |
40 | 40 | from pytensor.tensor.type import TensorType
|
41 | 41 | from pytensor.tensor.type_other import MakeSlice, NoneConst
|
42 | 42 |
|
@@ -646,6 +646,36 @@ def softplus(x):
|
646 | 646 | return softplus
|
647 | 647 |
|
648 | 648 |
|
| 649 | +@numba_funcify.register(Cholesky) |
| 650 | +def numba_funcify_Cholesky(op, node, **kwargs): |
| 651 | + lower = op.lower |
| 652 | + out_dtype = node.outputs[0].type.numpy_dtype |
| 653 | + |
| 654 | + if lower: |
| 655 | + inputs_cast = int_to_float_fn(node.inputs, out_dtype) |
| 656 | + |
| 657 | + @numba_njit |
| 658 | + def cholesky(a): |
| 659 | + return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype) |
| 660 | + |
| 661 | + else: |
| 662 | + # TODO: Use SciPy's BLAS/LAPACK Cython wrappers. |
| 663 | + warnings.warn( |
| 664 | + "Numba will use object mode to allow the `lower=False` argument to `scipy.linalg.cholesky`.", |
| 665 | + UserWarning, |
| 666 | + ) |
| 667 | + |
| 668 | + ret_sig = get_numba_type(node.outputs[0].type) |
| 669 | + |
| 670 | + @numba_njit |
| 671 | + def cholesky(a): |
| 672 | + with numba.objmode(ret=ret_sig): |
| 673 | + ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype) |
| 674 | + return ret |
| 675 | + |
| 676 | + return cholesky |
| 677 | + |
| 678 | + |
649 | 679 | @numba_funcify.register(Solve)
|
650 | 680 | def numba_funcify_Solve(op, node, **kwargs):
|
651 | 681 | assume_a = op.assume_a
|
|
0 commit comments