Skip to content

Commit 1e9ff57

Browse files
committed
Make all inplace rewrites happen at 50.x
1 parent 9c8e25f commit 1e9ff57

File tree

8 files changed

+14
-16
lines changed

8 files changed

+14
-16
lines changed

pytensor/scan/rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2492,7 +2492,7 @@ def scan_push_out_dot1(fgraph, node):
24922492
"fast_run",
24932493
"inplace",
24942494
"scan",
2495-
position=75,
2495+
position=50.5,
24962496
)
24972497

24982498
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")

pytensor/sparse/rewriting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def local_inplace_addsd_ccode(fgraph, node):
210210
),
211211
"fast_run",
212212
"inplace",
213-
position=60,
213+
position=50.1,
214214
)
215215

216216

@@ -239,9 +239,9 @@ def local_addsd_ccode(fgraph, node):
239239
pytensor.compile.optdb.register(
240240
"local_addsd_ccode",
241241
WalkingGraphRewriter(local_addsd_ccode),
242-
# Must be after local_inplace_addsd_ccode at 60
242+
# Must be after local_inplace_addsd_ccode at 70.0
243243
"fast_run",
244-
position=61,
244+
position=70.1,
245245
)
246246

247247

pytensor/tensor/random/rewriting/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def random_make_inplace(fgraph, node):
6060
in2out(random_make_inplace, ignore_newtrees=True),
6161
"fast_run",
6262
"inplace",
63-
position=99,
63+
position=50.9,
6464
)
6565

6666

pytensor/tensor/rewriting/blas.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -757,8 +757,6 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
757757
)
758758

759759

760-
# After destroyhandler(49.5) but before we try to make elemwise things
761-
# inplace (75)
762760
blas_opt_inplace = in2out(
763761
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
764762
)
@@ -768,7 +766,8 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
768766
"fast_run",
769767
"inplace",
770768
"blas_opt_inplace",
771-
position=70.0,
769+
# Before we try to make elemwise things inplace (70.5)
770+
position=50.2,
772771
)
773772

774773

pytensor/tensor/rewriting/blas_scipy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def make_ger_destructive(fgraph, node):
3333
make_scipy_blas_destructive,
3434
"fast_run",
3535
"inplace",
36-
position=70.0,
36+
position=50.2,
3737
)

pytensor/tensor/rewriting/elemwise.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def apply(self, fgraph):
186186
for i in range(len(node.inputs))
187187
if i not in baseline.values()
188188
and not isinstance(node.inputs[i], Constant)
189-
and
190189
# the next line should not be costly most of the time.
191-
not fgraph.has_destroyers([node.inputs[i]])
190+
and not fgraph.has_destroyers([node.inputs[i]])
192191
and node.inputs[i] not in protected_inputs
193192
]
194193
else:
@@ -362,7 +361,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
362361
"inplace_elemwise_optimizer",
363362
"fast_run",
364363
"inplace",
365-
position=75,
364+
position=50.5,
366365
)
367366

368367

pytensor/tensor/rewriting/subtensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ def local_inplace_setsubtensor(fgraph, node):
12871287
),
12881288
"fast_run",
12891289
"inplace",
1290-
position=60,
1290+
position=50.1,
12911291
)
12921292

12931293

@@ -1309,7 +1309,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
13091309
),
13101310
"fast_run",
13111311
"inplace",
1312-
position=60,
1312+
position=70.6,
13131313
)
13141314

13151315

@@ -1335,7 +1335,7 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node):
13351335
),
13361336
"fast_run",
13371337
"inplace",
1338-
position=60,
1338+
position=70.6,
13391339
)
13401340

13411341

pytensor/typed_list/rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def typed_list_inplace_rewrite(fgraph, node):
2222
),
2323
"fast_run",
2424
"inplace",
25-
position=60,
25+
position=50.1,
2626
)

0 commit comments

Comments
 (0)