Skip to content

Commit 5c87d74

Browse files
committed
Fix type check in local_pow_specialize
1 parent 7367e8d commit 5c87d74

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

pytensor/tensor/rewriting/math.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -2071,7 +2071,10 @@ def local_pow_specialize(fgraph, node):
20712071
rval = [reciprocal(sqr(xsym))]
20722072
if rval:
20732073
rval[0] = cast(rval[0], odtype)
2074-
assert rval[0].type == node.outputs[0].type, (rval, node.outputs)
2074+
assert rval[0].type.is_super(node.outputs[0].type), (
2075+
rval[0].type,
2076+
node.outputs[0].type,
2077+
)
20752078
return rval
20762079
else:
20772080
return False

tests/tensor/rewriting/test_math.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696
perform_sigm_times_exp,
9797
simplify_mul,
9898
)
99-
from pytensor.tensor.shape import Reshape, Shape_i
99+
from pytensor.tensor.shape import Reshape, Shape_i, SpecifyShape
100100
from pytensor.tensor.type import (
101101
TensorType,
102102
cmatrix,
@@ -1671,6 +1671,18 @@ def test_local_pow_specialize():
16711671
assert isinstance(nodes[1].scalar_op, aes.basic.Reciprocal)
16721672
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
16731673

1674+
twos = np.full(shape=(10,), fill_value=2.0).astype(config.floatX)
1675+
f = function([v], v**twos, mode=mode)
1676+
topo = f.maker.fgraph.toposort()
1677+
assert len(topo) == 2
1678+
# Depending on the mode the SpecifyShape is lifted or not
1679+
if topo[0].op == sqr:
1680+
assert isinstance(topo[1].op, SpecifyShape)
1681+
else:
1682+
assert isinstance(topo[0].op, SpecifyShape)
1683+
assert topo[1].op == sqr
1684+
utt.assert_allclose(f(val), val**twos)
1685+
16741686

16751687
def test_local_pow_to_nested_squaring():
16761688
mode = config.mode

0 commit comments

Comments
 (0)