Skip to content

Commit 9c8e25f

Browse files
jessegrabowskiricardoV94
authored andcommitted
Numba implementation of Cholesky
1 parent 7eca252 commit 9c8e25f

File tree

1 file changed

+31
-1
lines changed

1 file changed

+31
-1
lines changed

pytensor/link/numba/dispatch/basic.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from pytensor.tensor.blas import BatchedDot
3737
from pytensor.tensor.math import Dot
3838
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
39-
from pytensor.tensor.slinalg import Solve
39+
from pytensor.tensor.slinalg import Cholesky, Solve
4040
from pytensor.tensor.type import TensorType
4141
from pytensor.tensor.type_other import MakeSlice, NoneConst
4242

@@ -646,6 +646,36 @@ def softplus(x):
646646
return softplus
647647

648648

649+
@numba_funcify.register(Cholesky)
650+
def numba_funcify_Cholesky(op, node, **kwargs):
651+
lower = op.lower
652+
out_dtype = node.outputs[0].type.numpy_dtype
653+
654+
if lower:
655+
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
656+
657+
@numba_njit
658+
def cholesky(a):
659+
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
660+
661+
else:
662+
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
663+
warnings.warn(
664+
"Numba will use object mode to allow the `lower=False` argument to `scipy.linalg.cholesky`.",
665+
UserWarning,
666+
)
667+
668+
ret_sig = get_numba_type(node.outputs[0].type)
669+
670+
@numba_njit
671+
def cholesky(a):
672+
with numba.objmode(ret=ret_sig):
673+
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
674+
return ret
675+
676+
return cholesky
677+
678+
649679
@numba_funcify.register(Solve)
650680
def numba_funcify_Solve(op, node, **kwargs):
651681
assume_a = op.assume_a

0 commit comments

Comments
 (0)