Skip to content

Commit 8466acd

Browse files
tamastokesricardoV94
authored andcommitted
pytensor-54: Rewrite a^x * a^y to a^(x+y)
1 parent 28fdc86 commit 8466acd

File tree

2 files changed

+117
-0
lines changed

2 files changed

+117
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,65 @@ def local_mulexp2expadd(fgraph, node):
466466
return [new_out]
467467

468468

469+
@register_specialize
470+
@node_rewriter([mul, true_div])
471+
def local_mulpow2powadd(fgraph, node):
472+
"""
473+
This rewrite detects a^x * a^y and converts it to a^(x+y).
474+
Similarly, a^x / a^y becomes a^(x-y).
475+
"""
476+
if isinstance(node.op, Elemwise) and isinstance(
477+
node.op.scalar_op, (aes.Mul, aes.TrueDiv)
478+
):
479+
from collections import defaultdict
480+
481+
# search for pow-s and group them by their bases
482+
pow_nodes = defaultdict(list)
483+
rest = []
484+
for n in node.inputs:
485+
if (
486+
n.owner
487+
and hasattr(n.owner.op, "scalar_op")
488+
and isinstance(n.owner.op.scalar_op, aes.Pow)
489+
):
490+
base_node = n.owner.inputs[0]
491+
# exponent is at n.owner.inputs[1], but we need to store the full node
492+
# in case this particular power node remains alone and can't be rewritten
493+
pow_nodes[base_node].append(n)
494+
else:
495+
rest.append(n)
496+
497+
# Can only do any rewrite if there are at least two pow-s with the same base
498+
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
499+
if len(can_rewrite) >= 1:
500+
# Mul -> add; TrueDiv -> sub
501+
orig_op, new_op = mul, add
502+
if isinstance(node.op.scalar_op, aes.TrueDiv):
503+
orig_op, new_op = true_div, sub
504+
pow_factors = []
505+
# Rewrite pow-s having the same base for each different base
506+
# E.g.: a^x * a^y --> a^(x+y)
507+
for base in can_rewrite:
508+
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
509+
new_node = base ** new_op(*exponents)
510+
if new_node.dtype != node.outputs[0].dtype:
511+
new_node = cast(new_node, dtype=node.outputs[0].dtype)
512+
pow_factors.append(new_node)
513+
# Don't forget about those sole pow-s that couldn't be rewriten
514+
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
515+
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
516+
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
517+
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
518+
new_out = orig_op(*pow_factors, *sole_pows, *rest)
519+
if new_out.dtype != node.outputs[0].dtype:
520+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
521+
else:
522+
# if all factors of the original mul were pows-s with the same base,
523+
# we can get rid of the mul completely.
524+
new_out = pow_factors[0]
525+
return [new_out]
526+
527+
469528
@register_stabilize
470529
@register_specialize
471530
@register_canonicalize

tests/tensor/rewriting/test_math.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4105,6 +4105,64 @@ def test_local_mulexp2expadd():
41054105
assert isinstance(graph[0].inputs[0], TensorConstant)
41064106

41074107

4108+
def test_local_mulpow2powadd():
4109+
x = scalar("x")
4110+
y = scalar("y")
4111+
z = scalar("z")
4112+
w = scalar("w")
4113+
v = scalar("v")
4114+
u = scalar("u")
4115+
t = scalar("t")
4116+
s = scalar("s")
4117+
a = scalar("a")
4118+
b = scalar("b")
4119+
c = scalar("c")
4120+
4121+
# 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
4122+
op = 2**x * 2**y * 2**z * 2**w
4123+
f = function([x, y, z, w], op)
4124+
utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6))
4125+
graph = f.maker.fgraph.toposort()
4126+
assert isinstance(graph[0].op, Elemwise)
4127+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4128+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4129+
assert not any(isinstance(n.op, aes.Mul) for n in inner_graph)
4130+
4131+
# 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
4132+
op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t
4133+
f = function([x, y, z, w, v, u, t, s, a, b, c], op)
4134+
utt.assert_allclose(
4135+
f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5),
4136+
2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11,
4137+
)
4138+
graph = f.maker.fgraph.toposort()
4139+
assert isinstance(graph[0].op, Elemwise)
4140+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4141+
assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 3
4142+
assert len([True for n in inner_graph if isinstance(n.op, aes.Pow)]) == 4
4143+
assert any(isinstance(n.op, aes.Mul) for n in inner_graph)
4144+
4145+
# (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
4146+
op = 2**x / 2**y * (a**z / a**w)
4147+
f = function([x, y, z, w, a], op)
4148+
utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4))
4149+
graph = f.maker.fgraph.toposort()
4150+
assert isinstance(graph[0].op, Elemwise)
4151+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4152+
assert len([True for n in inner_graph if isinstance(n.op, aes.Sub)]) == 2
4153+
assert any(isinstance(n.op, aes.Mul) for n in inner_graph)
4154+
4155+
# a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
4156+
op = a**x * a**y * exp(z) * exp(w)
4157+
f = function([x, y, z, w, a], op)
4158+
utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6))
4159+
graph = f.maker.fgraph.toposort()
4160+
assert isinstance(graph[0].op, Elemwise)
4161+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4162+
assert len([True for n in inner_graph if isinstance(n.op, aes.Add)]) == 2
4163+
assert any(isinstance(n.op, aes.Mul) for n in inner_graph)
4164+
4165+
41084166
def test_local_expm1():
41094167
x = matrix("x")
41104168
u = scalar("u")

0 commit comments

Comments
 (0)