Skip to content

Commit 9be43d0

Browse files
Fix bug in tag_solve_triangular rewrite (#383)
* Fix bug in tag_solve_triangular rewrite * Rename tag_solve_triangular to generic_solve_to_solve_triangular
1 parent 7a82a3f commit 9be43d0

File tree

2 files changed

+50
-29
lines changed

2 files changed

+50
-29
lines changed

pytensor/tensor/rewriting/linalg.py

+19-20
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
register_specialize,
1212
register_stabilize,
1313
)
14-
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve
14+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, cholesky, solve
1515

1616

1717
logger = logging.getLogger(__name__)
@@ -50,31 +50,30 @@ def inv_as_solve(fgraph, node):
5050
@register_stabilize
5151
@register_canonicalize
5252
@node_rewriter([Solve])
53-
def tag_solve_triangular(fgraph, node):
53+
def generic_solve_to_solve_triangular(fgraph, node):
5454
"""
55-
If a general solve() is applied to the output of a cholesky op, then
55+
If any solve() is applied to the output of a cholesky op, then
5656
replace it with a triangular solve.
5757
5858
"""
5959
if isinstance(node.op, Solve):
60-
if node.op.assume_a == "gen":
61-
A, b = node.inputs # result is solution Ax=b
62-
if A.owner and isinstance(A.owner.op, Cholesky):
63-
if A.owner.op.lower:
64-
return [Solve(assume_a="sym", lower=True)(A, b)]
60+
A, b = node.inputs # result is solution Ax=b
61+
if A.owner and isinstance(A.owner.op, Cholesky):
62+
if A.owner.op.lower:
63+
return [SolveTriangular(lower=True)(A, b)]
64+
else:
65+
return [SolveTriangular(lower=False)(A, b)]
66+
if (
67+
A.owner
68+
and isinstance(A.owner.op, DimShuffle)
69+
and A.owner.op.new_order == (1, 0)
70+
):
71+
(A_T,) = A.owner.inputs
72+
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
73+
if A_T.owner.op.lower:
74+
return [SolveTriangular(lower=False)(A, b)]
6575
else:
66-
return [Solve(assume_a="sym", lower=False)(A, b)]
67-
if (
68-
A.owner
69-
and isinstance(A.owner.op, DimShuffle)
70-
and A.owner.op.new_order == (1, 0)
71-
):
72-
(A_T,) = A.owner.inputs
73-
if A_T.owner and isinstance(A_T.owner.op, Cholesky):
74-
if A_T.owner.op.lower:
75-
return [Solve(assume_a="sym", lower=False)(A, b)]
76-
else:
77-
return [Solve(assume_a="sym", lower=True)(A, b)]
76+
return [SolveTriangular(lower=True)(A, b)]
7877

7978

8079
@register_canonicalize

tests/tensor/rewriting/test_linalg.py

+31-9
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy.linalg
33
import pytest
44
import scipy.linalg
5+
from numpy.testing import assert_allclose
56

67
import pytensor
78
from pytensor import function
@@ -12,7 +13,7 @@
1213
from pytensor.tensor.math import _allclose
1314
from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
1415
from pytensor.tensor.rewriting.linalg import inv_as_solve
15-
from pytensor.tensor.slinalg import Cholesky, Solve, solve
16+
from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve
1617
from pytensor.tensor.type import dmatrix, matrix, vector
1718
from tests import unittest_tools as utt
1819
from tests.test_rop import break_op
@@ -81,25 +82,46 @@ def test_transinv_to_invtrans():
8182
assert node.inputs[0].name == "X"
8283

8384

84-
def test_tag_solve_triangular():
85+
def test_generic_solve_to_solve_triangular():
8586
cholesky_lower = Cholesky(lower=True)
8687
cholesky_upper = Cholesky(lower=False)
8788
A = matrix("A")
88-
x = vector("x")
89+
x = matrix("x")
90+
8991
L = cholesky_lower(A)
9092
U = cholesky_upper(A)
9193
b1 = solve(L, x)
9294
b2 = solve(U, x)
9395
f = pytensor.function([A, x], b1)
96+
97+
X = np.random.normal(size=(10, 10)).astype(config.floatX)
98+
X = X @ X.T
99+
X_chol = np.linalg.cholesky(X)
100+
eye = np.eye(10, dtype=config.floatX)
101+
94102
if config.mode != "FAST_COMPILE":
95-
for node in f.maker.fgraph.toposort():
96-
if isinstance(node.op, Solve):
97-
assert node.op.assume_a != "gen" and node.op.lower
103+
toposort = f.maker.fgraph.toposort()
104+
op_list = [node.op for node in toposort]
105+
106+
assert not any(isinstance(op, Solve) for op in op_list)
107+
assert any(isinstance(op, SolveTriangular) for op in op_list)
108+
109+
assert_allclose(
110+
f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4
111+
)
112+
98113
f = pytensor.function([A, x], b2)
114+
99115
if config.mode != "FAST_COMPILE":
100-
for node in f.maker.fgraph.toposort():
101-
if isinstance(node.op, Solve):
102-
assert node.op.assume_a != "gen" and not node.op.lower
116+
toposort = f.maker.fgraph.toposort()
117+
op_list = [node.op for node in toposort]
118+
assert not any(isinstance(op, Solve) for op in op_list)
119+
assert any(isinstance(op, SolveTriangular) for op in op_list)
120+
assert_allclose(
121+
f(X, eye).T @ X_chol,
122+
eye,
123+
atol=1e-8 if config.floatX.endswith("64") else 1e-4,
124+
)
103125

104126

105127
def test_matrix_inverse_solve():

0 commit comments

Comments
 (0)