Skip to content

Commit 90b0e43

Browse files
committed
minor changes
1 parent 365b6df commit 90b0e43

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -659,19 +659,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
659659
if not isinstance(node.op.core_op, Cholesky):
660660
return None
661661

662-
inputs = node.inputs[0]
662+
[input] = node.inputs
663663
# Check for use of pt.diag first
664664
if (
665-
inputs.owner
666-
and isinstance(inputs.owner.op, AllocDiag)
667-
and AllocDiag.is_offset_zero(inputs.owner)
665+
input.owner
666+
and isinstance(input.owner.op, AllocDiag)
667+
and AllocDiag.is_offset_zero(input.owner)
668668
):
669-
diag_input = inputs.owner.inputs[0]
669+
diag_input = input.owner.inputs[0]
670670
cholesky_val = pt.diag(diag_input**0.5)
671671
return [cholesky_val]
672672

673673
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
674-
inputs_or_none = _find_diag_from_eye_mul(inputs)
674+
inputs_or_none = _find_diag_from_eye_mul(input)
675675
if inputs_or_none is None:
676676
return None
677677

@@ -681,7 +681,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
681681
if len(non_eye_inputs) != 1:
682682
return None
683683

684-
non_eye_input = non_eye_inputs[0]
684+
[non_eye_input] = non_eye_inputs
685685

686686
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687687
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those

0 commit comments

Comments
 (0)