Skip to content

Commit 941d4cf

Browse files
committed
Tune Fusion Optimizer constraints to backend
The previous approach was insufficient. For instance, `test_shape_i_const` manually included `fast_run` which includes the fusion optimization, even when the test default mode is `fast_compile`. This led to an issue because `fast_compile` mode prevents the creation of `c_thunks`, even when a `C` compiler is available. This forces the use of Python perform method which is limited to 32 operands. The Fusion Optimizer only looked at the `cxx` flag and was assuming that the C limit (1024 operands) was in place. In this test, one of the fused Composite now surpassed that limit.
1 parent 25e98cd commit 941d4cf

File tree

4 files changed

+107
-43
lines changed

4 files changed

+107
-43
lines changed

pytensor/compile/mode.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -441,19 +441,34 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
441441
# FunctionMaker, the Mode will be taken from this dictionary using the
442442
# string as the key
443443
# Use VM_linker to allow lazy evaluation by default.
444-
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
444+
FAST_COMPILE = Mode(
445+
VMLinker(use_cloop=False, c_thunks=False),
446+
RewriteDatabaseQuery(include=["fast_compile"], exclude=["cxx_only"]),
447+
)
445448
if config.cxx:
446-
FAST_RUN = Mode("cvm", "fast_run")
449+
FAST_RUN = Mode(
450+
"cvm",
451+
RewriteDatabaseQuery(include=["fast_run"], exclude=["jax", "numba"]),
452+
)
447453
else:
448-
FAST_RUN = Mode("vm", "fast_run")
454+
FAST_RUN = Mode(
455+
"vm",
456+
RewriteDatabaseQuery(
457+
include=["fast_run"], exclude=["cxx_only", "jax", "numba"]
458+
),
459+
)
449460

450461
JAX = Mode(
451462
JAXLinker(),
452-
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
463+
RewriteDatabaseQuery(
464+
include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt", "numba"]
465+
),
453466
)
454467
NUMBA = Mode(
455468
NumbaLinker(),
456-
RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]),
469+
RewriteDatabaseQuery(
470+
include=["fast_run", "numba"], exclude=["cxx_only", "BlasOpt", "jax"]
471+
),
457472
)
458473

459474

pytensor/tensor/rewriting/elemwise.py

+85-36
Original file line numberDiff line numberDiff line change
@@ -593,20 +593,13 @@ def local_add_mul_fusion(fgraph, node):
593593
return [output]
594594

595595

596-
def elemwise_max_operands_fct(node) -> int:
597-
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
598-
if not config.cxx:
599-
return 32
600-
return 1024
601-
602-
603596
class FusionOptimizer(GraphRewriter):
604597
"""Graph optimizer that fuses consecutive Elemwise operations."""
605598

606-
def __init__(self, local_optimizer=None):
607-
# TODO: Figure out what to do with this
599+
def __init__(self, backend):
608600
super().__init__()
609-
self.optimizer = local_optimizer
601+
assert backend in ("py", "c", "numba")
602+
self.backend = backend
610603

611604
def add_requirements(self, fgraph):
612605
fgraph.attach_feature(ReplaceValidate())
@@ -654,29 +647,29 @@ def elemwise_to_scalar(inputs, outputs):
654647
return scalar_inputs, scalar_outputs
655648

656649
def apply(self, fgraph):
650+
# Even though this rewrite it marked as `cxx_only`,
651+
# it may sometimes be called when `cxx` is disabled -.-
652+
if self.backend == "c" and not config.cxx:
653+
return
654+
657655
nb_replacement = 0
658656

659657
if fgraph.profile:
660658
validate_before = fgraph.profile.validate_time
661659
callbacks_before = fgraph.execute_callbacks_times.copy()
662660
callback_before = fgraph.execute_callbacks_time
663661

664-
max_operands = elemwise_max_operands_fct(None)
662+
# `Elemwise.perform` uses NumPy ufuncs and they are limited to 32 operands (inputs and outputs)
663+
max_operands = 32 if self.backend == "py" else 1024
665664

666-
def find_next_fuseable_subgraph(
667-
fg: FunctionGraph,
668-
) -> Generator[Tuple[List[Variable], List[Variable]], None, None]:
669-
"""Find all subgraphs in a FunctionGraph that can be fused together
670-
671-
Yields
672-
-------
673-
List of inputs and outputs that determine subgraphs which can be fused. This
674-
method assumes that such replacement is done across iterations of the
675-
generator.
676-
"""
665+
if self.backend in ("py", "c"):
666+
# Python mode is not really a backend, and it may or may not call C code
667+
# Rewrites don't have access to the linker to make this decision, So we assume
668+
# we can only fuse Ops with C implementation
677669

670+
# Python rewrite may
678671
@lru_cache(maxsize=None)
679-
def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
672+
def elemwise_scalar_op_can_be_fused(node: Apply) -> bool:
680673
if node.op.scalar_op.supports_c_code(node.inputs, node.outputs):
681674
return True
682675
else:
@@ -690,6 +683,24 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
690683
)
691684
return False
692685

