Skip to content

Implement graph.vectorize and Blockwise Op #306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,13 +251,14 @@ def apply(self, fgraph):
# especially constant merge
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)

optdb.register("py_only", EquilibriumDB(), "fast_compile", position=49.1)

optdb.register(
"add_destroy_handler", AddDestroyHandler(), "fast_run", "inplace", position=49.5
)

# final pass just to make sure
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
optdb.register("py_only", EquilibriumDB(), "fast_compile", position=100)

_tags: Union[Tuple[str, str], Tuple]

Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
clone,
ancestors,
)
from pytensor.graph.replace import clone_replace, graph_replace
from pytensor.graph.replace import clone_replace, graph_replace, vectorize
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from pytensor.graph.fg import FunctionGraph
Expand Down
67 changes: 65 additions & 2 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from functools import partial
from typing import Iterable, Optional, Sequence, Union, cast, overload
from functools import partial, singledispatch
from typing import Iterable, Mapping, Optional, Sequence, Union, cast, overload

from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op


ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
Expand Down Expand Up @@ -198,3 +199,65 @@ def toposort_key(
return list(fg.outputs)
else:
return fg.outputs[0]


@singledispatch
def _vectorize_node(op: Op, node: Apply, *bached_inputs) -> Apply:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)


def vectorize(
outputs: Sequence[Variable], vectorize: Mapping[Variable, Variable]
) -> Sequence[Variable]:
"""Vectorize outputs graph given mapping from old variables to expanded counterparts version.

Expanded dimensions must be on the left. Behavior is similar to the functional `numpy.vectorize`.

Examples
--------
.. code-block:: python

import pytensor
import pytensor.tensor as pt

from pytensor.graph import vectorize

# Original graph
x = pt.vector("x")
y = pt.exp(x) / pt.sum(pt.exp(x))

# Vectorized graph
new_x = pt.matrix("new_x")
[new_y] = vectorize([y], {x: new_x})

fn = pytensor.function([new_x], new_y)
fn([[0, 1, 2], [2, 1, 0]])
# array([[0.09003057, 0.24472847, 0.66524096],
# [0.66524096, 0.24472847, 0.09003057]])

"""
# Avoid circular import

inputs = truncated_graph_inputs(outputs, ancestors_to_include=vectorize.keys())
new_inputs = [vectorize.get(inp, inp) for inp in inputs]

def transform(var):
if var in inputs:
return new_inputs[inputs.index(var)]

node = var.owner
batched_inputs = [transform(inp) for inp in node.inputs]
batched_node = vectorize_node(node, *batched_inputs)
batched_var = batched_node.outputs[var.owner.outputs.index(var)]

return batched_var

# TODO: MergeOptimization or node caching?
return [transform(out) for out in outputs]
13 changes: 0 additions & 13 deletions pytensor/link/numba/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Det,
Eig,
Eigh,
Inv,
MatrixInverse,
MatrixPinv,
QRFull,
Expand Down Expand Up @@ -125,18 +124,6 @@ def eigh(x):
return eigh


@numba_funcify.register(Inv)
def numba_funcify_Inv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)

@numba_basic.numba_njit(inline="always")
def inv(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)

return inv


@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
Expand Down
3 changes: 2 additions & 1 deletion pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Optional, Sequence, Tuple

from pytensor.compile import rebuild_collect_shared
from pytensor.graph import Constant, FunctionGraph, Variable, clone
from pytensor.graph.basic import Constant, Variable, clone
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import ScalarInnerGraphOp, as_scalar


Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3764,7 +3764,7 @@ def stacklists(arg):
return arg


def swapaxes(y, axis1, axis2):
def swapaxes(y, axis1: int, axis2: int) -> TensorVariable:
"Swap the axes of a tensor."
y = as_tensor_variable(y)
ndim = y.ndim
Expand Down
Loading