Skip to content

Commit 1de1726

Browse files
Inplace Blockwise and core versions of Cholesky and Solve Ops.
Co-authored-by: Ricardo Vieira <[email protected]>
1 parent f961052 commit 1de1726

File tree

6 files changed

+384
-41
lines changed

6 files changed

+384
-41
lines changed

pytensor/graph/op.py

+6
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,12 @@ def make_thunk(
583583
)
584584
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)
585585

586+
def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
587+
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
588+
# TODO: Document this in the Create your own Op docs
589+
# By default, do nothing
590+
return self
591+
586592
def __str__(self):
587593
return getattr(type(self), "__name__", super().__str__())
588594

pytensor/tensor/blockwise.py

+10
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
signature: str | None = None,
4646
name: str | None = None,
4747
gufunc_spec: tuple[str, int, int] | None = None,
48+
destroy_map=None,
4849
**kwargs,
4950
):
5051
"""
@@ -79,6 +80,15 @@ def __init__(
7980
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
8081
self.gufunc_spec = gufunc_spec
8182
self._gufunc = None
83+
if destroy_map is not None:
84+
self.destroy_map = destroy_map
85+
if self.destroy_map != core_op.destroy_map:
86+
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
87+
# But we are not using that anywhere yet, so this check is fine for now
88+
raise ValueError(
89+
"Blockwise destroy_map must be the same as that of the core_op"
90+
)
91+
8292
super().__init__(**kwargs)
8393

8494
def __getstate__(self):

pytensor/tensor/rewriting/blockwise.py

+80-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import itertools
2+
3+
from pytensor.compile import Supervisor
14
from pytensor.compile.mode import optdb
25
from pytensor.graph import Constant, node_rewriter
36
from pytensor.graph.replace import vectorize_node
4-
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
7+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
58
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
69
from pytensor.tensor.blockwise import Blockwise
710
from pytensor.tensor.math import Dot
@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
5053

5154

5255
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
56+
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
5357
optdb.register(
5458
"local_useless_unbatched_blockwise",
5559
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
5660
"fast_run",
5761
"fast_compile",
5862
"blockwise",
59-
position=49,
63+
position=60,
6064
)
6165

6266

@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
225229
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
226230
copy_stack_trace(node.outputs[0], new_out)
227231
return [new_out]
232+
233+
234+
@node_rewriter(tracks=[Blockwise], inplace=True)
235+
def blockwise_inplace(fgraph, node):
236+
blockwise_op = node.op
237+
238+
if blockwise_op.destroy_map:
239+
# Op already has inplace
240+
return
241+
242+
# Find out valid inputs for inplacing
243+
batch_ndim = blockwise_op.batch_ndim(node)
244+
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]
245+
246+
protected_inputs = [
247+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
248+
]
249+
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
250+
protected_inputs.extend(fgraph.outputs)
251+
allowed_inplace_inputs = [
252+
idx
253+
for idx, inp in enumerate(node.inputs)
254+
if
255+
(
256+
# Constants would need to be recreated every time if inplaced
257+
not isinstance(inp, Constant)
258+
# We can only inplace on inputs that are not being broadcasted
259+
# As those are reused across iterations of Blockwise
260+
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
261+
# Inputs that are marked as protected or destroyed can't be inplaced
262+
and not fgraph.has_destroyers([inp])
263+
and inp not in protected_inputs
264+
)
265+
]
266+
267+
if not allowed_inplace_inputs:
268+
return None
269+
270+
inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
271+
allowed_inplace_inputs=allowed_inplace_inputs
272+
)
273+
274+
if not inplace_core_op.destroy_map:
275+
return None
276+
277+
# Check Op is not trying to inplace on non-candidate inputs
278+
for destroyed_inputs in inplace_core_op.destroy_map.values():
279+
for destroyed_input in destroyed_inputs:
280+
if destroyed_input not in allowed_inplace_inputs:
281+
raise ValueError(
282+
"Op destroy_map does not respect allowed_inplace_inputs"
283+
)
284+
285+
# Recreate core_op with inplace
286+
inplace_blockwise_op = Blockwise(
287+
core_op=inplace_core_op,
288+
signature=blockwise_op.signature,
289+
name=blockwise_op.name,
290+
gufunc_spec=blockwise_op.gufunc_spec,
291+
destroy_map=inplace_core_op.destroy_map,
292+
)
293+
294+
out = inplace_blockwise_op.make_node(*node.inputs).outputs
295+
copy_stack_trace(node.outputs, out)
296+
return out
297+
298+
299+
optdb.register(
300+
"blockwise_inplace",
301+
in2out(blockwise_inplace),
302+
"fast_run",
303+
"inplace",
304+
position=50.1,
305+
)

0 commit comments

Comments
 (0)