Skip to content

Commit 736782b

Browse files
committed
fixed merge conflicts
1 parent 981688c commit 736782b

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,3 +611,64 @@ def rewrite_inv_inv(fgraph, node):
611611
):
612612
return None
613613
return [potential_inner_inv.inputs[0]]
614+
615+
616+
@register_canonicalize
617+
@register_stabilize
618+
@node_rewriter([Blockwise])
619+
def rewrite_cholesky_eye_to_eye(fgraph, node):
620+
"""
621+
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
622+
623+
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
624+
625+
Parameters
626+
----------
627+
fgraph: FunctionGraph
628+
Function graph being optimized
629+
node: Apply
630+
Node of the function graph to be optimized
631+
632+
Returns
633+
-------
634+
list of Variable, optional
635+
List of optimized variables, or None if no optimization was performed
636+
"""
637+
# Find whether cholesky op is being applied
638+
if not isinstance(node.op.core_op, Cholesky):
639+
return None
640+
641+
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
642+
eye_check = node.inputs[0]
643+
if not (
644+
eye_check.owner
645+
and isinstance(eye_check.owner.op, Eye)
646+
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
647+
):
648+
return None
649+
return [eye_check]
650+
651+
652+
@register_canonicalize
653+
@register_stabilize
654+
@node_rewriter([Blockwise])
655+
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
656+
# Find whether cholesky op is being applied
657+
if not isinstance(node.op.core_op, Cholesky):
658+
return None
659+
660+
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
661+
inputs = node.inputs[0]
662+
inputs_or_none = _find_diag_from_eye_mul(inputs)
663+
if inputs_or_none is None:
664+
return None
665+
666+
eye_input, non_eye_inputs = inputs_or_none
667+
668+
# Dealing with only one other input
669+
if len(non_eye_inputs) != 1:
670+
return None
671+
672+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
673+
674+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,3 +568,66 @@ def get_pt_function(x, op_name):
568568
op2 = get_pt_function(op1, inv_op_2)
569569
rewritten_out = rewrite_graph(op2)
570570
assert rewritten_out == x
571+
572+
573+
def test_cholesky_eye_rewrite():
574+
x = pt.eye(10)
575+
x_mat = pt.matrix("x")
576+
L = pt.linalg.cholesky(x)
577+
L_mat = pt.linalg.cholesky(x_mat)
578+
f_rewritten = function([], L, mode="FAST_RUN")
579+
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
580+
nodes = f_rewritten.maker.fgraph.apply_nodes
581+
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes
582+
583+
# Rewrite Test
584+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
585+
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)
586+
587+
# Value Test
588+
x_test = np.eye(10)
589+
L = np.linalg.cholesky(x_test)
590+
rewritten_val = f_rewritten()
591+
592+
assert_allclose(
593+
L,
594+
rewritten_val,
595+
atol=1e-3 if config.floatX == "float32" else 1e-8,
596+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
597+
)
598+
599+
600+
@pytest.mark.parametrize(
601+
"shape",
602+
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
603+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
604+
)
605+
def test_cholesky_diag_from_eye_mul(shape):
606+
# Initializing x based on scalar/vector/matrix
607+
x = pt.tensor("x", shape=shape)
608+
y = pt.eye(7) * x
609+
# Performing cholesky decomposition using pt.linalg.cholesky
610+
z_cholesky = pt.linalg.cholesky(y)
611+
612+
# REWRITE TEST
613+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
614+
nodes = f_rewritten.maker.fgraph.apply_nodes
615+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
616+
617+
# NUMERIC VALUE TEST
618+
if len(shape) == 0:
619+
x_test = np.array(np.random.rand()).astype(config.floatX)
620+
elif len(shape) == 1:
621+
x_test = np.random.rand(*shape).astype(config.floatX)
622+
else:
623+
x_test = np.random.rand(*shape).astype(config.floatX)
624+
x_test_matrix = np.eye(7) * x_test
625+
cholesky_val = np.linalg.cholesky(x_test_matrix)
626+
rewritten_val = f_rewritten(x_test)
627+
628+
assert_allclose(
629+
cholesky_val,
630+
rewritten_val,
631+
atol=1e-3 if config.floatX == "float32" else 1e-8,
632+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
633+
)

0 commit comments

Comments
 (0)