Skip to content

Commit b30205f

Browse files
committed
Fix local_careduce_fusion rewrite
1 parent 448be82 commit b30205f

File tree

2 files changed

+51
-22
lines changed

2 files changed

+51
-22
lines changed

pytensor/tensor/rewriting/elemwise.py

+31-16
Original file line numberDiff line numberDiff line change
@@ -1150,11 +1150,20 @@ def local_careduce_fusion(fgraph, node):
11501150
"""Fuse a `CAReduce` applied to an `Elemwise`."""
11511151

11521152
(car_input,) = node.inputs
1153+
car_scalar_op = node.op.scalar_op
1154+
1155+
# FIXME: This check is needed because of the faulty logic in the FIXME below!
1156+
# Right now, rewrite only works for `Sum`/`Prod`
1157+
if not isinstance(car_scalar_op, (aes.Add, aes.Mul)):
1158+
return None
1159+
11531160
elm_node = car_input.owner
11541161

11551162
if elm_node is None or not isinstance(elm_node.op, Elemwise):
11561163
return False
11571164

1165+
elm_scalar_op = elm_node.op.scalar_op
1166+
11581167
elm_inputs = elm_node.inputs
11591168
elm_outputs = elm_node.outputs
11601169

@@ -1166,21 +1175,15 @@ def local_careduce_fusion(fgraph, node):
11661175
return False
11671176

11681177
# Don't form the fusion when the target language is Python
1169-
elm_scalar_op = elm_node.op.scalar_op
1170-
car_scalar_op = node.op.scalar_op
1171-
11721178
if get_target_language() == ("py",):
11731179
return False
11741180

1175-
try:
1176-
elm_scalar_op.c_code(
1177-
elm_node,
1178-
"test_presence_of_c_code",
1179-
["x" for x in elm_inputs],
1180-
["z" for z in elm_outputs],
1181-
{"fail": "%(fail)s"},
1182-
)
1181+
if not elm_scalar_op.supports_c_code(elm_inputs, elm_outputs):
1182+
return None
11831183

1184+
# FIXME: This fails with Ops like `Max` whose `c_code` always expects two inputs!
1185+
# Should implement a `CAReduce.supports_c_code`?
1186+
try:
11841187
car_scalar_op.c_code(
11851188
node,
11861189
"test_presence_of_c_code",
@@ -1191,18 +1194,24 @@ def local_careduce_fusion(fgraph, node):
11911194
except (NotImplementedError, MethodNotDefined):
11921195
return False
11931196

1194-
car_axis = node.op.axis
1197+
car_op = node.op
1198+
car_acc_dtype = node.op.acc_dtype
11951199

11961200
scalar_elm_inputs = [
11971201
aes.get_scalar_type(inp.type.dtype).make_variable() for inp in elm_inputs
11981202
]
1203+
11991204
elm_output = elm_scalar_op(*scalar_elm_inputs)
1205+
12001206
# This input represents the previous value in the `CAReduce` binary reduction
1201-
carried_car_input = elm_output.type()
1202-
scalar_fused_outputs = [car_scalar_op(carried_car_input, elm_output)]
1207+
carried_car_input = aes.get_scalar_type(car_acc_dtype).make_variable()
1208+
1209+
scalar_fused_output = car_scalar_op(carried_car_input, elm_output)
1210+
if scalar_fused_output.type.dtype != car_acc_dtype:
1211+
scalar_fused_output = aes.cast(scalar_fused_output, car_acc_dtype)
12031212

12041213
fused_scalar_op = aes.Composite(
1205-
inputs=[carried_car_input] + scalar_elm_inputs, outputs=scalar_fused_outputs
1214+
inputs=[carried_car_input] + scalar_elm_inputs, outputs=[scalar_fused_output]
12061215
)
12071216

12081217
# The fused `Op` needs to look and behave like a `BinaryScalarOp`
@@ -1211,7 +1220,13 @@ def local_careduce_fusion(fgraph, node):
12111220
fused_scalar_op.nin = 2
12121221
fused_scalar_op.nout = 1
12131222

1214-
new_car_op = CAReduce(fused_scalar_op, car_axis)
1223+
new_car_op = CAReduce(
1224+
scalar_op=fused_scalar_op,
1225+
axis=car_op.axis,
1226+
acc_dtype=car_acc_dtype,
1227+
dtype=car_op.dtype,
1228+
upcast_discrete_output=car_op.upcast_discrete_output,
1229+
)
12151230

12161231
return [new_car_op(*elm_inputs)]
12171232

tests/tensor/rewriting/test_elemwise.py

+20-6
Original file line numberDiff line numberDiff line change
@@ -1177,8 +1177,22 @@ def test_test_values(self, test_value):
11771177
)
11781178

11791179
@pytest.mark.parametrize("linker", ["cvm", "py"])
1180+
@pytest.mark.parametrize("inp_dtype", ("floatX", "int32"))
11801181
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1), (0, 1, 2)])
1181-
def test_CAReduce_single_input(self, linker, axis):
1182+
@pytest.mark.parametrize(
1183+
"careduce_op, numpy_op",
1184+
[
1185+
(at_sum, np.sum),
1186+
pytest.param(
1187+
at_all,
1188+
np.all,
1189+
marks=pytest.mark.xfail(
1190+
reason="Rewrite logic does not support all CAReduce"
1191+
),
1192+
),
1193+
],
1194+
)
1195+
def test_CAReduce_single_input(self, linker, inp_dtype, axis, careduce_op, numpy_op):
11821196
"""Make sure that `CAReduce` and `Elemwise` fusions work with a single input."""
11831197

11841198
mode = Mode(linker=linker)
@@ -1188,8 +1202,8 @@ def test_CAReduce_single_input(self, linker, axis):
11881202
"inplace",
11891203
)
11901204

1191-
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
1192-
out = exp(x).sum(axis=axis)
1205+
x = tensor(dtype=inp_dtype, shape=(None, None, None), name="x")
1206+
out = careduce_op(exp(x), axis=axis)
11931207

11941208
out_fn = function([x], out, mode=mode)
11951209

@@ -1198,9 +1212,9 @@ def test_CAReduce_single_input(self, linker, axis):
11981212
assert isinstance(getattr(out_node.op, "scalar_op"), aes.basic.Composite)
11991213

12001214
rng = np.random.default_rng(2320)
1201-
x_val = rng.random((4, 3, 2), dtype=config.floatX)
1215+
x_val = rng.random((4, 3, 2)).astype(x.type.dtype)
12021216

1203-
exp_res = np.exp(x_val).sum(axis=axis)
1217+
exp_res = numpy_op(np.exp(x_val), axis=axis)
12041218

12051219
out_val = out_fn(x_val)
12061220
assert out_val.shape == exp_res.shape
@@ -1216,7 +1230,7 @@ def test_CAReduce_single_input(self, linker, axis):
12161230
# `Elemwise`s with more than one client shouldn't be rewritten
12171231
x = tensor(dtype="floatX", shape=(None, None, None), name="x")
12181232
exp_x = exp(x)
1219-
out = exp_x.sum(axis=axis) + exp(x)
1233+
out = careduce_op(exp_x, axis=axis) + exp(x)
12201234

12211235
out_fn = function([x], out, mode=mode)
12221236
out_nodes = out_fn.maker.fgraph.toposort()

0 commit comments

Comments
 (0)