Skip to content

Adds functions to rewrite cholesky decomposition of identity and diagonal matrices #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,3 +887,82 @@
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_remove_useless_cholesky(fgraph, node):
"""
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself

The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

# Check whether input to Cholesky is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and hasattr(potential_eye.owner.inputs[-1], "data")
and potential_eye.owner.inputs[-1].data.item() == 0
):
return None
return [potential_eye]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

[input] = node.inputs
# Check for use of pt.diag first
if (
input.owner
and isinstance(input.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(input.owner)
):
diag_input = input.owner.inputs[0]
cholesky_val = pt.diag(diag_input**0.5)
return [cholesky_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(input)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

Check warning on line 957 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L957

Added line #L957 was not covered by tests

[non_eye_input] = non_eye_inputs

# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_input = non_eye_input.diagonal(axis1=-1, axis2=-2)
if eye_input.type.ndim > 2:
non_eye_input = pt.shape_padaxis(non_eye_input, -2)

return [eye_input * (non_eye_input**0.5)]
103 changes: 103 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,3 +803,106 @@ def test_slogdet_kronecker_rewrite():
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_cholesky_eye_rewrite():
x = pt.eye(10)
L = pt.linalg.cholesky(x)
f_rewritten = function([], L, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

# Rewrite Test
assert not any(isinstance(node.op, Cholesky) for node in nodes)

# Value Test
x_test = np.eye(10)
L = np.linalg.cholesky(x_test)
rewritten_val = f_rewritten()

assert_allclose(
L,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
def test_cholesky_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Performing cholesky decomposition using pt.linalg.cholesky
z_cholesky = pt.linalg.cholesky(y)

# REWRITE TEST
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Cholesky) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_val = f_rewritten(x_test)

assert_allclose(
cholesky_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_cholesky_diag_from_diag():
x = pt.dvector("x")
x_diag = pt.diag(x)
x_cholesky = pt.linalg.cholesky(x_diag)

# REWRITE TEST
f_rewritten = function([x], x_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

assert not any(isinstance(node.op, Cholesky) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_cholesky = f_rewritten(x_test)

assert_allclose(
cholesky_val,
rewritten_cholesky,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
# Case 1 : y is not a diagonal matrix because of k = -1
x = pt.tensor("x", shape=(7, 7))
y = pt.eye(7, k=-1) * x
z_cholesky = pt.linalg.cholesky(y)

# REWRITE TEST (should not be applied)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)

# Case 2 : eye is degenerate
x = pt.scalar("x")
y = pt.eye(1) * x
z_cholesky = pt.linalg.cholesky(y)
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)
Loading