Skip to content

Commit b065112

Browse files
Add rewrite for 1 ** x = 1 (#1179)
1 parent 2f1d25a commit b065112

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

pytensor/tensor/rewriting/math.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -1905,13 +1905,40 @@ def local_reciprocal_canon(fgraph, node):
19051905
@register_canonicalize
19061906
@node_rewriter([pt_pow])
19071907
def local_pow_canonicalize(fgraph, node):
1908-
cst = get_underlying_scalar_constant_value(
1908+
"""
1909+
Rewrites for exponential functions with straight-forward simplifications:
1910+
1. x ** 0 -> 1
1911+
2. x ** 1 -> x
1912+
3. 1 ** x -> 1
1913+
1914+
In all cases, the shape of the output is the result of broadcasting the shapes of the inputs.
1915+
"""
1916+
cst_base = get_underlying_scalar_constant_value(
1917+
node.inputs[0], only_process_constants=True, raise_not_constant=False
1918+
)
1919+
cst_exponent = get_underlying_scalar_constant_value(
19091920
node.inputs[1], only_process_constants=True, raise_not_constant=False
19101921
)
1911-
if cst == 0:
1912-
return [alloc_like(1, node.outputs[0], fgraph)]
1913-
if cst == 1:
1914-
return [alloc_like(node.inputs[0], node.outputs[0], fgraph)]
1922+
1923+
new_out = None
1924+
1925+
if cst_base == 1:
1926+
# 1 ** x = 1
1927+
new_out = broadcast_arrays(*node.inputs)[0]
1928+
elif cst_exponent == 0:
1929+
# x ** 0 = 1
1930+
new_out = broadcast_arrays(ones_like(node.inputs[0]), node.inputs[1])[0]
1931+
elif cst_exponent == 1:
1932+
# x ** 1 = x
1933+
new_out = broadcast_arrays(*node.inputs)[0]
1934+
1935+
if not new_out:
1936+
return
1937+
1938+
if new_out.dtype != node.out.dtype:
1939+
new_out = cast(new_out, dtype=node.out.dtype)
1940+
1941+
return [new_out]
19151942

19161943

19171944
@register_specialize

tests/tensor/rewriting/test_math.py

+19
Original file line numberDiff line numberDiff line change
@@ -4571,3 +4571,22 @@ def test_log_kv_stabilization():
45714571
out.eval({x: 1000.0}, mode=mode),
45724572
-1003.2180912984705,
45734573
)
4574+
4575+
4576+
@pytest.mark.parametrize("shape", [(), (4, 5, 6)], ids=["scalar", "tensor"])
4577+
def test_pow_1_rewrite(shape):
4578+
x = pt.tensor("x", shape=shape)
4579+
z = 1**x
4580+
4581+
assert isinstance(z.owner.op, Elemwise) and isinstance(
4582+
z.owner.op.scalar_op, ps.basic.Pow
4583+
)
4584+
4585+
f = pytensor.function([x], z)
4586+
assert not any(
4587+
isinstance(node.op, Elemwise) and isinstance(node.op.scalar_op, ps.basic.Pow)
4588+
for node in f.maker.fgraph.toposort()
4589+
)
4590+
4591+
x_val = np.random.random(shape).astype(config.floatX)
4592+
np.testing.assert_allclose(z.eval({x: x_val}), f(x_val))

0 commit comments

Comments
 (0)