@@ -865,19 +865,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
865
865
if not isinstance (node .op .core_op , Cholesky ):
866
866
return None
867
867
868
- inputs = node .inputs [ 0 ]
868
+ [ input ] = node .inputs
869
869
# Check for use of pt.diag first
870
870
if (
871
- inputs .owner
872
- and isinstance (inputs .owner .op , AllocDiag )
873
- and AllocDiag .is_offset_zero (inputs .owner )
871
+ input .owner
872
+ and isinstance (input .owner .op , AllocDiag )
873
+ and AllocDiag .is_offset_zero (input .owner )
874
874
):
875
- diag_input = inputs .owner .inputs [0 ]
875
+ diag_input = input .owner .inputs [0 ]
876
876
cholesky_val = pt .diag (diag_input ** 0.5 )
877
877
return [cholesky_val ]
878
878
879
879
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
880
- inputs_or_none = _find_diag_from_eye_mul (inputs )
880
+ inputs_or_none = _find_diag_from_eye_mul (input )
881
881
if inputs_or_none is None :
882
882
return None
883
883
@@ -887,7 +887,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
887
887
if len (non_eye_inputs ) != 1 :
888
888
return None
889
889
890
- non_eye_input = non_eye_inputs [ 0 ]
890
+ [ non_eye_input ] = non_eye_inputs
891
891
892
892
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
893
893
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
0 commit comments