|
8 | 8 | import pytensor.scalar.basic as ps
|
9 | 9 | from pytensor import compile
|
10 | 10 | from pytensor.compile import optdb
|
11 |
| -from pytensor.graph import FunctionGraph |
12 | 11 | from pytensor.graph.basic import Constant, Variable
|
13 | 12 | from pytensor.graph.rewriting.basic import (
|
14 |
| - EquilibriumGraphRewriter, |
15 | 13 | WalkingGraphRewriter,
|
16 | 14 | copy_stack_trace,
|
17 | 15 | in2out,
|
|
58 | 56 | register_specialize,
|
59 | 57 | register_stabilize,
|
60 | 58 | )
|
61 |
| -from pytensor.tensor.rewriting.extremum import ( |
62 |
| - local_extremum_plus_x, |
63 |
| - local_flatten_extremum, |
64 |
| - local_useless_extremum_branches, |
65 |
| -) |
66 |
| -from pytensor.tensor.rewriting.math import ( |
67 |
| - local_add_canonizer, |
68 |
| - local_intdiv_by_one, |
69 |
| - local_mul_canonizer, |
70 |
| -) |
71 | 59 | from pytensor.tensor.shape import (
|
72 | 60 | Shape,
|
73 | 61 | SpecifyShape,
|
@@ -572,20 +560,20 @@ def local_subtensor_merge(fgraph, node):
|
572 | 560 | out = subtens(x, *sl_ins)
|
573 | 561 |
|
574 | 562 | # Eagerly clean up merged subtensor graph, which can be a mess
|
575 |
| - rewriter = EquilibriumGraphRewriter( |
576 |
| - [ |
577 |
| - local_extremum_plus_x, |
578 |
| - local_add_canonizer, |
579 |
| - local_mul_canonizer, |
580 |
| - local_intdiv_by_one, |
581 |
| - local_useless_extremum_branches, |
582 |
| - local_flatten_extremum, |
583 |
| - ], |
584 |
| - max_use_ratio=10.0, |
585 |
| - ) |
586 |
| - fg = FunctionGraph(outputs=[out], clone=False) |
587 |
| - rewriter.rewrite(fg) |
588 |
| - [out] = fg.outputs |
| 563 | + # rewriter = EquilibriumGraphRewriter( |
| 564 | + # [ |
| 565 | + # local_extremum_plus_x, |
| 566 | + # local_add_canonizer, |
| 567 | + # local_mul_canonizer, |
| 568 | + # local_intdiv_by_one, |
| 569 | + # local_useless_extremum_branches, |
| 570 | + # local_flatten_extremum, |
| 571 | + # ], |
| 572 | + # max_use_ratio=10.0, |
| 573 | + # ) |
| 574 | + # fg = FunctionGraph(outputs=[out], clone=False) |
| 575 | + # rewriter.rewrite(fg) |
| 576 | + # [out] = fg.outputs |
589 | 577 |
|
590 | 578 | # Copy over previous output stacktrace
|
591 | 579 | # and stacktrace from previous slicing operation.
|
|
0 commit comments