Skip to content

Commit d6b8777

Browse files
committed
Implement vectorize utility
1 parent 8df22d7 commit d6b8777

File tree

11 files changed

+111
-55
lines changed

11 files changed

+111
-55
lines changed

pytensor/graph/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
clone,
1010
ancestors,
1111
)
12-
from pytensor.graph.replace import clone_replace, graph_replace
12+
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/replace.py

+65-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from functools import partial
2-
from typing import Iterable, Optional, Sequence, Union, cast, overload
1+
from functools import partial, singledispatch
2+
from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload
33

44
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
55
from pytensor.graph.fg import FunctionGraph
6+
from pytensor.graph.op import Op
67

78

89
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
@@ -198,3 +199,65 @@ def toposort_key(
198199
return list(fg.outputs)
199200
else:
200201
return fg.outputs[0]
202+
203+
204+
@singledispatch
205+
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
206+
# Default implementation is provided in pytensor.tensor.blockwise
207+
raise NotImplementedError
208+
209+
210+
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
211+
"""Returns vectorized version of node with new batched inputs."""
212+
op = node.op
213+
return _vectorize_node(op, node, *batched_inputs)
214+
215+
216+
def vectorize(
217+
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
218+
) -> Sequence[Variable]:
219+
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.
220+
221+
Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.
222+
223+
Examples
224+
--------
225+
.. code-block:: python
226+
227+
import pytensor
228+
import pytensor.tensor as pt
229+
230+
from pytensor.graph import vectorize
231+
232+
# Original graph
233+
x = pt.vector("x")
234+
y = pt.exp(x) / pt.sum(pt.exp(x))
235+
236+
# Vectorized graph
237+
new_x = pt.matrix("new_x")
238+
[new_y] = vectorize([y], {x: new_x})
239+
240+
fn = pytensor.function([new_x], new_y)
241+
fn([[0, 1, 2], [2, 1, 0]])
242+
# array([[0.09003057, 0.24472847, 0.66524096],
243+
# [0.66524096, 0.24472847, 0.09003057]])
244+
245+
"""
246+
# Avoid circular import
247+
248+
inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
249+
new_inputs = [vectorize.get(inp, inp) for inp in inputs]
250+
251+
def transform(var):
252+
if var in inputs:
253+
return new_inputs[inputs.index(var)]
254+
255+
node = var.owner
256+
batched_inputs = [transform(inp) for inp in node.inputs]
257+
batched_node = vectorize_node(node, *batched_inputs)
258+
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
259+
260+
return batched_var
261+
262+
# TODO: MergeOptimization or node caching?
263+
return [transform(out) for out in outputs]

pytensor/scalar/loop.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
from typing import Optional, Sequence, Tuple
33

44
from pytensor.compile import rebuild_collect_shared
5-
from pytensor.graph import Constant, FunctionGraph, Variable, clone
5+
from pytensor.graph.basic import Constant, Variable, clone
6+
from pytensor.graph.fg import FunctionGraph
67
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar
78

89

pytensor/tensor/blockwise.py

+14-44
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import re
2-
from functools import singledispatch
32
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
43

54
import numpy as np
@@ -9,6 +8,7 @@
98
from pytensor.graph.basic import Apply, Constant, Variable
109
from pytensor.graph.null_type import NullType
1110
from pytensor.graph.op import Op
11+
from pytensor.graph.replace import _vectorize_node, vectorize
1212
from pytensor.tensor import as_tensor_variable
1313
from pytensor.tensor.shape import shape_padleft
1414
from pytensor.tensor.type import continuous_dtypes, discrete_dtypes, tensor
@@ -72,8 +72,8 @@ def operand_sig(operand: Variable, prefix: str) -> str:
7272
return f"{inputs_sig}->{outputs_sig}"
7373

7474

75-
@singledispatch
76-
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
75+
@_vectorize_node.register(Op)
76+
def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
7777
if hasattr(op, "gufunc_signature"):
7878
signature = op.gufunc_signature
7979
else:
@@ -83,12 +83,6 @@ def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
8383
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
8484

8585

86-
def vectorize_node(node: Apply, *batched_inputs) -> Apply:
87-
"""Returns vectorized version of node with new batched inputs."""
88-
op = node.op
89-
return _vectorize_node(op, node, *batched_inputs)
90-
91-
9286
class Blockwise(Op):
9387
"""Generalizes a core `Op` to work with batched dimensions.
9488
@@ -279,42 +273,18 @@ def as_core(t, core_t):
279273

280274
core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)
281275

282-
batch_ndims = self._batch_ndim_from_outputs(outputs)
283-
284-
def transform(var):
285-
# From a graph of ScalarOps, make a graph of Broadcast ops.
286-
if isinstance(var.type, (NullType, DisconnectedType)):
287-
return var
288-
if var in core_inputs:
289-
return inputs[core_inputs.index(var)]
290-
if var in core_outputs:
291-
return outputs[core_outputs.index(var)]
292-
if var in core_ograds:
293-
return ograds[core_ograds.index(var)]
294-
295-
node = var.owner
296-
297-
# The gradient contains a constant, which may be responsible for broadcasting
298-
if node is None:
299-
if batch_ndims:
300-
var = shape_padleft(var, batch_ndims)
301-
return var
302-
303-
batched_inputs = [transform(inp) for inp in node.inputs]
304-
batched_node = vectorize_node(node, *batched_inputs)
305-
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
306-
307-
return batched_var
308-
309-
ret = []
310-
for core_igrad, ipt in zip(core_igrads, inputs):
311-
# Undefined gradient
312-
if core_igrad is None:
313-
ret.append(None)
314-
else:
315-
ret.append(transform(core_igrad))
276+
igrads = vectorize(
277+
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
278+
vectorize=dict(
279+
zip(core_inputs + core_outputs + core_ograds, inputs + outputs + ograds)
280+
),
281+
)
316282

