Skip to content

Implement destructive in-place rewrites for Cholesky and Solve Ops #1028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,12 @@ def make_thunk(
)
return self.make_py_thunk(node, storage_map, compute_map, no_recycling)

def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op":
"""Try to return a version of self that tries to inplace in as many as `allowed_inplace_inputs`."""
# TODO: Document this in the Create your own Op docs
# By default, do nothing
return self

def __str__(self):
return getattr(type(self), "__name__", super().__str__())

Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,7 +2492,7 @@ def scan_push_out_dot1(fgraph, node):
"fast_run",
"inplace",
"scan",
position=75,
position=50.5,
)

scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")
Expand Down
6 changes: 3 additions & 3 deletions pytensor/sparse/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def local_inplace_addsd_ccode(fgraph, node):
),
"fast_run",
"inplace",
position=60,
position=50.1,
)


Expand Down Expand Up @@ -239,9 +239,9 @@ def local_addsd_ccode(fgraph, node):
pytensor.compile.optdb.register(
"local_addsd_ccode",
WalkingGraphRewriter(local_addsd_ccode),
# Must be after local_inplace_addsd_ccode at 60
# Must be after local_inplace_addsd_ccode at 70.0
"fast_run",
position=61,
position=70.1,
)


Expand Down
10 changes: 10 additions & 0 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
signature: str | None = None,
name: str | None = None,
gufunc_spec: tuple[str, int, int] | None = None,
destroy_map=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -79,6 +80,15 @@
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
# Note: Should be fine for destroy_map of Blockwise to be more extensive than that of core_op
# But we are not using that anywhere yet, so this check is fine for now
raise ValueError(

Check warning on line 88 in pytensor/tensor/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/blockwise.py#L88

Added line #L88 was not covered by tests
f"Blockwise destroy_map {self.destroy_map} must be the same as that of the core_op {core_op} {core_op.destroy_map}"
)

super().__init__(**kwargs)

def __getstate__(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/random/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def random_make_inplace(fgraph, node):
in2out(random_make_inplace, ignore_newtrees=True),
"fast_run",
"inplace",
position=99,
position=50.9,
)


Expand Down
5 changes: 2 additions & 3 deletions pytensor/tensor/rewriting/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,8 +762,6 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
)


# After destroyhandler(49.5) but before we try to make elemwise things
# inplace (75)
blas_opt_inplace = in2out(
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
)
Expand All @@ -773,7 +771,8 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
"fast_run",
"inplace",
"blas_opt_inplace",
position=70.0,
# Before we try to make elemwise things inplace (70.5)
position=50.2,
)


Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/blas_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def make_ger_destructive(fgraph, node):
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=70.0,
position=50.2,
)
82 changes: 80 additions & 2 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import itertools

from pytensor.compile import Supervisor
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, node_rewriter
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot
Expand Down Expand Up @@ -50,13 +53,14 @@


# We register this rewrite late, so that other rewrites need only target Blockwise Ops
# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops
optdb.register(
"local_useless_unbatched_blockwise",
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
"fast_run",
"fast_compile",
"blockwise",
position=49,
position=60,
)


Expand Down Expand Up @@ -225,3 +229,77 @@
new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
copy_stack_trace(node.outputs[0], new_out)
return [new_out]


@node_rewriter(tracks=[Blockwise], inplace=True)
def blockwise_inplace(fgraph, node):
blockwise_op = node.op

if blockwise_op.destroy_map:
# Op already has inplace
return

# Find out valid inputs for inplacing
batch_ndim = blockwise_op.batch_ndim(node)
out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim]

protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = list(itertools.chain.from_iterable(protected_inputs))
protected_inputs.extend(fgraph.outputs)
allowed_inplace_inputs = [
idx
for idx, inp in enumerate(node.inputs)
if
(
# Constants would need to be recreated every time if inplaced
not isinstance(inp, Constant)
# We can only inplace on inputs that are not being broadcasted
# As those are reused across iterations of Blockwise
and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast
# Inputs that are marked as protected or destroyed can't be inplaced
and not fgraph.has_destroyers([inp])
and inp not in protected_inputs
)
]

if not allowed_inplace_inputs:
return None

inplace_core_op = blockwise_op.core_op.inplace_on_inputs(
allowed_inplace_inputs=allowed_inplace_inputs
)

if not inplace_core_op.destroy_map:
return None

# Check Op is not trying to inplace on non-candidate inputs
for destroyed_inputs in inplace_core_op.destroy_map.values():
for destroyed_input in destroyed_inputs:
if destroyed_input not in allowed_inplace_inputs:
raise ValueError(

Check warning on line 281 in pytensor/tensor/rewriting/blockwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/blockwise.py#L281

Added line #L281 was not covered by tests
f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}"
)

# Recreate core_op with inplace
inplace_blockwise_op = Blockwise(
core_op=inplace_core_op,
signature=blockwise_op.signature,
name=blockwise_op.name,
gufunc_spec=blockwise_op.gufunc_spec,
destroy_map=inplace_core_op.destroy_map,
)

out = inplace_blockwise_op.make_node(*node.inputs).outputs
copy_stack_trace(node.outputs, out)
return out


optdb.register(
"blockwise_inplace",
in2out(blockwise_inplace),
"fast_run",
"inplace",
position=50.1,
)
5 changes: 2 additions & 3 deletions pytensor/tensor/rewriting/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ def apply(self, fgraph):
for i in range(len(node.inputs))
if i not in baseline.values()
and not isinstance(node.inputs[i], Constant)
and
# the next line should not be costly most of the time.
not fgraph.has_destroyers([node.inputs[i]])
and not fgraph.has_destroyers([node.inputs[i]])
and node.inputs[i] not in protected_inputs
]
else:
Expand Down Expand Up @@ -362,7 +361,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
"inplace_elemwise_optimizer",
"fast_run",
"inplace",
position=75,
position=50.5,
)


Expand Down
6 changes: 3 additions & 3 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ def local_inplace_setsubtensor(fgraph, node):
),
"fast_run",
"inplace",
position=60,
position=50.1,
)


Expand All @@ -1329,7 +1329,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
),
"fast_run",
"inplace",
position=60,
position=70.6,
)


Expand All @@ -1355,7 +1355,7 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node):
),
"fast_run",
"inplace",
position=60,
position=70.6,
)


Expand Down
Loading
Loading