Skip to content

Commit a547296

Browse files
committed
Remove tensor__local_elemwise_fusion config.
Same behavior can be obtained with `optimizer_excluding` The `local_careduce_rewrite` is now included in this database. Otherwise it would usually not be applied because it ran before the fusion rewrites
1 parent b30205f commit a547296

File tree

3 files changed

+69
-72
lines changed

3 files changed

+69
-72
lines changed

pytensor/configdefaults.py

-10
Original file line numberDiff line numberDiff line change
@@ -640,16 +640,6 @@ def add_tensor_configvars():
640640
in_c_key=False,
641641
)
642642

643-
config.add(
644-
"tensor__local_elemwise_fusion",
645-
(
646-
"Enable or not in fast_run mode(fast_run optimization) the elemwise "
647-
"fusion optimization"
648-
),
649-
BoolParam(True),
650-
in_c_key=False,
651-
)
652-
653643
# http://developer.amd.com/CPU/LIBRARIES/LIBM/Pages/default.aspx
654644
config.add(
655645
"lib__amblibm",

pytensor/tensor/rewriting/elemwise.py

+36-30
Original file line numberDiff line numberDiff line change
@@ -1085,38 +1085,10 @@ def print_profile(stream, prof, level=0):
10851085
print(blanc, " time_toposort", prof[7], file=stream)
10861086

10871087

1088-
if config.tensor__local_elemwise_fusion:
1089-
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
1090-
fuse_seqopt = SequenceDB()
1091-
fuse_seqopt.register(
1092-
"local_add_mul_fusion",
1093-
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
1094-
"fast_run",
1095-
"fusion",
1096-
position=0,
1097-
)
1098-
fuse_seqopt.register(
1099-
"composite_elemwise_fusion",
1100-
FusionOptimizer(),
1101-
"fast_run",
1102-
"fusion",
1103-
position=1,
1104-
)
1105-
compile.optdb.register(
1106-
"elemwise_fusion",
1107-
fuse_seqopt,
1108-
"fast_run",
1109-
"fusion",
1110-
"local_elemwise_fusion",
1111-
"FusionOptimizer",
1112-
position=49,
1113-
)
1114-
1115-
11161088
@register_canonicalize
11171089
@register_specialize
11181090
@node_rewriter([Elemwise])
1119-
def local_useless_composite(fgraph, node):
1091+
def local_useless_composite_outputs(fgraph, node):
11201092
"""Remove inputs and outputs of Composite Ops that are not used anywhere."""
11211093
if not isinstance(node.op, Elemwise) or not isinstance(
11221094
node.op.scalar_op, aes.Composite
@@ -1231,11 +1203,45 @@ def local_careduce_fusion(fgraph, node):
12311203
return [new_car_op(*elm_inputs)]
12321204

12331205

1206+
# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
1207+
fuse_seqopt = SequenceDB()
12341208
compile.optdb.register(
1209+
"elemwise_fusion",
1210+
fuse_seqopt,
1211+
"fast_run",
1212+
"fusion",
1213+
"local_elemwise_fusion",
1214+
"FusionOptimizer",
1215+
position=49,
1216+
)
1217+
1218+
fuse_seqopt.register(
1219+
"local_add_mul_fusion",
1220+
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
1221+
"fast_run",
1222+
"fusion",
1223+
position=0,
1224+
)
1225+
fuse_seqopt.register(
1226+
"composite_elemwise_fusion",
1227+
FusionOptimizer(),
1228+
"fast_run",
1229+
"fusion",
1230+
position=1,
1231+
)
1232+
fuse_seqopt.register(
1233+
"local_useless_composite_outputs",
1234+
in2out(local_useless_composite_outputs),
1235+
"fast_run",
1236+
"fusion",
1237+
position=2,
1238+
)
1239+
fuse_seqopt.register(
12351240
"local_careduce_fusion",
12361241
in2out(local_careduce_fusion),
1242+
"fast_run",
12371243
"fusion",
1238-
position=49,
1244+
position=10,
12391245
)
12401246

12411247

tests/tensor/rewriting/test_elemwise.py

+33-32
Original file line numberDiff line numberDiff line change
@@ -1423,39 +1423,40 @@ def test_nested_composite(self):
14231423
fval = f([1, 2, 3])
14241424
assert np.all(fval == [6, 12, 18])
14251425

1426-
def test_local_useless_composite(self):
1427-
x = aes.float32()
1428-
y = aes.float32()
1429-
z = aes.float32()
1430-
c = aes.Composite([x, y, z], [x + 1, y - 1])
1431-
X = matrix("X")
1432-
Y = matrix("Y")
1433-
Z = matrix("Z")
1434-
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
1435-
mode = get_default_mode().including("local_useless_composite")
1436-
1437-
f = function([X, Y, Z], [o1, o2], mode=mode)
1438-
topo = f.maker.fgraph.toposort()
1439-
assert len(topo) == 1
1440-
assert len(topo[0].inputs) == 2
1441-
assert len(topo[0].outputs) == 2
1442-
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
1443-
utt.assert_allclose(res1, [[2.0]])
1444-
utt.assert_allclose(res2, [[0.0]])
1445-
1446-
f = function([X, Y, Z], o1, mode=mode)
1447-
topo = f.maker.fgraph.toposort()
1448-
assert len(topo) == 1
1449-
assert len(topo[0].inputs) == 1
1450-
assert len(topo[0].outputs) == 1
1451-
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
14521426

1453-
f = function([X, Y, Z], o2, mode=mode)
1454-
topo = f.maker.fgraph.toposort()
1455-
assert len(topo) == 1
1456-
assert len(topo[0].inputs) == 1
1457-
assert len(topo[0].outputs) == 1
1458-
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
1427+
def test_local_useless_composite_outputs():
1428+
x = aes.float32()
1429+
y = aes.float32()
1430+
z = aes.float32()
1431+
c = aes.Composite([x, y, z], [x + 1, y - 1])
1432+
X = matrix("X")
1433+
Y = matrix("Y")
1434+
Z = matrix("Z")
1435+
o1, o2 = Elemwise(scalar_op=c)(X, Y, Z)
1436+
mode = get_default_mode().including("local_useless_composite")
1437+
1438+
f = function([X, Y, Z], [o1, o2], mode=mode)
1439+
topo = f.maker.fgraph.toposort()
1440+
assert len(topo) == 1
1441+
assert len(topo[0].inputs) == 2
1442+
assert len(topo[0].outputs) == 2
1443+
res1, res2 = f([[1.0]], [[1.0]], [[np.nan]])
1444+
utt.assert_allclose(res1, [[2.0]])
1445+
utt.assert_allclose(res2, [[0.0]])
1446+
1447+
f = function([X, Y, Z], o1, mode=mode)
1448+
topo = f.maker.fgraph.toposort()
1449+
assert len(topo) == 1
1450+
assert len(topo[0].inputs) == 1
1451+
assert len(topo[0].outputs) == 1
1452+
utt.assert_allclose(f([[1.0]], [[np.nan]], [[np.nan]]), [[2.0]])
1453+
1454+
f = function([X, Y, Z], o2, mode=mode)
1455+
topo = f.maker.fgraph.toposort()
1456+
assert len(topo) == 1
1457+
assert len(topo[0].inputs) == 1
1458+
assert len(topo[0].outputs) == 1
1459+
utt.assert_allclose(f([[np.nan]], [[1.0]], [[np.nan]]), [[0.0]])
14591460

14601461

14611462
def test_local_useless_dimshuffle_makevector():

0 commit comments

Comments
 (0)