|
2 | 2 | import numpy.linalg
|
3 | 3 | import pytest
|
4 | 4 | import scipy.linalg
|
| 5 | +from numpy.testing import assert_allclose |
5 | 6 |
|
6 | 7 | import pytensor
|
7 | 8 | from pytensor import function
|
|
12 | 13 | from pytensor.tensor.math import _allclose
|
13 | 14 | from pytensor.tensor.nlinalg import MatrixInverse, matrix_inverse
|
14 | 15 | from pytensor.tensor.rewriting.linalg import inv_as_solve
|
15 |
| -from pytensor.tensor.slinalg import Cholesky, Solve, solve |
| 16 | +from pytensor.tensor.slinalg import Cholesky, Solve, SolveTriangular, solve |
16 | 17 | from pytensor.tensor.type import dmatrix, matrix, vector
|
17 | 18 | from tests import unittest_tools as utt
|
18 | 19 | from tests.test_rop import break_op
|
@@ -81,25 +82,46 @@ def test_transinv_to_invtrans():
|
81 | 82 | assert node.inputs[0].name == "X"
|
82 | 83 |
|
83 | 84 |
|
84 |
| -def test_tag_solve_triangular(): |
| 85 | +def test_generic_solve_to_solve_triangular(): |
85 | 86 | cholesky_lower = Cholesky(lower=True)
|
86 | 87 | cholesky_upper = Cholesky(lower=False)
|
87 | 88 | A = matrix("A")
|
88 |
| - x = vector("x") |
| 89 | + x = matrix("x") |
| 90 | + |
89 | 91 | L = cholesky_lower(A)
|
90 | 92 | U = cholesky_upper(A)
|
91 | 93 | b1 = solve(L, x)
|
92 | 94 | b2 = solve(U, x)
|
93 | 95 | f = pytensor.function([A, x], b1)
|
| 96 | + |
| 97 | + X = np.random.normal(size=(10, 10)).astype(config.floatX) |
| 98 | + X = X @ X.T |
| 99 | + X_chol = np.linalg.cholesky(X) |
| 100 | + eye = np.eye(10, dtype=config.floatX) |
| 101 | + |
94 | 102 | if config.mode != "FAST_COMPILE":
|
95 |
| - for node in f.maker.fgraph.toposort(): |
96 |
| - if isinstance(node.op, Solve): |
97 |
| - assert node.op.assume_a != "gen" and node.op.lower |
| 103 | + toposort = f.maker.fgraph.toposort() |
| 104 | + op_list = [node.op for node in toposort] |
| 105 | + |
| 106 | + assert not any(isinstance(op, Solve) for op in op_list) |
| 107 | + assert any(isinstance(op, SolveTriangular) for op in op_list) |
| 108 | + |
| 109 | + assert_allclose( |
| 110 | + f(X, eye) @ X_chol, eye, atol=1e-8 if config.floatX.endswith("64") else 1e-4 |
| 111 | + ) |
| 112 | + |
98 | 113 | f = pytensor.function([A, x], b2)
|
| 114 | + |
99 | 115 | if config.mode != "FAST_COMPILE":
|
100 |
| - for node in f.maker.fgraph.toposort(): |
101 |
| - if isinstance(node.op, Solve): |
102 |
| - assert node.op.assume_a != "gen" and not node.op.lower |
| 116 | + toposort = f.maker.fgraph.toposort() |
| 117 | + op_list = [node.op for node in toposort] |
| 118 | + assert not any(isinstance(op, Solve) for op in op_list) |
| 119 | + assert any(isinstance(op, SolveTriangular) for op in op_list) |
| 120 | + assert_allclose( |
| 121 | + f(X, eye).T @ X_chol, |
| 122 | + eye, |
| 123 | + atol=1e-8 if config.floatX.endswith("64") else 1e-4, |
| 124 | + ) |
103 | 125 |
|
104 | 126 |
|
105 | 127 | def test_matrix_inverse_solve():
|
|
0 commit comments