1
+ from pytensor .compile .mode import optdb
2
+ from pytensor .graph import node_rewriter
3
+ from pytensor .graph .rewriting .basic import out2in , copy_stack_trace
4
+ from pytensor .tensor .blockwise import Blockwise , vectorize_node
5
+ from pytensor .tensor .rewriting .basic import register_useless
6
+
7
+
8
+ @register_useless ("fast_compile" )
9
+ @node_rewriter ([Blockwise ])
10
+ def local_useless_blockwise (fgraph , node ):
11
+ # If there is a dispatch implementation that does not require Blockwise, use that instead.
12
+ # This means a user created a Blockwise manually when there was no need.
13
+ op : Blockwise = node .op
14
+ inputs = node .inputs
15
+ dummy_core_node = op ._create_dummy_core_node (node .inputs )
16
+ vect_node = vectorize_node (dummy_core_node , * inputs )
17
+ if not isinstance (vect_node .op , Blockwise ):
18
+ return copy_stack_trace (node .outputs , vect_node .outputs )
19
+
20
+
21
+ @node_rewriter ([Blockwise ])
22
+ def local_useless_unbatched_blockwise (fgraph , node ):
23
+ """Remove Blockwise that don't have any batched dims."""
24
+ op : Blockwise = node .op
25
+ inputs = node .inputs
26
+
27
+ if max (
28
+ inp .type .ndim - len (sig ) for inp , sig in zip (inputs , op .inputs_sig )
29
+ ) == 0 :
30
+ return copy_stack_trace (node .outputs , op .core_op .make_node (* inputs ).outputs )
31
+
32
+
33
+ # We register this rewrite late, so that other rewrites need only target Blockwise Ops
34
+ optdb .register (
35
+ "local_useless_unbatched_blockwise" ,
36
+ out2in (local_useless_unbatched_blockwise , ignore_newtrees = True ),
37
+ "fast_run" ,
38
+ "fast_compile" ,
39
+ "blockwise" ,
40
+ position = 49 ,
41
+ )
0 commit comments