Skip to content

Commit 487d6de

Browse files
committed
added rewrite for cholesky(diag) -> sqrt(diag)
1 parent 9ac68db commit 487d6de

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,28 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
572572
):
573573
return None
574574
return [eye_check]
575+
576+
577+
@register_canonicalize
578+
@register_stabilize
579+
@node_rewriter([Blockwise])
580+
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
581+
# Find whether cholesky op is being applied
582+
if not isinstance(node.op.core_op, Cholesky):
583+
return None
584+
585+
# Check whether input is diagonal from multiplcation of identity matrix with a tensor
586+
inputs = node.inputs[0]
587+
inputs_or_none = _find_diag_from_eye_mul(inputs)
588+
if inputs_or_none is None:
589+
return None
590+
591+
eye_input, non_eye_inputs = inputs_or_none
592+
593+
# Dealing with only one other input
594+
if len(non_eye_inputs) != 1:
595+
return None
596+
597+
eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]
598+
599+
return [eye_input * (non_eye_input**0.5)]

tests/tensor/rewriting/test_linalg.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,39 @@ def test_cholesky_eye_rewrite():
572572
atol=1e-3 if config.floatX == "float32" else 1e-8,
573573
rtol=1e-3 if config.floatX == "float32" else 1e-8,
574574
)
575+
576+
577+
@pytest.mark.parametrize(
578+
"shape",
579+
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
580+
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
581+
)
582+
def test_cholesky_diag_from_eye_mul(shape):
583+
# Initializing x based on scalar/vector/matrix
584+
x = pt.tensor("x", shape=shape)
585+
y = pt.eye(7) * x
586+
# Performing cholesky decomposition using pt.linalg.cholesky
587+
z_cholesky = pt.linalg.cholesky(y)
588+
589+
# REWRITE TEST
590+
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
591+
nodes = f_rewritten.maker.fgraph.apply_nodes
592+
assert not any(isinstance(node.op, Cholesky) for node in nodes)
593+
594+
# NUMERIC VALUE TEST
595+
if len(shape) == 0:
596+
x_test = np.array(np.random.rand()).astype(config.floatX)
597+
elif len(shape) == 1:
598+
x_test = np.random.rand(*shape).astype(config.floatX)
599+
else:
600+
x_test = np.random.rand(*shape).astype(config.floatX)
601+
x_test_matrix = np.eye(7) * x_test
602+
cholesky_val = np.linalg.cholesky(x_test_matrix)
603+
rewritten_val = f_rewritten(x_test)
604+
605+
assert_allclose(
606+
cholesky_val,
607+
rewritten_val,
608+
atol=1e-3 if config.floatX == "float32" else 1e-8,
609+
rtol=1e-3 if config.floatX == "float32" else 1e-8,
610+
)

0 commit comments

Comments
 (0)