Skip to content

Commit 28fdc86

Browse files
tamastokesricardoV94
authored andcommitted
pytensor-54: Handle properly the scenarios where a Mul node has more than two factors with some of which may not be an exp
1 parent a52f8bc commit 28fdc86

File tree

2 files changed

+91
-36
lines changed

2 files changed

+91
-36
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -423,33 +423,47 @@ def local_sumsqr2dot(fgraph, node):
423423
return [new_out]
424424

425425

426-
@register_canonicalize
427426
@register_specialize
428-
@node_rewriter([Elemwise])
427+
@node_rewriter([mul, true_div])
429428
def local_mulexp2expadd(fgraph, node):
430429
"""
431430
This rewrite detects e^x * e^y and converts it to e^(x+y).
432431
Similarly, e^x / e^y becomes e^(x-y).
433432
"""
434-
if (
435-
isinstance(node.op, Elemwise)
436-
and isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv))
437-
and node.inputs[0].owner
438-
and isinstance(node.inputs[0].owner.op, Elemwise)
439-
and isinstance(node.inputs[0].owner.op.scalar_op, aes.Exp)
440-
and node.inputs[1].owner
441-
and isinstance(node.inputs[1].owner.op, Elemwise)
442-
and isinstance(node.inputs[1].owner.op.scalar_op, aes.Exp)
433+
if isinstance(node.op, Elemwise) and isinstance(
434+
node.op.scalar_op, (aes.Mul, aes.TrueDiv)
443435
):
444-
input1 = node.inputs[0].owner.inputs[0]
445-
input2 = node.inputs[1].owner.inputs[0]
446-
if isinstance(node.op.scalar_op, aes.Mul):
447-
new_out = exp(input1 + input2)
448-
else: # TrueDiv
449-
new_out = exp(input1 - input2)
450-
if new_out.dtype != node.outputs[0].dtype:
451-
new_out = cast(new_out, dtype=node.outputs[0].dtype)
452-
return [new_out]
436+
exps = [
437+
n.owner.inputs[0]
438+
for n in node.inputs
439+
if n.owner
440+
and hasattr(n.owner.op, "scalar_op")
441+
and isinstance(n.owner.op.scalar_op, aes.Exp)
442+
]
443+
# Can only do any rewrite if there are at least two exp-s
444+
if len(exps) >= 2:
445+
# Mul -> add; TrueDiv -> sub
446+
orig_op, new_op = mul, add
447+
if isinstance(node.op.scalar_op, aes.TrueDiv):
448+
orig_op, new_op = true_div, sub
449+
new_out = exp(new_op(*exps))
450+
if new_out.dtype != node.outputs[0].dtype:
451+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
452+
# The original Mul may have more than two factors, some of which may not be exp nodes.
453+
# If so, we keep multiplying them with the new exp(sum) node.
454+
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
455+
rest = [
456+
n
457+
for n in node.inputs
458+
if not n.owner
459+
or not hasattr(n.owner.op, "scalar_op")
460+
or not isinstance(n.owner.op.scalar_op, aes.Exp)
461+
]
462+
if len(rest) > 0:
463+
new_out = orig_op(new_out, *rest)
464+
if new_out.dtype != node.outputs[0].dtype:
465+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
466+
return [new_out]
453467

454468

455469
@register_stabilize

tests/tensor/rewriting/test_math.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4015,19 +4015,68 @@ def test_local_sumsqr2dot():
40154015

40164016

