Skip to content

Commit 365b6df

Browse files
committed
added test for batched case and more cases of not applying rewrite
1 parent 6877aea commit 365b6df

File tree

2 files changed

+14
-4
lines changed

2 files changed

+14
-4
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,5 +687,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
687687
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
688688
if non_eye_input.type.broadcastable[-2:] == (False, False):
689689
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
690+
if eye_input.type.ndim > 2:
691+
non_eye_input = pt.shape_padaxis(non_eye_input, -2)
690692

691693
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -594,8 +594,8 @@ def test_cholesky_eye_rewrite():
594594

595595
@pytest.mark.parametrize(
596596
"shape",
597-
[(), (7,), (7, 7)],
598-
ids=["scalar", "vector", "matrix"],
597+
[(), (7,), (7, 7), (5, 7, 7)],
598+
ids=["scalar", "vector", "matrix", "batched"],
599599
)
600600
def test_cholesky_diag_from_eye_mul(shape):
601601
# Initializing x based on scalar/vector/matrix
@@ -653,13 +653,21 @@ def test_cholesky_diag_from_diag():
653653
)
654654

655655

656-
def test_dont_apply_cholesky():
656+
def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
657+
# Case 1 : y is not a diagonal matrix because of k = -1
657658
x = pt.tensor("x", shape=(7, 7))
658659
y = pt.eye(7, k=-1) * x
659-
# Here, y is not a diagonal matrix because of k = -1
660660
z_cholesky = pt.linalg.cholesky(y)
661661

662662
# REWRITE TEST (should not be applied)
663663
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
664664
nodes = f_rewritten.maker.fgraph.apply_nodes
665665
assert any(isinstance(node.op, Cholesky) for node in nodes)
666+
667+
# Case 2 : eye is degenerate
668+
x = pt.scalar("x")
669+
y = pt.eye(1) * x
670+
z_cholesky = pt.linalg.cholesky(y)
671+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
672+
nodes = f_rewritten.maker.fgraph.apply_nodes
673+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)