Skip to content

Commit d92c367

Browse files
aseyboldtricardoV94
authored andcommitted
fix(numba): cholesky did not set off-diag entries to zero
1 parent 75a9fd2 commit d92c367

File tree

2 files changed

+29
-20
lines changed

2 files changed

+29
-20
lines changed

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,11 @@ def solve_triangular(a, b):
310310

311311

312312
def _cholesky(a, lower=False, overwrite_a=False, check_finite=True):
313-
return linalg.cholesky(
314-
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
313+
return (
314+
linalg.cholesky(
315+
a, lower=lower, overwrite_a=overwrite_a, check_finite=check_finite
316+
),
317+
0,
315318
)
316319

317320

@@ -346,6 +349,15 @@ def impl(A, lower=0, overwrite_a=False, check_finite=True):
346349
INFO,
347350
)
348351

352+
if lower:
353+
for j in range(1, _N):
354+
for i in range(j):
355+
A_copy[i, j] = 0.0
356+
else:
357+
for j in range(_N):
358+
for i in range(j + 1, _N):
359+
A_copy[i, j] = 0.0
360+
349361
return A_copy, int_ptr_to_val(INFO)
350362

351363
return impl

tests/link/numba/test_slinalg.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,8 @@
66
import pytensor
77
import pytensor.tensor as pt
88
from pytensor import config
9-
from pytensor.compile import SharedVariable
10-
from pytensor.graph import Constant, FunctionGraph
9+
from pytensor.graph import FunctionGraph
1110
from tests.link.numba.test_basic import compare_numba_and_py
12-
from tests.tensor.test_extra_ops import set_test_value
1311

1412

1513
numba = pytest.importorskip("numba")
@@ -109,23 +107,22 @@ def test_solve_triangular_raises_on_nan_inf(value):
109107

110108

111109
@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"])
112-
def test_numba_Cholesky(lower):
113-
x = set_test_value(
114-
pt.tensor(dtype=config.floatX, shape=(3, 3)),
115-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype(config.floatX)),
116-
)
110+
@pytest.mark.parametrize("trans", [True, False], ids=["trans=True", "trans=False"])
111+
def test_numba_Cholesky(lower, trans):
112+
cov = pt.matrix("cov")
117113

118-
g = pt.linalg.cholesky(x, lower=lower)
119-
g_fg = FunctionGraph(outputs=[g])
114+
if trans:
115+
cov_ = cov.T
116+
else:
117+
cov_ = cov
118+
chol = pt.linalg.cholesky(cov_, lower=lower)
120119

121-
compare_numba_and_py(
122-
g_fg,
123-
[
124-
i.tag.test_value
125-
for i in g_fg.inputs
126-
if not isinstance(i, SharedVariable | Constant)
127-
],
128-
)
120+
fg = FunctionGraph(outputs=[chol])
121+
122+
x = np.array([0.1, 0.2, 0.3])
123+
val = np.eye(3) + x[None, :] * x[:, None]
124+
125+
compare_numba_and_py(fg, [val])
129126

130127

131128
def test_numba_Cholesky_raises_on_nan_input():

0 commit comments

Comments
 (0)