Skip to content

Commit ba2121c

Browse files
committed
Make all inplace rewrites happen at 50.x
1 parent b248eba commit ba2121c

File tree

8 files changed

+14
-16
lines changed

8 files changed

+14
-16
lines changed

pytensor/scan/rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2492,7 +2492,7 @@ def scan_push_out_dot1(fgraph, node):
24922492
"fast_run",
24932493
"inplace",
24942494
"scan",
2495-
position=75,
2495+
position=50.5,
24962496
)
24972497

24982498
scan_eqopt1.register("all_pushout_opt", scan_seqopt1, "fast_run", "scan")

pytensor/sparse/rewriting.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def local_inplace_addsd_ccode(fgraph, node):
210210
),
211211
"fast_run",
212212
"inplace",
213-
position=60,
213+
position=50.1,
214214
)
215215

216216

@@ -239,9 +239,9 @@ def local_addsd_ccode(fgraph, node):
239239
pytensor.compile.optdb.register(
240240
"local_addsd_ccode",
241241
WalkingGraphRewriter(local_addsd_ccode),
242-
# Must be after local_inplace_addsd_ccode at 60
242+
# Must be after local_inplace_addsd_ccode at 70.0
243243
"fast_run",
244-
position=61,
244+
position=70.1,
245245
)
246246

247247

pytensor/tensor/random/rewriting/basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def random_make_inplace(fgraph, node):
6060
in2out(random_make_inplace, ignore_newtrees=True),
6161
"fast_run",
6262
"inplace",
63-
position=99,
63+
position=50.9,
6464
)
6565

6666

pytensor/tensor/rewriting/blas.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,6 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
762762
)
763763

764764

765-
# After destroyhandler(49.5) but before we try to make elemwise things
766-
# inplace (75)
767765
blas_opt_inplace = in2out(
768766
local_inplace_gemm, local_inplace_gemv, local_inplace_ger, name="blas_opt_inplace"
769767
)
@@ -773,7 +771,8 @@ def local_dot22_to_ger_or_gemv(fgraph, node):
773771
"fast_run",
774772
"inplace",
775773
"blas_opt_inplace",
776-
position=70.0,
774+
# Before we try to make elemwise things inplace (70.5)
775+
position=50.2,
777776
)
778777

779778

pytensor/tensor/rewriting/blas_scipy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@ def make_ger_destructive(fgraph, node):
3333
make_scipy_blas_destructive,
3434
"fast_run",
3535
"inplace",
36-
position=70.0,
36+
position=50.2,
3737
)

pytensor/tensor/rewriting/elemwise.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def apply(self, fgraph):
186186
for i in range(len(node.inputs))
187187
if i not in baseline.values()
188188
and not isinstance(node.inputs[i], Constant)
189-
and
190189
# the next line should not be costly most of the time.
191-
not fgraph.has_destroyers([node.inputs[i]])
190+
and not fgraph.has_destroyers([node.inputs[i]])
192191
and node.inputs[i] not in protected_inputs
193192
]
194193
else:
@@ -362,7 +361,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
362361
"inplace_elemwise_optimizer",
363362
"fast_run",
364363
"inplace",
365-
position=75,
364+
position=50.5,
366365
)
367366

368367

pytensor/tensor/rewriting/subtensor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1307,7 +1307,7 @@ def local_inplace_setsubtensor(fgraph, node):
13071307
),
13081308
"fast_run",
13091309
"inplace",
1310-
position=60,
1310+
position=50.1,
13111311
)
13121312

13131313

@@ -1329,7 +1329,7 @@ def local_inplace_AdvancedIncSubtensor1(fgraph, node):
13291329
),
13301330
"fast_run",
13311331
"inplace",
1332-
position=60,
1332+
position=70.6,
13331333
)
13341334

13351335

@@ -1355,7 +1355,7 @@ def local_inplace_AdvancedIncSubtensor(fgraph, node):
13551355
),
13561356
"fast_run",
13571357
"inplace",
1358-
position=60,
1358+
position=70.6,
13591359
)
13601360

13611361

pytensor/typed_list/rewriting.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ def typed_list_inplace_rewrite(fgraph, node):
2222
),
2323
"fast_run",
2424
"inplace",
25-
position=60,
25+
position=50.1,
2626
)

0 commit comments

Comments
 (0)