Skip to content

Commit e37497f

Browse files
authored
Rewrite products of exponents as exponent of sum (#186)
* Rewrite products of exponents as exponent of sum. Rewrite e^x*e^y to e^(x+y), e^x/e^y to e^(x-y). * Rewrite a^x * a^y to a^(x+y)
1 parent 5628ab1 commit e37497f

File tree

2 files changed

+250
-0
lines changed

2 files changed

+250
-0
lines changed

pytensor/tensor/rewriting/math.py

+95
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
import operator
5+
from collections import defaultdict
56
from functools import partial, reduce
67

78
import numpy as np
@@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node):
423424
return [new_out]
424425

425426

427+
@register_specialize
428+
@node_rewriter([mul, true_div])
429+
def local_mul_exp_to_exp_add(fgraph, node):
430+
"""
431+
This rewrite detects e^x * e^y and converts it to e^(x+y).
432+
Similarly, e^x / e^y becomes e^(x-y).
433+
"""
434+
exps = [
435+
n.owner.inputs[0]
436+
for n in node.inputs
437+
if n.owner
438+
and hasattr(n.owner.op, "scalar_op")
439+
and isinstance(n.owner.op.scalar_op, aes.Exp)
440+
]
441+
# Can only do any rewrite if there are at least two exp-s
442+
if len(exps) >= 2:
443+
# Mul -> add; TrueDiv -> sub
444+
orig_op, new_op = mul, add
445+
if isinstance(node.op.scalar_op, aes.TrueDiv):
446+
orig_op, new_op = true_div, sub
447+
new_out = exp(new_op(*exps))
448+
if new_out.dtype != node.outputs[0].dtype:
449+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
450+
# The original Mul may have more than two factors, some of which may not be exp nodes.
451+
# If so, we keep multiplying them with the new exp(sum) node.
452+
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
453+
rest = [
454+
n
455+
for n in node.inputs
456+
if not n.owner
457+
or not hasattr(n.owner.op, "scalar_op")
458+
or not isinstance(n.owner.op.scalar_op, aes.Exp)
459+
]
460+
if len(rest) > 0:
461+
new_out = orig_op(new_out, *rest)
462+
if new_out.dtype != node.outputs[0].dtype:
463+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
464+
return [new_out]
465+
466+
467+
@register_specialize
468+
@node_rewriter([mul, true_div])
469+
def local_mul_pow_to_pow_add(fgraph, node):
470+
"""
471+
This rewrite detects a^x * a^y and converts it to a^(x+y).
472+
Similarly, a^x / a^y becomes a^(x-y).
473+
"""
474+
# search for pow-s and group them by their bases
475+
pow_nodes = defaultdict(list)
476+
rest = []
477+
for n in node.inputs:
478+
if (
479+
n.owner
480+
and hasattr(n.owner.op, "scalar_op")
481+
and isinstance(n.owner.op.scalar_op, aes.Pow)
482+
):
483+
base_node = n.owner.inputs[0]
484+
# exponent is at n.owner.inputs[1], but we need to store the full node
485+
# in case this particular power node remains alone and can't be rewritten
486+
pow_nodes[base_node].append(n)
487+
else:
488+
rest.append(n)
489+
490+
# Can only do any rewrite if there are at least two pow-s with the same base
491+
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
492+
if len(can_rewrite) >= 1:
493+
# Mul -> add; TrueDiv -> sub
494+
orig_op, new_op = mul, add
495+
if isinstance(node.op.scalar_op, aes.TrueDiv):
496+
orig_op, new_op = true_div, sub
497+
pow_factors = []
498+
# Rewrite pow-s having the same base for each different base
499+
# E.g.: a^x * a^y --> a^(x+y)
500+
for base in can_rewrite:
501+
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
502+
new_node = base ** new_op(*exponents)
503+
if new_node.dtype != node.outputs[0].dtype:
504+
new_node = cast(new_node, dtype=node.outputs[0].dtype)
505+
pow_factors.append(new_node)
506+
# Don't forget about those sole pow-s that couldn't be rewriten
507+
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
508+
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
509+
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
510+
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
511+
new_out = orig_op(*pow_factors, *sole_pows, *rest)
512+
if new_out.dtype != node.outputs[0].dtype:
513+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
514+
else:
515+
# if all factors of the original mul were pows-s with the same base,
516+
# we can get rid of the mul completely.
517+
new_out = pow_factors[0]
518+
return [new_out]
519+
520+
426521
@register_stabilize
427522
@register_specialize
428523
@register_canonicalize

