@@ -594,8 +594,8 @@ def test_cholesky_eye_rewrite():
594
594
595
595
@pytest .mark .parametrize (
596
596
"shape" ,
597
- [(), (7 ,), (7 , 7 )],
598
- ids = ["scalar" , "vector" , "matrix" ],
597
+ [(), (7 ,), (7 , 7 ), ( 5 , 7 , 7 ) ],
598
+ ids = ["scalar" , "vector" , "matrix" , "batched" ],
599
599
)
600
600
def test_cholesky_diag_from_eye_mul (shape ):
601
601
# Initializing x based on scalar/vector/matrix
@@ -653,13 +653,21 @@ def test_cholesky_diag_from_diag():
653
653
)
654
654
655
655
656
- def test_dont_apply_cholesky ():
656
+ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied ():
657
+ # Case 1 : y is not a diagonal matrix because of k = -1
657
658
x = pt .tensor ("x" , shape = (7 , 7 ))
658
659
y = pt .eye (7 , k = - 1 ) * x
659
- # Here, y is not a diagonal matrix because of k = -1
660
660
z_cholesky = pt .linalg .cholesky (y )
661
661
662
662
# REWRITE TEST (should not be applied)
663
663
f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
664
664
nodes = f_rewritten .maker .fgraph .apply_nodes
665
665
assert any (isinstance (node .op , Cholesky ) for node in nodes )
666
+
667
+ # Case 2 : eye is degenerate
668
+ x = pt .scalar ("x" )
669
+ y = pt .eye (1 ) * x
670
+ z_cholesky = pt .linalg .cholesky (y )
671
+ f_rewritten = function ([x ], z_cholesky , mode = "FAST_RUN" )
672
+ nodes = f_rewritten .maker .fgraph .apply_nodes
673
+ assert any (isinstance (node .op , Cholesky ) for node in nodes )
0 commit comments