Skip to content

Commit 00b5b90

Browse files
committed
Add direct test for nested broadcasted Composite graphs
1 parent a2f101a commit 00b5b90

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

tests/tensor/rewriting/test_elemwise.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,37 @@ def test_CAReduce_multiple_inputs(self, linker, axis):
11671167
assert out_val.shape == exp_res.shape
11681168
assert np.allclose(out_val, exp_res)
11691169

1170+
def test_not_fusing_broadcasted_subgraphs(self):
1171+
"""Test that broadcasted Elemwise subgraphs are not fused in a single Elemwise Composite Op.
1172+
1173+
There are some cases in self.test_elemwise_fusion, but this test confirms that the
1174+
fused subgraphs are exactly the expected ones.
1175+
"""
1176+
xs = vector("xm")
1177+
xm = matrix("xs")
1178+
1179+
es = log(xs + 5)
1180+
em = exp(xm * 5)
1181+
esm = es - em
1182+
1183+
f = pytensor.function([xs, xm], esm, mode=self.mode)
1184+
apply_nodes = f.maker.fgraph.toposort()
1185+
assert len(apply_nodes) == 3
1186+
assert isinstance(apply_nodes[0].op, DimShuffle)
1187+
# Inner Vector output Composite
1188+
assert isinstance(apply_nodes[1].op.scalar_op, Composite)
1189+
assert {node.op for node in apply_nodes[1].op.scalar_op.fgraph.apply_nodes} == {
1190+
aes.add,
1191+
aes.log,
1192+
}
1193+
# Outer Matrix output Composite
1194+
assert isinstance(apply_nodes[2].op.scalar_op, Composite)
1195+
assert {node.op for node in apply_nodes[2].op.scalar_op.fgraph.apply_nodes} == {
1196+
aes.sub,
1197+
aes.exp,
1198+
aes.mul,
1199+
}
1200+
11701201

11711202
class TimesN(aes.basic.UnaryScalarOp):
11721203
"""

0 commit comments

Comments
 (0)