@@ -659,19 +659,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
659
659
if not isinstance (node .op .core_op , Cholesky ):
660
660
return None
661
661
662
- inputs = node .inputs [ 0 ]
662
+ [ input ] = node .inputs
663
663
# Check for use of pt.diag first
664
664
if (
665
- inputs .owner
666
- and isinstance (inputs .owner .op , AllocDiag )
667
- and AllocDiag .is_offset_zero (inputs .owner )
665
+ input .owner
666
+ and isinstance (input .owner .op , AllocDiag )
667
+ and AllocDiag .is_offset_zero (input .owner )
668
668
):
669
- diag_input = inputs .owner .inputs [0 ]
669
+ diag_input = input .owner .inputs [0 ]
670
670
cholesky_val = pt .diag (diag_input ** 0.5 )
671
671
return [cholesky_val ]
672
672
673
673
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
674
- inputs_or_none = _find_diag_from_eye_mul (inputs )
674
+ inputs_or_none = _find_diag_from_eye_mul (input )
675
675
if inputs_or_none is None :
676
676
return None
677
677
@@ -681,7 +681,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
681
681
if len (non_eye_inputs ) != 1 :
682
682
return None
683
683
684
- non_eye_input = non_eye_inputs [ 0 ]
684
+ [ non_eye_input ] = non_eye_inputs
685
685
686
686
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687
687
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
0 commit comments