|
| 1 | +import itertools |
| 2 | + |
| 3 | +from pytensor.compile import Supervisor |
1 | 4 | from pytensor.compile.mode import optdb
|
2 | 5 | from pytensor.graph import Constant, node_rewriter
|
3 | 6 | from pytensor.graph.replace import vectorize_node
|
4 |
| -from pytensor.graph.rewriting.basic import copy_stack_trace, out2in |
| 7 | +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in |
5 | 8 | from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
|
6 | 9 | from pytensor.tensor.blockwise import Blockwise
|
7 | 10 | from pytensor.tensor.math import Dot
|
@@ -50,13 +53,14 @@ def local_useless_unbatched_blockwise(fgraph, node):
|
50 | 53 |
|
51 | 54 |
|
52 | 55 | # We register this rewrite late, so that other rewrites need only target Blockwise Ops
|
| 56 | +# We do it after position>=60 so that Blockwise inplace rewrites will work also on useless Blockwise Ops |
53 | 57 | optdb.register(
|
54 | 58 | "local_useless_unbatched_blockwise",
|
55 | 59 | out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
|
56 | 60 | "fast_run",
|
57 | 61 | "fast_compile",
|
58 | 62 | "blockwise",
|
59 |
| - position=49, |
| 63 | + position=60, |
60 | 64 | )
|
61 | 65 |
|
62 | 66 |
|
@@ -225,3 +229,77 @@ def local_blockwise_reshape(fgraph, node):
|
225 | 229 | new_out = x.reshape([*tuple(batched_shape), *tuple(core_reshape)])
|
226 | 230 | copy_stack_trace(node.outputs[0], new_out)
|
227 | 231 | return [new_out]
|
| 232 | + |
| 233 | + |
| 234 | +@node_rewriter(tracks=[Blockwise], inplace=True) |
| 235 | +def blockwise_inplace(fgraph, node): |
| 236 | + blockwise_op = node.op |
| 237 | + |
| 238 | + if blockwise_op.destroy_map: |
| 239 | + # Op already has inplace |
| 240 | + return |
| 241 | + |
| 242 | + # Find out valid inputs for inplacing |
| 243 | + batch_ndim = blockwise_op.batch_ndim(node) |
| 244 | + out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] |
| 245 | + |
| 246 | + protected_inputs = [ |
| 247 | + f.protected for f in fgraph._features if isinstance(f, Supervisor) |
| 248 | + ] |
| 249 | + protected_inputs = list(itertools.chain.from_iterable(protected_inputs)) |
| 250 | + protected_inputs.extend(fgraph.outputs) |
| 251 | + allowed_inplace_inputs = [ |
| 252 | + idx |
| 253 | + for idx, inp in enumerate(node.inputs) |
| 254 | + if |
| 255 | + ( |
| 256 | + # Constants would need to be recreated every time if inplaced |
| 257 | + not isinstance(inp, Constant) |
| 258 | + # We can only inplace on inputs that are not being broadcasted |
| 259 | + # As those are reused across iterations of Blockwise |
| 260 | + and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast |
| 261 | + # Inputs that are marked as protected or destroyed can't be inplaced |
| 262 | + and not fgraph.has_destroyers([inp]) |
| 263 | + and inp not in protected_inputs |
| 264 | + ) |
| 265 | + ] |
| 266 | + |
| 267 | + if not allowed_inplace_inputs: |
| 268 | + return None |
| 269 | + |
| 270 | + inplace_core_op = blockwise_op.core_op.inplace_on_inputs( |
| 271 | + allowed_inplace_inputs=allowed_inplace_inputs |
| 272 | + ) |
| 273 | + |
| 274 | + if not inplace_core_op.destroy_map: |
| 275 | + return None |
| 276 | + |
| 277 | + # Check Op is not trying to inplace on non-candidate inputs |
| 278 | + for destroyed_inputs in inplace_core_op.destroy_map.values(): |
| 279 | + for destroyed_input in destroyed_inputs: |
| 280 | + if destroyed_input not in allowed_inplace_inputs: |
| 281 | + raise ValueError( |
| 282 | + f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}" |
| 283 | + ) |
| 284 | + |
| 285 | + # Recreate core_op with inplace |
| 286 | + inplace_blockwise_op = Blockwise( |
| 287 | + core_op=inplace_core_op, |
| 288 | + signature=blockwise_op.signature, |
| 289 | + name=blockwise_op.name, |
| 290 | + gufunc_spec=blockwise_op.gufunc_spec, |
| 291 | + destroy_map=inplace_core_op.destroy_map, |
| 292 | + ) |
| 293 | + |
| 294 | + out = inplace_blockwise_op.make_node(*node.inputs).outputs |
| 295 | + copy_stack_trace(node.outputs, out) |
| 296 | + return out |
| 297 | + |
| 298 | + |
| 299 | +optdb.register( |
| 300 | + "blockwise_inplace", |
| 301 | + in2out(blockwise_inplace), |
| 302 | + "fast_run", |
| 303 | + "inplace", |
| 304 | + position=50.1, |
| 305 | +) |
0 commit comments