1
1
import re
2
- from functools import singledispatch
3
2
from typing import Any , Dict , List , Optional , Sequence , Tuple , cast
4
3
5
4
import numpy as np
9
8
from pytensor .graph .basic import Apply , Constant , Variable
10
9
from pytensor .graph .null_type import NullType
11
10
from pytensor .graph .op import Op
11
+ from pytensor .graph .replace import _vectorize_node , vectorize
12
12
from pytensor .tensor import as_tensor_variable
13
13
from pytensor .tensor .shape import shape_padleft
14
14
from pytensor .tensor .type import continuous_dtypes , discrete_dtypes , tensor
@@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str:
72
72
return f"{ inputs_sig } ->{ outputs_sig } "
73
73
74
74
75
- @singledispatch
76
- def _vectorize_node (op : Op , node : Apply , * bached_inputs ) -> Apply :
75
+ @_vectorize_node . register ( Op )
76
+ def vectorize_node_fallback (op : Op , node : Apply , * bached_inputs ) -> Apply :
77
77
if hasattr (op , "gufunc_signature" ):
78
78
signature = op .gufunc_signature
79
79
else :
@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
83
83
return cast (Apply , Blockwise (op , signature = signature ).make_node (* bached_inputs ))
84
84
85
85
86
- def vectorize_node (node : Apply , * batched_inputs ) -> Apply :
87
- """Returns vectorized version of node with new batched inputs."""
88
- op = node .op
89
- return _vectorize_node (op , node , * batched_inputs )
90
-
91
-
92
86
class Blockwise (Op ):
93
87
"""Generalizes a core `Op` to work with batched dimensions.
94
88
@@ -279,42 +273,18 @@ def as_core(t, core_t):
279
273
280
274
core_igrads = self .core_op .L_op (core_inputs , core_outputs , core_ograds )
281
275
282
- batch_ndims = self ._batch_ndim_from_outputs (outputs )
283
-
284
- def transform (var ):
285
- # From a graph of ScalarOps, make a graph of Broadcast ops.
286
- if isinstance (var .type , (NullType , DisconnectedType )):
287
- return var
288
- if var in core_inputs :
289
- return inputs [core_inputs .index (var )]
290
- if var in core_outputs :
291
- return outputs [core_outputs .index (var )]
292
- if var in core_ograds :
293
- return ograds [core_ograds .index (var )]
294
-
295
- node = var .owner
296
-
297
- # The gradient contains a constant, which may be responsible for broadcasting
298
- if node is None :
299
- if batch_ndims :
300
- var = shape_padleft (var , batch_ndims )
301
- return var
302
-
303
- batched_inputs = [transform (inp ) for inp in node .inputs ]
304
- batched_node = vectorize_node (node , * batched_inputs )
305
- batched_var = batched_node .outputs [var .owner .outputs .index (var )]
306
-
307
- return batched_var
308
-
309
- ret = []
310
- for core_igrad , ipt in zip (core_igrads , inputs ):
311
- # Undefined gradient
312
- if core_igrad is None :
313
- ret .append (None )
314
- else :
315
- ret .append (transform (core_igrad ))
276
+ igrads = vectorize (
277
+ [core_igrad for core_igrad in core_igrads if core_igrad is not None ],
278
+ vectorize = dict (
279
+ zip (core_inputs + core_outputs + core_ograds , inputs + outputs + ograds )
280
+ ),
281
+ )
316
282
317
- return ret
283
+ igrads_iter = iter (igrads )
284
+ return [
285
+ None if core_igrad is None else next (igrads_iter )
286
+ for core_igrad in core_igrads
287
+ ]
318
288
319
289
def L_op (self , inputs , outs , ograds ):
320
290
from pytensor .tensor .math import sum as pt_sum
0 commit comments