317-
return ret
283+
igrads_iter = iter(igrads)
284+
return [
285+
None if core_igrad is None else next(igrads_iter)
286+
for core_igrad in core_igrads
287+
]
318288

319289
def L_op(self, inputs, outs, ograds):
320290
from pytensor.tensor.math import sum as pt_sum

pytensor/tensor/elemwise.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pytensor.gradient import DisconnectedType
99
from pytensor.graph.basic import Apply
1010
from pytensor.graph.null_type import NullType
11+
from pytensor.graph.replace import _vectorize_node
1112
from pytensor.graph.utils import MethodNotDefined
1213
from pytensor.link.c.basic import failure_code
1314
from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
@@ -22,7 +23,7 @@
2223
from pytensor.tensor import elemwise_cgen as cgen
2324
from pytensor.tensor import get_vector_length
2425
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
25-
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
26+
from pytensor.tensor.blockwise import vectorize_not_needed
2627
from pytensor.tensor.type import (
2728
TensorType,
2829
continuous_dtypes,

pytensor/tensor/random/op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pytensor.configdefaults import config
88
from pytensor.graph.basic import Apply, Variable, equal_computations
99
from pytensor.graph.op import Op
10+
from pytensor.graph.replace import _vectorize_node
1011
from pytensor.misc.safe_asarray import _asarray
1112
from pytensor.scalar import ScalarVariable
1213
from pytensor.tensor.basic import (
@@ -17,7 +18,6 @@
1718
get_vector_length,
1819
infer_static_shape,
1920
)
20-
from pytensor.tensor.blockwise import _vectorize_node
2121
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
2222
from pytensor.tensor.random.utils import (
2323
broadcast_params,

pytensor/tensor/rewriting/blockwise.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from pytensor.compile.mode import optdb
22
from pytensor.graph import node_rewriter
3+
from pytensor.graph.replace import vectorize_node
34
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
4-
from pytensor.tensor.blockwise import Blockwise, vectorize_node
5+
from pytensor.tensor.blockwise import Blockwise
56

67

78
@node_rewriter([Blockwise])

tests/graph/test_replace.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
22
import pytest
3+
import scipy.special
34

45
import pytensor.tensor as pt
56
from pytensor import config, function, shared
67
from pytensor.graph.basic import graph_inputs
7-
from pytensor.graph.replace import clone_replace, graph_replace
8+
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
89
from pytensor.tensor import dvector, fvector, vector
910
from tests import unittest_tools as utt
1011
from tests.graph.utils import MyOp, MyVariable
@@ -223,3 +224,21 @@ def test_graph_replace_disconnected(self):
223224
assert oc[0] is o
224225
with pytest.raises(ValueError, match="Some replacements were not used"):
225226
oc = graph_replace([o], {fake: x.clone()}, strict=True)
227+
228+
229+
class TestVectorize:
230+
# TODO: Add tests with multiple outputs, constants, and other singleton types
231+
232+
def test_basic(self):
233+
x = pt.vector("x")
234+
y = pt.exp(x) / pt.sum(pt.exp(x))
235+
236+
new_x = pt.matrix("new_x")
237+
[new_y] = vectorize([y], {x: new_x})
238+
239+
fn = function([new_x], new_y)
240+
test_new_y = np.array([[0, 1, 2], [2, 1, 0]]).astype(config.floatX)
241+
np.testing.assert_allclose(
242+
fn(test_new_y),
243+
scipy.special.softmax(test_new_y, axis=-1),
244+
)

tests/tensor/random/test_op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytensor.tensor as at
55
from pytensor import config, function
66
from pytensor.gradient import NullTypeGradError, grad
7+
from pytensor.graph.replace import vectorize_node
78
from pytensor.raise_op import Assert
8-
from pytensor.tensor.blockwise import vectorize_node
99
from pytensor.tensor.math import eq
1010
from pytensor.tensor.random import normal
1111
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng

tests/tensor/test_blockwise.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from pytensor import config
99
from pytensor.gradient import grad
1010
from pytensor.graph import Apply, Op
11+
from pytensor.graph.replace import vectorize_node
1112
from pytensor.tensor import tensor
12-
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature, vectorize_node
13+
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
1314
from pytensor.tensor.nlinalg import MatrixInverse
1415
from pytensor.tensor.slinalg import Cholesky, Solve
1516

tests/tensor/test_elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from pytensor.configdefaults import config
1414
from pytensor.graph.basic import Apply, Variable
1515
from pytensor.graph.fg import FunctionGraph
16+
from pytensor.graph.replace import vectorize_node
1617
from pytensor.link.basic import PerformLinker
1718
from pytensor.link.c.basic import CLinker, OpWiseCLinker
1819
from pytensor.tensor import as_tensor_variable
1920
from pytensor.tensor.basic import second
20-
from pytensor.tensor.blockwise import vectorize_node
2121
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
2222
from pytensor.tensor.math import Any, Sum
2323
from pytensor.tensor.math import all as pt_all

0 commit comments

Comments
 (0)