Skip to content

Commit fd79395

Browse files
tanish1729jessegrabowski
authored andcommitted
minor changes
1 parent 44c13d9 commit fd79395

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

pytensor/tensor/rewriting/linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -865,19 +865,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
865865
if not isinstance(node.op.core_op, Cholesky):
866866
return None
867867

868-
inputs = node.inputs[0]
868+
[input] = node.inputs
869869
# Check for use of pt.diag first
870870
if (
871-
inputs.owner
872-
and isinstance(inputs.owner.op, AllocDiag)
873-
and AllocDiag.is_offset_zero(inputs.owner)
871+
input.owner
872+
and isinstance(input.owner.op, AllocDiag)
873+
and AllocDiag.is_offset_zero(input.owner)
874874
):
875-
diag_input = inputs.owner.inputs[0]
875+
diag_input = input.owner.inputs[0]
876876
cholesky_val = pt.diag(diag_input**0.5)
877877
return [cholesky_val]
878878

879879
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
880-
inputs_or_none = _find_diag_from_eye_mul(inputs)
880+
inputs_or_none = _find_diag_from_eye_mul(input)
881881
if inputs_or_none is None:
882882
return None
883883

@@ -887,7 +887,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
887887
if len(non_eye_inputs) != 1:
888888
return None
889889

890-
non_eye_input = non_eye_inputs[0]
890+
[non_eye_input] = non_eye_inputs
891891

892892
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
893893
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those

0 commit comments

Comments
 (0)