Skip to content

Commit 0b69d85

Browse files
committed
minor changes; added test to not apply rewrite
1 parent e96a770 commit 0b69d85

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ def rewrite_inv_inv(fgraph, node):
617617
@register_canonicalize
618618
@register_stabilize
619619
@node_rewriter([Blockwise])
620-
def rewrite_cholesky_eye_to_eye(fgraph, node):
620+
def rewrite_remove_useless_cholesky(fgraph, node):
621621
"""
622622
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
623623
@@ -640,14 +640,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
640640
return None
641641

642642
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
643-
eye_check = node.inputs[0]
643+
potential_eye = node.inputs[0]
644644
if not (
645-
eye_check.owner
646-
and isinstance(eye_check.owner.op, Eye)
647-
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
645+
potential_eye.owner
646+
and isinstance(potential_eye.owner.op, Eye)
647+
and hasattr(potential_eye.owner.inputs[-1], "data")
648+
and potential_eye.owner.inputs[-1].data.item() == 0
648649
):
649650
return None
650-
return [eye_check]
651+
return [potential_eye]
651652

652653

653654
@register_canonicalize
@@ -665,10 +666,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
665666
and isinstance(inputs.owner.op, AllocDiag)
666667
and AllocDiag.is_offset_zero(inputs.owner)
667668
):
668-
cholesky_input = inputs.owner.inputs[0]
669-
if cholesky_input.type.ndim == 1:
670-
cholesky_val = pt.diag(cholesky_input**0.5)
671-
return [cholesky_val]
669+
diag_input = inputs.owner.inputs[0]
670+
cholesky_val = pt.diag(diag_input**0.5)
671+
return [cholesky_val]
672672

673673
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
674674
inputs_or_none = _find_diag_from_eye_mul(inputs)
@@ -686,8 +686,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
686686
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687687
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
688688
if non_eye_input.type.broadcastable[-2:] == (False, False):
689-
# For Matrix
690-
return [eye_input * (non_eye_input.diagonal(axis1=-1, axis2=-2) ** 0.5)]
691-
else:
692-
# For Vector or Scalar
693-
return [eye_input * (non_eye_input**0.5)]
689+
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
690+
691+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -572,17 +572,12 @@ def get_pt_function(x, op_name):
572572

573573
def test_cholesky_eye_rewrite():
574574
x = pt.eye(10)
575-
x_mat = pt.matrix("x")
576575
L = pt.linalg.cholesky(x)
577-
L_mat = pt.linalg.cholesky(x_mat)
578576
f_rewritten = function([], L, mode="FAST_RUN")
579-
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
580577
nodes = f_rewritten.maker.fgraph.apply_nodes
581-
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
582578

583579
# Rewrite Test
584580
assert not any(isinstance(node.op, Cholesky) for node in nodes)
585-
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
586581

587582
# Value Test
588583
x_test = np.eye(10)
@@ -656,3 +651,15 @@ def test_cholesky_diag_from_diag():
656651
atol=1e-3 if config.floatX == "float32" else 1e-8,
657652
rtol=1e-3 if config.floatX == "float32" else 1e-8,
658653
)
654+
655+
656+
def test_dont_apply_cholesky():
657+
x = pt.tensor("x", shape=(7, 7))
658+
y = pt.eye(7, k=-1) * x
659+
# Here, y is not a diagonal matrix because of k = -1
660+
z_cholesky = pt.linalg.cholesky(y)
661+
662+
# REWRITE TEST (should not be applied)
663+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
664+
nodes = f_rewritten.maker.fgraph.apply_nodes
665+
assert any(isinstance(node.op, Cholesky) for node in nodes)

0 commit comments

Comments
 (0)