tests/tensor/rewriting/test_math.py

+155
Original file line numberDiff line numberDiff line change
@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
40144014
)
40154015

40164016

4017+
def test_local_mul_exp_to_exp_add():
4018+
# Default and FAST_RUN modes put a Composite op into the final graph,
4019+
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4020+
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4021+
mode = get_default_mode().excluding("fusion").including("local_mul_exp_to_exp_add")
4022+
4023+
x = scalar("x")
4024+
y = scalar("y")
4025+
z = scalar("z")
4026+
w = scalar("w")
4027+
expx = exp(x)
4028+
expy = exp(y)
4029+
expz = exp(z)
4030+
expw = exp(w)
4031+
4032+
# e^x * e^y * e^z * e^w = e^(x+y+z+w)
4033+
op = expx * expy * expz * expw
4034+
f = function([x, y, z, w], op, mode)
4035+
pytensor.dprint(f)
4036+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
4037+
graph = f.maker.fgraph.toposort()
4038+
assert all(isinstance(n.op, Elemwise) for n in graph)
4039+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4040+
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4041+
4042+
# e^x * e^y * e^z / e^w = e^(x+y+z-w)
4043+
op = expx * expy * expz / expw
4044+
f = function([x, y, z, w], op, mode)
4045+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6))
4046+
graph = f.maker.fgraph.toposort()
4047+
assert all(isinstance(n.op, Elemwise) for n in graph)
4048+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4049+
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
4050+
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4051+
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
4052+
4053+
# e^x * e^y / e^z * e^w = e^(x+y-z+w)
4054+
op = expx * expy / expz * expw
4055+
f = function([x, y, z, w], op, mode)
4056+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6))
4057+
graph = f.maker.fgraph.toposort()
4058+
assert all(isinstance(n.op, Elemwise) for n in graph)
4059+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4060+
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
4061+
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4062+
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
4063+
4064+
# e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
4065+
op = expx / expy / expz
4066+
f = function([x, y, z], op, mode)
4067+
utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5))
4068+
graph = f.maker.fgraph.toposort()
4069+
assert all(isinstance(n.op, Elemwise) for n in graph)
4070+
assert any(isinstance(n.op.scalar_op, aes.Sub) for n in graph)
4071+
assert not any(isinstance(n.op.scalar_op, aes.TrueDiv) for n in graph)
4072+
4073+
# e^x * y * e^z * w = e^(x+z) * y * w
4074+
op = expx * y * expz * w
4075+
f = function([x, y, z, w], op, mode)
4076+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6)
4077+
graph = f.maker.fgraph.toposort()
4078+
assert all(isinstance(n.op, Elemwise) for n in graph)
4079+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4080+
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4081+
4082+
# expect same for matrices as well
4083+
mx = matrix("mx")
4084+
my = matrix("my")
4085+
f = function([mx, my], exp(mx) * exp(my), mode, allow_input_downcast=True)
4086+
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
4087+
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
4088+
utt.assert_allclose(f(M1, M2), np.exp(M1 + M2))
4089+
graph = f.maker.fgraph.toposort()
4090+
assert all(isinstance(n.op, Elemwise) for n in graph)
4091+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4092+
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4093+
4094+
# checking whether further rewrites can proceed after this one as one would expect
4095+
# e^x * e^(-x) = e^(x-x) = e^0 = 1
4096+
f = function([x], expx * exp(neg(x)), mode)
4097+
utt.assert_allclose(f(42), 1)
4098+
graph = f.maker.fgraph.toposort()
4099+
assert isinstance(graph[0].inputs[0], TensorConstant)
4100+
4101+
# e^x / e^x = e^(x-x) = e^0 = 1
4102+
f = function([x], expx / expx, mode)
4103+
utt.assert_allclose(f(42), 1)
4104+
graph = f.maker.fgraph.toposort()
4105+
assert isinstance(graph[0].inputs[0], TensorConstant)
4106+
4107+
4108+
def test_local_mul_pow_to_pow_add():
4109+
# Default and FAST_RUN modes put a Composite op into the final graph,
4110+
# whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4111+
# we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4112+
mode = (
4113+
get_default_mode()
4114+
.excluding("fusion")
4115+
.including("local_mul_exp_to_exp_add")
4116+
.including("local_mul_pow_to_pow_add")
4117+
)
4118+
4119+
x = scalar("x")
4120+
y = scalar("y")
4121+
z = scalar("z")
4122+
w = scalar("w")
4123+
v = scalar("v")
4124+
u = scalar("u")
4125+
t = scalar("t")
4126+
s = scalar("s")
4127+
a = scalar("a")
4128+
b = scalar("b")
4129+
c = scalar("c")
4130+
4131+
# 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
4132+
op = 2**x * 2**y * 2**z * 2**w
4133+
f = function([x, y, z, w], op, mode)
4134+
utt.assert_allclose(f(3, 4, 5, 6), 2 ** (3 + 4 + 5 + 6))
4135+
graph = f.maker.fgraph.toposort()
4136+
assert all(isinstance(n.op, Elemwise) for n in graph)
4137+
assert any(isinstance(n.op.scalar_op, aes.Add) for n in graph)
4138+
assert not any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4139+
4140+
# 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
4141+
op = 2**x * a**y * 2**z * b**w * c**v * a**u * s * b**t
4142+
f = function([x, y, z, w, v, u, t, s, a, b, c], op, mode)
4143+
utt.assert_allclose(
4144+
f(4, 5, 6, 7, 8, 9, 10, 11, 2.5, 3, 3.5),
4145+
2 ** (4 + 6) * 2.5 ** (5 + 9) * 3 ** (7 + 10) * 3.5**8 * 11,
4146+
)
4147+
graph = f.maker.fgraph.toposort()
4148+
assert all(isinstance(n.op, Elemwise) for n in graph)
4149+
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 3
4150+
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Pow)]) == 4
4151+
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4152+
4153+
# (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
4154+
op = 2**x / 2**y * (a**z / a**w)
4155+
f = function([x, y, z, w, a], op, mode)
4156+
utt.assert_allclose(f(3, 5, 6, 4, 7), 2 ** (3 - 5) * 7 ** (6 - 4))
4157+
graph = f.maker.fgraph.toposort()
4158+
assert all(isinstance(n.op, Elemwise) for n in graph)
4159+
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Sub)]) == 2
4160+
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4161+
4162+
# a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
4163+
op = a**x * a**y * exp(z) * exp(w)
4164+
f = function([x, y, z, w, a], op, mode)
4165+
utt.assert_allclose(f(3, 4, 5, 6, 2), 2 ** (3 + 4) * np.exp(5 + 6))
4166+
graph = f.maker.fgraph.toposort()
4167+
assert all(isinstance(n.op, Elemwise) for n in graph)
4168+
assert len([True for n in graph if isinstance(n.op.scalar_op, aes.Add)]) == 2
4169+
assert any(isinstance(n.op.scalar_op, aes.Mul) for n in graph)
4170+
4171+
40174172
def test_local_expm1():
40184173
x = matrix("x")
40194174
u = scalar("u")

0 commit comments

Comments
 (0)