Skip to content

Make CAReduce more SIMD and memory friendly #1385

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 22 additions & 31 deletions pytensor/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,7 +1422,7 @@ def infer_shape(self, fgraph, node, shapes):
def _c_all(self, node, name, input_names, output_names, sub):
[inp] = node.inputs
[out] = node.outputs
ndim = inp.type.ndim
inp_ndim = inp.type.ndim

[inp_name] = input_names
[out_name] = output_names
Expand Down Expand Up @@ -1454,10 +1454,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
assert var.dtype == node.outputs[0].dtype
return var.owner.op._c_all(var.owner, name, input_names, output_names, sub)

inp_dims = list(range(ndim))
inp_dims = list(range(inp_ndim))
non_reduced_dims = [i for i in inp_dims if i not in axis]
counter = iter(range(ndim))
acc_dims = ["x" if i in axis else next(counter) for i in range(ndim)]
counter = iter(range(inp_ndim))
acc_dims = ["x" if i in axis else next(counter) for i in range(inp_ndim)]

sub = sub.copy()
sub["lv0"] = inp_name
Expand All @@ -1484,7 +1484,9 @@ def _c_all(self, node, name, input_names, output_names, sub):
cgen.make_declare(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], out_dtype, sub)
+ cgen.make_careduce_alloc(
inp_name, out_name, inp_ndim, axis, out_dtype, sub["fail"]
)
+ cgen.make_checks(
[acc_dims], [out_dtype], out_sub, compute_stride_jump=False
)
Expand All @@ -1500,7 +1502,10 @@ def _c_all(self, node, name, input_names, output_names, sub):
cgen.make_declare(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
+ cgen.make_alloc([non_reduced_dims], acc_dtype, sub)
+ cgen.make_careduce_alloc(
inp_name, acc_name, inp_ndim, axis, out_dtype, sub["fail"]
)
+ cgen.make_careduce_alloc([non_reduced_dims], acc_dtype, sub)
+ cgen.make_checks(
[acc_dims], [acc_dtype], acc_sub, compute_stride_jump=False
)
Expand All @@ -1524,8 +1529,6 @@ def _c_all(self, node, name, input_names, output_names, sub):
elif identity is None:
raise TypeError(f"The {self.scalar_op} does not define an identity.")

initial_value = f"{acc_name}_i = {identity};"

inner_task = self.scalar_op.c_code(
Apply(
self.scalar_op,
Expand All @@ -1544,28 +1547,16 @@ def _c_all(self, node, name, input_names, output_names, sub):
sub,
)

if out.type.ndim == 0:
# Simple case where everything is reduced, no need for loop ordering
loop = cgen.make_complete_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
initial_value=initial_value,
inner_task=inner_task,
fail_code=sub["fail"],
)
else:
loop = cgen.make_reordered_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
inp_ndim=ndim,
reduction_axes=axis,
initial_value=initial_value,
inner_task=inner_task,
)
loop = cgen.make_reordered_loop_careduce(
inp_var=inp_name,
acc_var=acc_name,
inp_dtype=inp_dtype,
acc_dtype=acc_dtype,
inp_ndim=inp_ndim,
reduction_axes=axis,
initial_value=identity,
inner_task=inner_task,
)

if acc_dtype != out_dtype:
cast = dedent(
Expand All @@ -1589,7 +1580,7 @@ def c_headers(self, **kwargs):

def c_code_cache_version_apply(self, node):
# the version corresponding to the c code in this Op
version = [10]
version = [11]

# now we insert versions for the ops on which we depend...
scalar_node = Apply(
Expand Down
Loading
Loading