Skip to content

Commit a52f8bc

Browse files
tamastokesricardoV94
authored andcommitted
pytensor-54: Rewrite products of exponents as exponent of sum. Rewrite e^x*e^y to e^(x+y), e^x/e^y to e^(x-y).
1 parent 958cd14 commit a52f8bc

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,35 @@ def local_sumsqr2dot(fgraph, node):
423423
return [new_out]
424424

425425

426+
@register_canonicalize
427+
@register_specialize
428+
@node_rewriter([Elemwise])
429+
def local_mulexp2expadd(fgraph, node):
430+
"""
431+
This rewrite detects e^x * e^y and converts it to e^(x+y).
432+
Similarly, e^x / e^y becomes e^(x-y).
433+
"""
434+
if (
435+
isinstance(node.op, Elemwise)
436+
and isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv))
437+
and node.inputs[0].owner
438+
and isinstance(node.inputs[0].owner.op, Elemwise)
439+
and isinstance(node.inputs[0].owner.op.scalar_op, aes.Exp)
440+
and node.inputs[1].owner
441+
and isinstance(node.inputs[1].owner.op, Elemwise)
442+
and isinstance(node.inputs[1].owner.op.scalar_op, aes.Exp)
443+
):
444+
input1 = node.inputs[0].owner.inputs[0]
445+
input2 = node.inputs[1].owner.inputs[0]
446+
if isinstance(node.op.scalar_op, aes.Mul):
447+
new_out = exp(input1 + input2)
448+
else: # TrueDiv
449+
new_out = exp(input1 - input2)
450+
if new_out.dtype != node.outputs[0].dtype:
451+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
452+
return [new_out]
453+
454+
426455
@register_stabilize
427456
@register_specialize
428457
@register_canonicalize

tests/tensor/rewriting/test_math.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4014,6 +4014,56 @@ def test_local_sumsqr2dot():
40144014
)
40154015

40164016

4017+
def test_local_mulexp2expadd():
4018+
# e^x * e^y = e^(x+y)
4019+
# test simple scalars first
4020+
x = scalar("x")
4021+
y = scalar("y")
4022+
expx = exp(x)
4023+
expy = exp(y)
4024+
expx_expy = expx * expy
4025+
f = function([x, y], expx_expy)
4026+
utt.assert_allclose(f(3, 4), np.exp(3 + 4))
4027+
graph = f.maker.fgraph.toposort()
4028+
assert isinstance(graph[0].op, Elemwise)
4029+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4030+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4031+
4032+
# expect same for matrices as well
4033+
mx = matrix("mx")
4034+
my = matrix("my")
4035+
f = function([mx, my], exp(mx) * exp(my))
4036+
M1 = np.array([[1.0, 2.0], [3.0, 4.0]])
4037+
M2 = np.array([[5.0, 6.0], [7.0, 8.0]])
4038+
utt.assert_allclose(f(M1, M2), np.exp(M1 + M2))
4039+
graph = f.maker.fgraph.toposort()
4040+
assert isinstance(graph[0].op, Elemwise)
4041+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4042+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4043+
4044+
# checking whether further rewrites can proceed after this one as one would expect
4045+
# e^x * e^(-x) = e^(x-x) = e^0 = 1
4046+
f = function([x], expx * exp(neg(x)))
4047+
graph = f.maker.fgraph.toposort()
4048+
assert isinstance(graph[0].inputs[0], TensorConstant)
4049+
utt.assert_allclose(f(42), 1)
4050+
4051+
# e^x / e^y = e^(x-y)
4052+
expx_div_expy = expx / expy
4053+
f = function([x, y], expx_div_expy)
4054+
utt.assert_allclose(f(5, 3), np.exp(5 - 3))
4055+
graph = f.maker.fgraph.toposort()
4056+
assert isinstance(graph[0].op, Elemwise)
4057+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4058+
assert any(isinstance(n.op, aes.Sub) for n in inner_graph)
4059+
4060+
# e^x / e^x = e^(x-x) = e^0 = 1
4061+
f = function([x], expx / expx)
4062+
graph = f.maker.fgraph.toposort()
4063+
assert isinstance(graph[0].inputs[0], TensorConstant)
4064+
utt.assert_allclose(f(42), 1)
4065+
4066+
40174067
def test_local_expm1():
40184068
x = matrix("x")
40194069
u = scalar("u")

0 commit comments

Comments
 (0)