Skip to content

Commit 22c5db3

Browse files
committed
Fuse consecutive Elemwise nodes with multiple clients
1 parent 256450f commit 22c5db3

File tree

7 files changed

+594
-376
lines changed

7 files changed

+594
-376
lines changed

pytensor/tensor/elemwise.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,10 @@ def transform(r):
652652

653653
def prepare_node(self, node, storage_map, compute_map, impl):
654654
# Postpone the ufunc building to the last minutes due to:
655-
# - NumPy ufunc support only up to 31 inputs.
655+
# - NumPy ufunc support only up to 32 operands (inputs and outputs)
656656
# But our c code support more.
657657
# - nfunc is reused for scipy and scipy is optional
658-
if len(node.inputs) > 32 and self.ufunc and impl == "py":
658+
if (len(node.inputs) + len(node.outputs)) > 32 and impl == "py":
659659
impl = "c"
660660

661661
if getattr(self, "nfunc_spec", None) and impl != "c":
@@ -677,7 +677,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
677677
self.nfunc = module
678678

679679
if (
680-
len(node.inputs) < 32
680+
(len(node.inputs) + len(node.outputs)) <= 32
681681
and (self.nfunc is None or self.scalar_op.nin != len(node.inputs))
682682
and self.ufunc is None
683683
and impl == "py"
@@ -727,28 +727,18 @@ def prepare_node(self, node, storage_map, compute_map, impl):
727727
self.scalar_op.prepare_node(node.tag.fake_node, None, None, impl)
728728

729729
def perform(self, node, inputs, output_storage):
730-
if len(node.inputs) >= 32:
730+
if (len(node.inputs) + len(node.outputs)) > 32:
731731
# Some versions of NumPy will segfault, other will raise a
732-
# ValueError, if the number of inputs to a ufunc is 32 or more.
732+
# ValueError, if the number of operands in an ufunc is more than 32.
733733
# In that case, the C version should be used, or Elemwise fusion
734734
# should be disabled.
735+
# FIXME: This no longer calls the C implementation!
735736
super().perform(node, inputs, output_storage)
736737

737738
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
738739
if len(set(dim_shapes) - {1}) > 1:
739740
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
740741

741-
# Determine the shape of outputs
742-
out_shape = []
743-
for values in zip(*[input.shape for input in inputs]):
744-
if any(v == 0 for v in values):
745-
# All non-broadcasted dimensions should be zero
746-
assert max(values) <= 1
747-
out_shape.append(0)
748-
else:
749-
out_shape.append(max(values))
750-
out_shape = tuple(out_shape)
751-
752742
ufunc_args = inputs
753743
ufunc_kwargs = {}
754744
# We supported in the past calling manually op.perform.

0 commit comments

Comments
 (0)