File tree 2 files changed +17
-2
lines changed
pytensor/tensor/rewriting
2 files changed +17
-2
lines changed Original file line number Diff line number Diff line change @@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
2071
2071
rval = [reciprocal (sqr (xsym ))]
2072
2072
if rval :
2073
2073
rval [0 ] = cast (rval [0 ], odtype )
2074
- assert rval [0 ].type == node .outputs [0 ].type , (rval , node .outputs )
2074
+ assert rval [0 ].type .is_super (node .outputs [0 ].type ), (
2075
+ rval [0 ].type ,
2076
+ node .outputs [0 ].type ,
2077
+ )
2075
2078
return rval
2076
2079
else :
2077
2080
return False
Original file line number Diff line number Diff line change 96
96
perform_sigm_times_exp ,
97
97
simplify_mul ,
98
98
)
99
- from pytensor .tensor .shape import Reshape , Shape_i
99
+ from pytensor .tensor .shape import Reshape , Shape_i , SpecifyShape
100
100
from pytensor .tensor .type import (
101
101
TensorType ,
102
102
cmatrix ,
@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
1671
1671
assert isinstance (nodes [1 ].scalar_op , aes .basic .Reciprocal )
1672
1672
utt .assert_allclose (f (val_no0 ), val_no0 ** (- 0.5 ))
1673
1673
1674
+ twos = np .full (shape = (10 ,), fill_value = 2.0 ).astype (config .floatX )
1675
+ f = function ([v ], v ** twos , mode = mode )
1676
+ topo = f .maker .fgraph .toposort ()
1677
+ assert len (topo ) == 2
1678
+ # Depending on the mode the SpecifyShape is lifted or not
1679
+ if topo [0 ].op == sqr :
1680
+ assert isinstance (topo [1 ].op , SpecifyShape )
1681
+ else :
1682
+ assert isinstance (topo [0 ].op , SpecifyShape )
1683
+ assert topo [1 ].op == sqr
1684
+ utt .assert_allclose (f (val ), val ** twos )
1685
+
1674
1686
1675
1687
def test_local_pow_to_nested_squaring ():
1676
1688
mode = config .mode
You can’t perform that action at this time.
0 commit comments