686+
elif self.backend == "numba":
687+
688+
def elemwise_scalar_op_can_be_fused(node: Apply) -> bool:
689+
# Should we truncate at numba elemwise ops that need to run in object mode?
690+
return True
691+
692+
def find_next_fuseable_subgraph(
693+
fg: FunctionGraph,
694+
) -> Generator[Tuple[List[Variable], List[Variable]], None, None]:
695+
"""Find all subgraphs in a FunctionGraph that can be fused together
696+
697+
Yields
698+
-------
699+
List of inputs and outputs that determine subgraphs which can be fused. This
700+
method assumes that such replacement is done across iterations of the
701+
generator.
702+
"""
703+
693704
# We start by creating two maps, 1) from each node to each potentially
694705
# fuseable client (both nodes must be single output Elemwise with same
695706
# broadcast type) and 2) from each node to each certainly unfuseable
@@ -702,7 +713,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
702713
and isinstance(out.owner.op, Elemwise)
703714
# and not isinstance(out.owner.op.scalar_op, aes.Composite)
704715
and len(out.owner.outputs) == 1
705-
and elemwise_scalar_op_has_c_code(out.owner)
716+
and elemwise_scalar_op_can_be_fused(out.owner)
706717
)
707718
for client, _ in clients:
708719
if (
@@ -713,7 +724,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
713724
and len(client.outputs) == 1
714725
and out.type.broadcastable
715726
== client.outputs[0].type.broadcastable
716-
and elemwise_scalar_op_has_c_code(client)
727+
and elemwise_scalar_op_can_be_fused(client)
717728
):
718729
if client not in fuseable_clients[out]:
719730
fuseable_clients[out].append(client)
@@ -1001,7 +1012,7 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool:
10011012
if (len(inputs) + len(outputs)) > max_operands:
10021013
warn(
10031014
"Loop fusion failed because the resulting node would exceed "
1004-
"the kernel argument limit."
1015+
"the backend limit for number of operands."
10051016
)
10061017
break
10071018

@@ -1067,30 +1078,68 @@ def print_profile(stream, prof, level=0):
10671078
print(blanc, " time_toposort", prof[7], file=stream)
10681079

10691080

1070-
if config.tensor__local_elemwise_fusion:
1071-
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
1072-
fuse_seqopt = SequenceDB()
1073-
fuse_seqopt.register(
1081+
fuse_opt_py = SequenceDB()
1082+
fuse_opt_c = SequenceDB()
1083+
fuse_opt_numba = SequenceDB()
1084+
for fuse_opt in (fuse_opt_py, fuse_opt_c, fuse_opt_numba):
1085+
fuse_opt.register(
10741086
"local_add_mul_fusion",
10751087
EquilibriumGraphRewriter(rewriters=[local_add_mul_fusion], max_use_ratio=1000),
10761088
"fast_run",
10771089
"fusion",
10781090
position=0,
10791091
)
1080-
fuse_seqopt.register(
1081-
"composite_elemwise_fusion",
1082-
FusionOptimizer(),
1092+
fuse_opt_py.register(
1093+
"composite_elemwise_fusion_py",
1094+
FusionOptimizer("py"),
1095+
"fast_run",
1096+
"fusion",
1097+
position=1,
1098+
)
1099+
fuse_opt_c.register(
1100+
"composite_elemwise_fusion_c",
1101+
FusionOptimizer("c"),
1102+
"fast_run",
1103+
"fusion",
1104+
position=1,
1105+
)
1106+
fuse_opt_numba.register(
1107+
"composite_elemwise_fusion_numba",
1108+
FusionOptimizer("numba"),
1109+
"fast_run",
1110+
"fusion",
1111+
position=1,
1112+
)
1113+
1114+
1115+
if config.tensor__local_elemwise_fusion:
1116+
# Must be after gpu(48.5) and before AddDestroyHandler(49.5)
1117+
compile.optdb.register( # type: ignore
1118+
"elemwise_fusion_c",
1119+
fuse_opt_c,
10831120
"fast_run",
10841121
"fusion",
1085-
position=1,
1122+
"local_elemwise_fusion",
1123+
"FusionOptimizer",
1124+
"cxx_only",
1125+
position=49,
10861126
)
1127+
# We allow the Python version to run afterwards,
1128+
# since there is no mode for Python only
10871129
compile.optdb.register( # type: ignore
1088-
"elemwise_fusion",
1089-
fuse_seqopt,
1130+
"elemwise_fusion_py",
1131+
fuse_opt_py,
10901132
"fast_run",
10911133
"fusion",
10921134
"local_elemwise_fusion",
10931135
"FusionOptimizer",
1136+
position=49.01,
1137+
)
1138+
# TODO: Not sure about this... Could rewrites receive info about the linker that is being used?
1139+
compile.optdb.register( # type: ignore
1140+
"elemwise_fusion_numba",
1141+
fuse_opt_numba,
1142+
"numba",
10941143
position=49,
10951144
)
10961145

tests/tensor/rewriting/test_elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ class TestFusion:
270270
"fusion",
271271
"inplace",
272272
],
273-
exclude=["cxx_only", "BlasOpt"],
273+
exclude=["BlasOpt"],
274274
)
275275
mode = Mode(get_default_mode().linker, rewrites)
276276
_shared = staticmethod(shared)

tests/tensor/test_subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ def test_adv_sub1_idx_broadcast(self):
986986
def test_shape_i_const(self):
987987
# Each axis is treated independently by shape_i/shape operators
988988

989-
mode_opt = self.mode.including("fast_run").excluding("fusion")
989+
mode_opt = self.mode.including("fast_run")
990990
data = self.shared(np.array(np.arange(5), dtype=self.dtype))
991991
for start in [None] + [-8, -5, -1, 0, 1, 5, 8]:
992992
outs = []

0 commit comments

Comments
 (0)