@@ -617,7 +617,7 @@ def rewrite_inv_inv(fgraph, node):
617
617
@register_canonicalize
618
618
@register_stabilize
619
619
@node_rewriter ([Blockwise ])
620
- def rewrite_cholesky_eye_to_eye (fgraph , node ):
620
+ def rewrite_remove_useless_cholesky (fgraph , node ):
621
621
"""
622
622
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
623
623
@@ -640,14 +640,15 @@ def rewrite_cholesky_eye_to_eye(fgraph, node):
640
640
return None
641
641
642
642
# Check whether input to Cholesky is Eye and the 1's are on main diagonal
643
- eye_check = node .inputs [0 ]
643
+ potential_eye = node .inputs [0 ]
644
644
if not (
645
- eye_check .owner
646
- and isinstance (eye_check .owner .op , Eye )
647
- and getattr (eye_check .owner .inputs [- 1 ], "data" , - 1 ).item () == 0
645
+ potential_eye .owner
646
+ and isinstance (potential_eye .owner .op , Eye )
647
+ and hasattr (potential_eye .owner .inputs [- 1 ], "data" )
648
+ and potential_eye .owner .inputs [- 1 ].data .item () == 0
648
649
):
649
650
return None
650
- return [eye_check ]
651
+ return [potential_eye ]
651
652
652
653
653
654
@register_canonicalize
@@ -665,10 +666,9 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
665
666
and isinstance (inputs .owner .op , AllocDiag )
666
667
and AllocDiag .is_offset_zero (inputs .owner )
667
668
):
668
- cholesky_input = inputs .owner .inputs [0 ]
669
- if cholesky_input .type .ndim == 1 :
670
- cholesky_val = pt .diag (cholesky_input ** 0.5 )
671
- return [cholesky_val ]
669
+ diag_input = inputs .owner .inputs [0 ]
670
+ cholesky_val = pt .diag (diag_input ** 0.5 )
671
+ return [cholesky_val ]
672
672
673
673
# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
674
674
inputs_or_none = _find_diag_from_eye_mul (inputs )
@@ -686,8 +686,6 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
686
686
# Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
687
687
# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
688
688
if non_eye_input .type .broadcastable [- 2 :] == (False , False ):
689
- # For Matrix
690
- return [eye_input * (non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 ) ** 0.5 )]
691
- else :
692
- # For Vector or Scalar
693
- return [eye_input * (non_eye_input ** 0.5 )]
689
+ non_eye_input = non_eye_input .diagonal (axis1 = - 1 , axis2 = - 2 )
690
+
691
+ return [eye_input * (non_eye_input ** 0.5 )]
0 commit comments