40174017
def test_local_mulexp2expadd():
4018-
# e^x * e^y = e^(x+y)
4019-
# test simple scalars first
40204018
x = scalar("x")
40214019
y = scalar("y")
4020+
z = scalar("z")
4021+
w = scalar("w")
40224022
expx = exp(x)
40234023
expy = exp(y)
4024-
expx_expy = expx * expy
4025-
f = function([x, y], expx_expy)
4026-
utt.assert_allclose(f(3, 4), np.exp(3 + 4))
4024+
expz = exp(z)
4025+
expw = exp(w)
4026+
4027+
# e^x * e^y * e^z * e^w = e^(x+y+z+w)
4028+
op = expx * expy * expz * expw
4029+
f = function([x, y, z, w], op)
4030+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 + 6))
4031+
graph = f.maker.fgraph.toposort()
4032+
assert isinstance(graph[0].op, Elemwise)
4033+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4034+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4035+
assert not any(isinstance(n.op, aes.Mul) for n in inner_graph)
4036+
4037+
# e^x * e^y * e^z / e^w = e^(x+y+z-w)
4038+
op = expx * expy * expz / expw
4039+
f = function([x, y, z, w], op)
4040+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 + 5 - 6))
4041+
graph = f.maker.fgraph.toposort()
4042+
assert isinstance(graph[0].op, Elemwise)
4043+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4044+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4045+
assert any(isinstance(n.op, aes.Sub) for n in inner_graph)
4046+
assert not any(isinstance(n.op, aes.Mul) for n in inner_graph)
4047+
assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph)
4048+
4049+
# e^x * e^y / e^z * e^w = e^(x+y-z+w)
4050+
op = expx * expy / expz * expw
4051+
f = function([x, y, z, w], op)
4052+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 4 - 5 + 6))
4053+
graph = f.maker.fgraph.toposort()
4054+
assert isinstance(graph[0].op, Elemwise)
4055+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4056+
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4057+
assert any(isinstance(n.op, aes.Sub) for n in inner_graph)
4058+
assert not any(isinstance(n.op, aes.Mul) for n in inner_graph)
4059+
assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph)
4060+
4061+
# e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
4062+
op = expx / expy / expz
4063+
f = function([x, y, z], op)
4064+
utt.assert_allclose(f(3, 4, 5), np.exp(3 - 4 - 5))
4065+
graph = f.maker.fgraph.toposort()
4066+
assert isinstance(graph[0].op, Elemwise)
4067+
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4068+
assert any(isinstance(n.op, aes.Sub) for n in inner_graph)
4069+
assert not any(isinstance(n.op, aes.TrueDiv) for n in inner_graph)
4070+
4071+
# e^x * y * e^z * w = e^(x+z) * y * w
4072+
op = expx * y * expz * w
4073+
f = function([x, y, z, w], op)
4074+
utt.assert_allclose(f(3, 4, 5, 6), np.exp(3 + 5) * 4 * 6)
40274075
graph = f.maker.fgraph.toposort()
40284076
assert isinstance(graph[0].op, Elemwise)
40294077
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
40304078
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4079+
assert any(isinstance(n.op, aes.Mul) for n in inner_graph)
40314080

40324081
# expect same for matrices as well
40334082
mx = matrix("mx")
@@ -4040,28 +4089,20 @@ def test_local_mulexp2expadd():
40404089
assert isinstance(graph[0].op, Elemwise)
40414090
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
40424091
assert any(isinstance(n.op, aes.Add) for n in inner_graph)
4092+
assert not any(isinstance(n.op, aes.Mul) for n in inner_graph)
40434093

40444094
# checking whether further rewrites can proceed after this one as one would expect
40454095
# e^x * e^(-x) = e^(x-x) = e^0 = 1
40464096
f = function([x], expx * exp(neg(x)))
4047-
graph = f.maker.fgraph.toposort()
4048-
assert isinstance(graph[0].inputs[0], TensorConstant)
40494097
utt.assert_allclose(f(42), 1)
4050-
4051-
# e^x / e^y = e^(x-y)
4052-
expx_div_expy = expx / expy
4053-
f = function([x, y], expx_div_expy)
4054-
utt.assert_allclose(f(5, 3), np.exp(5 - 3))
40554098
graph = f.maker.fgraph.toposort()
4056-
assert isinstance(graph[0].op, Elemwise)
4057-
inner_graph = graph[0].op.scalar_op.fgraph.toposort()
4058-
assert any(isinstance(n.op, aes.Sub) for n in inner_graph)
4099+
assert isinstance(graph[0].inputs[0], TensorConstant)
40594100

40604101
# e^x / e^x = e^(x-x) = e^0 = 1
40614102
f = function([x], expx / expx)
4103+
utt.assert_allclose(f(42), 1)
40624104
graph = f.maker.fgraph.toposort()
40634105
assert isinstance(graph[0].inputs[0], TensorConstant)
4064-
utt.assert_allclose(f(42), 1)
40654106

40664107

40674108
def test_local_expm1():

0 commit comments

Comments
 (0)