Skip to content

Commit 799722b

Browse files
tamastokesricardoV94
authored andcommitted
pytensor-54: Rename functions according to naming conventions. Removed a redundant check. Moved import statement to top of file.
1 parent 8466acd commit 799722b

File tree

2 files changed

+7
-12
lines changed

2 files changed

+7
-12
lines changed

pytensor/tensor/rewriting/math.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import itertools
44
import operator
5+
from collections import defaultdict
56
from functools import partial, reduce
67

78
import numpy as np
@@ -425,14 +426,12 @@ def local_sumsqr2dot(fgraph, node):
425426

426427
@register_specialize
427428
@node_rewriter([mul, true_div])
428-
def local_mulexp2expadd(fgraph, node):
429+
def local_mul_exp_to_exp_add(fgraph, node):
429430
"""
430431
This rewrite detects e^x * e^y and converts it to e^(x+y).
431432
Similarly, e^x / e^y becomes e^(x-y).
432433
"""
433-
if isinstance(node.op, Elemwise) and isinstance(
434-
node.op.scalar_op, (aes.Mul, aes.TrueDiv)
435-
):
434+
if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)):
436435
exps = [
437436
n.owner.inputs[0]
438437
for n in node.inputs
@@ -468,16 +467,12 @@ def local_mulexp2expadd(fgraph, node):
468467

469468
@register_specialize
470469
@node_rewriter([mul, true_div])
471-
def local_mulpow2powadd(fgraph, node):
470+
def local_mul_pow_to_pow_add(fgraph, node):
472471
"""
473472
This rewrite detects a^x * a^y and converts it to a^(x+y).
474473
Similarly, a^x / a^y becomes a^(x-y).
475474
"""
476-
if isinstance(node.op, Elemwise) and isinstance(
477-
node.op.scalar_op, (aes.Mul, aes.TrueDiv)
478-
):
479-
from collections import defaultdict
480-
475+
if isinstance(node.op.scalar_op, (aes.Mul, aes.TrueDiv)):
481476
# search for pow-s and group them by their bases
482477
pow_nodes = defaultdict(list)
483478
rest = []

tests/tensor/rewriting/test_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4014,7 +4014,7 @@ def test_local_sumsqr2dot():
40144014
)
40154015

40164016

4017-
def test_local_mulexp2expadd():
4017+
def test_local_mul_exp_to_exp_add():
40184018
x = scalar("x")
40194019
y = scalar("y")
40204020
z = scalar("z")
@@ -4105,7 +4105,7 @@ def test_local_mulexp2expadd():
41054105
assert isinstance(graph[0].inputs[0], TensorConstant)
41064106

41074107

4108-
def test_local_mulpow2powadd():
4108+
def test_local_mul_pow_to_pow_add():
41094109
x = scalar("x")
41104110
y = scalar("y")
41114111
z = scalar("z")

0 commit comments

Comments
 (0)