Skip to content

Commit 426f0c0

Browse files
tamastokesricardoV94
authored andcommitted
pytensor-54: Removed yet another redundant check
1 parent 799722b commit 426f0c0

File tree

1 file changed

+72
-74
lines changed

1 file changed

+72
-74
lines changed

pytensor/tensor/rewriting/math.py

+72-74
Original file line numberDiff line numberDiff line change
@@ -431,38 +431,37 @@ def local_mul_exp_to_exp_add(fgraph, node):
431431
This rewrite detects e^x * e^y and converts it to e^(x+y).
432432
Similarly, e^x / e^y becomes e^(x-y).
433433
"""
434-
if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)):
435-
exps = [
436-
n.owner.inputs[0]
434+
exps = [
435+
n.owner.inputs[0]
436+
for n in node.inputs
437+
if n.owner
438+
and hasattr(n.owner.op, "scalar_op")
439+
and isinstance(n.owner.op.scalar_op, aes.Exp)
440+
]
441+
# Can only do any rewrite if there are at least two exp-s
442+
if len(exps) >= 2:
443+
# Mul -> add; TrueDiv -> sub
444+
orig_op, new_op = mul, add
445+
if isinstance(node.op.scalar_op, aes.TrueDiv):
446+
orig_op, new_op = true_div, sub
447+
new_out = exp(new_op(*exps))
448+
if new_out.dtype != node.outputs[0].dtype:
449+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
450+
# The original Mul may have more than two factors, some of which may not be exp nodes.
451+
# If so, we keep multiplying them with the new exp(sum) node.
452+
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
453+
rest = [
454+
n
437455
for n in node.inputs
438-
if n.owner
439-
and hasattr(n.owner.op, "scalar_op")
440-
and isinstance(n.owner.op.scalar_op, aes.Exp)
456+
if not n.owner
457+
or not hasattr(n.owner.op, "scalar_op")
458+
or not isinstance(n.owner.op.scalar_op, aes.Exp)
441459
]
442-
# Can only do any rewrite if there are at least two exp-s
443-
if len(exps) >= 2:
444-
# Mul -> add; TrueDiv -> sub
445-
orig_op, new_op = mul, add
446-
if isinstance(node.op.scalar_op, aes.TrueDiv):
447-
orig_op, new_op = true_div, sub
448-
new_out = exp(new_op(*exps))
460+
if len(rest) > 0:
461+
new_out = orig_op(new_out, *rest)
449462
if new_out.dtype != node.outputs[0].dtype:
450463
new_out = cast(new_out, dtype=node.outputs[0].dtype)
451-
# The original Mul may have more than two factors, some of which may not be exp nodes.
452-
# If so, we keep multiplying them with the new exp(sum) node.
453-
# E.g.: e^x * y * e^z * w --> e^(x+z) * y * w
454-
rest = [
455-
n
456-
for n in node.inputs
457-
if not n.owner
458-
or not hasattr(n.owner.op, "scalar_op")
459-
or not isinstance(n.owner.op.scalar_op, aes.Exp)
460-
]
461-
if len(rest) > 0:
462-
new_out = orig_op(new_out, *rest)
463-
if new_out.dtype != node.outputs[0].dtype:
464-
new_out = cast(new_out, dtype=node.outputs[0].dtype)
465-
return [new_out]
464+
return [new_out]
466465

467466

468467
@register_specialize
@@ -472,52 +471,51 @@ def local_mul_pow_to_pow_add(fgraph, node):
472471
This rewrite detects a^x * a^y and converts it to a^(x+y).
473472
Similarly, a^x / a^y becomes a^(x-y).
474473
"""
475-
if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)):
476-
# search for pow-s and group them by their bases
477-
pow_nodes = defaultdict(list)
478-
rest = []
479-
for n in node.inputs:
480-
if (
481-
n.owner
482-
and hasattr(n.owner.op, "scalar_op")
483-
and isinstance(n.owner.op.scalar_op, aes.Pow)
484-
):
485-
base_node = n.owner.inputs[0]
486-
# exponent is at n.owner.inputs[1], but we need to store the full node
487-
# in case this particular power node remains alone and can't be rewritten
488-
pow_nodes[base_node].append(n)
489-
else:
490-
rest.append(n)
491-
492-
# Can only do any rewrite if there are at least two pow-s with the same base
493-
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
494-
if len(can_rewrite) >= 1:
495-
# Mul -> add; TrueDiv -> sub
496-
orig_op, new_op = mul, add
497-
if isinstance(node.op.scalar_op, aes.TrueDiv):
498-
orig_op, new_op = true_div, sub
499-
pow_factors = []
500-
# Rewrite pow-s having the same base for each different base
501-
# E.g.: a^x * a^y --> a^(x+y)
502-
for base in can_rewrite:
503-
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
504-
new_node = base ** new_op(*exponents)
505-
if new_node.dtype != node.outputs[0].dtype:
506-
new_node = cast(new_node, dtype=node.outputs[0].dtype)
507-
pow_factors.append(new_node)
508-
# Don't forget about those sole pow-s that couldn't be rewriten
509-
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
510-
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
511-
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
512-
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
513-
new_out = orig_op(*pow_factors, *sole_pows, *rest)
514-
if new_out.dtype != node.outputs[0].dtype:
515-
new_out = cast(new_out, dtype=node.outputs[0].dtype)
516-
else:
517-
# if all factors of the original mul were pows-s with the same base,
518-
# we can get rid of the mul completely.
519-
new_out = pow_factors[0]
520-
return [new_out]
474+
# search for pow-s and group them by their bases
475+
pow_nodes = defaultdict(list)
476+
rest = []
477+
for n in node.inputs:
478+
if (
479+
n.owner
480+
and hasattr(n.owner.op, "scalar_op")
481+
and isinstance(n.owner.op.scalar_op, aes.Pow)
482+
):
483+
base_node = n.owner.inputs[0]
484+
# exponent is at n.owner.inputs[1], but we need to store the full node
485+
# in case this particular power node remains alone and can't be rewritten
486+
pow_nodes[base_node].append(n)
487+
else:
488+
rest.append(n)
489+
490+
# Can only do any rewrite if there are at least two pow-s with the same base
491+
can_rewrite = [k for k, v in pow_nodes.items() if len(v) >= 2]
492+
if len(can_rewrite) >= 1:
493+
# Mul -> add; TrueDiv -> sub
494+
orig_op, new_op = mul, add
495+
if isinstance(node.op.scalar_op, aes.TrueDiv):
496+
orig_op, new_op = true_div, sub
497+
pow_factors = []
498+
# Rewrite pow-s having the same base for each different base
499+
# E.g.: a^x * a^y --> a^(x+y)
500+
for base in can_rewrite:
501+
exponents = [n.owner.inputs[1] for n in pow_nodes[base]]
502+
new_node = base ** new_op(*exponents)
503+
if new_node.dtype != node.outputs[0].dtype:
504+
new_node = cast(new_node, dtype=node.outputs[0].dtype)
505+
pow_factors.append(new_node)
506+
# Don't forget about those sole pow-s that couldn't be rewriten
507+
sole_pows = [v[0] for k, v in pow_nodes.items() if k not in can_rewrite]
508+
# Combine the rewritten pow-s and other, non-pow factors of the original Mul
509+
# E.g.: a^x * y * b^z * a^w * v * b^t --> a^(x+z) * b^(z+t) * y * v
510+
if len(pow_factors) > 1 or len(sole_pows) > 0 or len(rest) > 0:
511+
new_out = orig_op(*pow_factors, *sole_pows, *rest)
512+
if new_out.dtype != node.outputs[0].dtype:
513+
new_out = cast(new_out, dtype=node.outputs[0].dtype)
514+
else:
515+
# if all factors of the original mul were pows-s with the same base,
516+
# we can get rid of the mul completely.
517+
new_out = pow_factors[0]
518+
return [new_out]
521519

522520

523521
@register_stabilize

0 commit comments

Comments
 (0)