Skip to content

Commit 0ff0f29

Browse files
ricardoV94brandonwillardpurna135Sayam Kumarkc611
committed
Implement Blockwise Op to vectorize existing Ops
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Purna Chandra Mansingh <[email protected]> Co-authored-by: Sayam Kumar <[email protected]> Co-authored-by: Kaustubh <[email protected]>
1 parent f49b2cc commit 0ff0f29

File tree

10 files changed

+966
-46
lines changed

10 files changed

+966
-46
lines changed

pytensor/tensor/blockwise.py

+413
Large diffs are not rendered by default.

pytensor/tensor/elemwise.py

+39-34
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@
2222
from pytensor.tensor import elemwise_cgen as cgen
2323
from pytensor.tensor import get_vector_length
2424
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
25+
from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed
2526
from pytensor.tensor.type import (
2627
TensorType,
2728
continuous_dtypes,
2829
discrete_dtypes,
2930
float_dtypes,
3031
lvector,
3132
)
33+
from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string
3234
from pytensor.tensor.variable import TensorVariable
3335
from pytensor.utils import uniq
3436

@@ -232,7 +234,7 @@ def __str__(self):
232234
return f"Transpose{{axes={self.shuffle}}}"
233235
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
234236

235-
def perform(self, node, inp, out, params):
237+
def perform(self, node, inp, out, params=None):
236238
(res,) = inp
237239
(storage,) = out
238240

@@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs):
429431
# of all inputs in parallel... the all() gives us each output
430432
# broadcastable bit in turn.
431433

432-
def get_most_specialized_shape(shapes):
433-
shapes = set(shapes)
434-
# All shapes are the same
435-
if len(shapes) == 1:
436-
return tuple(shapes)[0]
437-
438-
# Only valid indeterminate case
439-
if shapes == {None, 1}:
440-
return None
441-
442-
shapes.discard(1)
443-
shapes.discard(None)
444-
if len(shapes) > 1:
445-
raise ValueError
446-
return tuple(shapes)[0]
447-
448434
# it is multiplied by nout because Elemwise supports multiple outputs
449435
# (nout of them)
450436
try:
451437
out_shapes = [
452438
[
453-
get_most_specialized_shape(shape)
439+
broadcast_static_dim_lengths(shape)
454440
for shape in zip(*[inp.type.shape for inp in inputs])
455441
]
456442
] * shadow.nout
@@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
665651
impl = "c"
666652

667653
if getattr(self, "nfunc_spec", None) and impl != "c":
668-
self.nfunc = getattr(np, self.nfunc_spec[0], None)
669-
if self.nfunc is None:
670-
# Not inside NumPy. So probably another package like scipy.
671-
symb = self.nfunc_spec[0].split(".")
672-
for idx in range(1, len(self.nfunc_spec[0])):
673-
try:
674-
module = __import__(".".join(symb[:idx]))
675-
except ImportError:
676-
break
677-
for sub in symb[1:]:
678-
try:
679-
module = getattr(module, sub)
680-
except AttributeError:
681-
module = None
682-
break
683-
self.nfunc = module
654+
self.nfunc = import_func_from_string(self.nfunc_spec[0])
684655

685656
if (
686657
(len(node.inputs) + len(node.outputs)) <= 32
@@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var):
17681739
return get_vector_length(var.owner.inputs[0])
17691740

17701741
raise ValueError(f"Length of {var} cannot be determined")
1742+
1743+
1744+
_vectorize_node.register(Elemwise, vectorize_not_needed)
1745+
1746+
1747+
@_vectorize_node.register(DimShuffle)
1748+
def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply:
1749+
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1750+
if not batched_ndims:
1751+
return node.op.make_node(x)
1752+
input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable
1753+
# e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2))
1754+
# e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x"))
1755+
new_order = list(range(batched_ndims)) + [
1756+
"x" if (o == "x") else (o + batched_ndims) for o in op.new_order
1757+
]
1758+
return DimShuffle(input_broadcastable, new_order).make_node(x)
1759+
1760+
1761+
@_vectorize_node.register(CAReduce)
1762+
def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply:
1763+
batched_ndims = x.type.ndim - node.inputs[0].type.ndim
1764+
if not batched_ndims:
1765+
return node.op.make_node(x)
1766+
axes = op.axis
1767+
# e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3))
1768+
# e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,))
1769+
if axes is None:
1770+
axes = list(range(node.inputs[0].type.ndim))
1771+
else:
1772+
axes = list(axes)
1773+
new_axes = [axis + batched_ndims for axis in axes]
1774+
new_op = op.clone(axis=new_axes)
1775+
return new_op.make_node(x)

pytensor/tensor/random/op.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,25 @@
55

66
import pytensor
77
from pytensor.configdefaults import config
8-
from pytensor.graph.basic import Apply, Variable
8+
from pytensor.graph.basic import Apply, Variable, equal_computations
99
from pytensor.graph.op import Op
1010
from pytensor.misc.safe_asarray import _asarray
1111
from pytensor.scalar import ScalarVariable
1212
from pytensor.tensor.basic import (
1313
as_tensor_variable,
14+
concatenate,
1415
constant,
1516
get_underlying_scalar_constant_value,
1617
get_vector_length,
1718
infer_static_shape,
1819
)
20+
from pytensor.tensor.blockwise import _vectorize_node
1921
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
20-
from pytensor.tensor.random.utils import normalize_size_param, params_broadcast_shapes
22+
from pytensor.tensor.random.utils import (
23+
broadcast_params,
24+
normalize_size_param,
25+
params_broadcast_shapes,
26+
)
2127
from pytensor.tensor.shape import shape_tuple
2228
from pytensor.tensor.type import TensorType, all_dtypes
2329
from pytensor.tensor.type_other import NoneConst
@@ -383,3 +389,22 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):
383389

384390

385391
default_rng = DefaultGeneratorMakerOp()
392+
393+
394+
@_vectorize_node.register(RandomVariable)
395+
def vectorize_random_variable(
396+
op: RandomVariable, node: Apply, rng, size, dtype, *dist_params
397+
) -> Apply:
398+
# If size was provided originally and a new size hasn't been provided,
399+
# We extend it to accommodate the new input batch dimensions.
400+
# Otherwise, we assume the new size already has the right values
401+
old_size = node.inputs[1]
402+
len_old_size = get_vector_length(old_size)
403+
if len_old_size and equal_computations([old_size], [size]):
404+
bcasted_param = broadcast_params(dist_params, op.ndims_params)[0]
405+
new_param_ndim = (bcasted_param.type.ndim - op.ndims_params[0]) - len_old_size
406+
if new_param_ndim >= 0:
407+
new_size_dims = bcasted_param.shape[:new_param_ndim]
408+
size = concatenate([new_size_dims, size])
409+
410+
return op.make_node(rng, size, dtype, *dist_params)

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytensor.tensor.rewriting.blas
33
import pytensor.tensor.rewriting.blas_c
44
import pytensor.tensor.rewriting.blas_scipy
5+
import pytensor.tensor.rewriting.blockwise
56
import pytensor.tensor.rewriting.elemwise
67
import pytensor.tensor.rewriting.extra_ops
78

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from pytensor.compile.mode import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
4+
from pytensor.tensor.blockwise import Blockwise, vectorize_node
5+
6+
7+
@node_rewriter([Blockwise])
8+
def local_useless_blockwise(fgraph, node):
9+
"""
10+
If there is a dispatch implementation that does not require Blockwise, use that instead.
11+
This means a user created a Blockwise manually when there was no need.
12+
13+
Note: This rewrite is not registered by default anywhere
14+
"""
15+
op = node.op
16+
inputs = node.inputs
17+
dummy_core_node = op._create_dummy_core_node(node.inputs)
18+
vect_node = vectorize_node(dummy_core_node, *inputs)
19+
if not isinstance(vect_node.op, Blockwise):
20+
return copy_stack_trace(node.outputs, vect_node.outputs)
21+
22+
23+
@node_rewriter([Blockwise])
24+
def local_useless_unbatched_blockwise(fgraph, node):
25+
"""Remove Blockwise that don't have any batched dims."""
26+
op = node.op
27+
inputs = node.inputs
28+
29+
if max(inp.type.ndim - len(sig) for inp, sig in zip(inputs, op.inputs_sig)) == 0:
30+
return copy_stack_trace(node.outputs, op.core_op.make_node(*inputs).outputs)
31+
32+
33+
# We register this rewrite late, so that other rewrites need only target Blockwise Ops
34+
optdb.register(
35+
"local_useless_unbatched_blockwise",
36+
out2in(local_useless_unbatched_blockwise, ignore_newtrees=True),
37+
"fast_run",
38+
"fast_compile",
39+
"blockwise",
40+
position=49,
41+
)

pytensor/tensor/utils.py

+53
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Sequence, Union
2+
13
import numpy as np
24

35
import pytensor
@@ -107,3 +109,54 @@ def as_list(x):
107109
return list(x)
108110
except TypeError:
109111
return [x]
112+
113+
114+
def import_func_from_string(func_string: str): # -> Optional[Callable]:
115+
func = getattr(np, func_string, None)
116+
if func is not None:
117+
return func
118+
119+
# Not inside NumPy or Scipy. So probably another package like scipy.
120+
module = None
121+
items = func_string.split(".")
122+
for idx in range(1, len(items)):
123+
try:
124+
module = __import__(".".join(items[:idx]))
125+
except ImportError:
126+
break
127+
128+
if module:
129+
for sub in items[1:]:
130+
try:
131+
module = getattr(module, sub)
132+
except AttributeError:
133+
module = None
134+
break
135+
return module
136+
137+
138+
def broadcast_static_dim_lengths(
139+
dim_lengths: Sequence[Union[int, None]]
140+
) -> Union[int, None]:
141+
"""Apply static broadcast given static dim length of inputs (obtained from var.type.shape).
142+
143+
Raises
144+
------
145+
ValueError
146+
When static dim lengths are incompatible
147+
"""
148+
149+
dim_lengths_set = set(dim_lengths)
150+
# All dim_lengths are the same
151+
if len(dim_lengths_set) == 1:
152+
return tuple(dim_lengths_set)[0]
153+
154+
# Only valid indeterminate case
155+
if dim_lengths_set == {None, 1}:
156+
return None
157+
158+
dim_lengths_set.discard(1)
159+
dim_lengths_set.discard(None)
160+
if len(dim_lengths_set) > 1:
161+
raise ValueError
162+
return tuple(dim_lengths_set)[0]

tests/tensor/random/test_op.py

+36
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from pytensor import config, function
66
from pytensor.gradient import NullTypeGradError, grad
77
from pytensor.raise_op import Assert
8+
from pytensor.tensor.blockwise import vectorize_node
89
from pytensor.tensor.math import eq
10+
from pytensor.tensor.random import normal
911
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
1012
from pytensor.tensor.shape import specify_shape
1113
from pytensor.tensor.type import all_dtypes, iscalar, tensor
@@ -202,3 +204,37 @@ def test_RandomVariable_incompatible_size():
202204
ValueError, match="Size length is incompatible with batched dimensions"
203205
):
204206
rv_op(np.zeros((2, 4, 3)), 1, size=(4,))
207+
208+
209+
def test_vectorize_node():
210+
vec = tensor(shape=(None,))
211+
vec.tag.test_value = [0, 0, 0]
212+
mat = tensor(shape=(None, None))
213+
mat.tag.test_value = [[0, 0, 0], [1, 1, 1]]
214+
215+
# Test without size
216+
node = normal(vec).owner
217+
new_inputs = node.inputs.copy()
218+
new_inputs[3] = mat
219+
vect_node = vectorize_node(node, *new_inputs)
220+
assert vect_node.op is normal
221+
assert vect_node.inputs[3] is mat
222+
223+
# Test with size, new size provided
224+
node = normal(vec, size=(3,)).owner
225+
new_inputs = node.inputs.copy()
226+
new_inputs[1] = (2, 3)
227+
new_inputs[3] = mat
228+
vect_node = vectorize_node(node, *new_inputs)
229+
assert vect_node.op is normal
230+
assert tuple(vect_node.inputs[1].eval()) == (2, 3)
231+
assert vect_node.inputs[3] is mat
232+
233+
# Test with size, new size not provided
234+
node = normal(vec, size=(3,)).owner
235+
new_inputs = node.inputs.copy()
236+
new_inputs[3] = mat
237+
vect_node = vectorize_node(node, *new_inputs)
238+
assert vect_node.op is normal
239+
assert vect_node.inputs[3] is mat
240+
assert tuple(vect_node.inputs[1].eval({mat: mat.tag.test_value})) == (2, 3)
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from pytensor import function
2+
from pytensor.graph import FunctionGraph
3+
from pytensor.scalar import log as scalar_log
4+
from pytensor.tensor import matrix, tensor3
5+
from pytensor.tensor.blockwise import Blockwise
6+
from pytensor.tensor.elemwise import Elemwise
7+
from pytensor.tensor.nlinalg import MatrixPinv
8+
from pytensor.tensor.rewriting.blockwise import local_useless_blockwise
9+
10+
11+
def test_useless_blockwise_of_elemwise():
12+
x = matrix("x")
13+
out = Blockwise(Elemwise(scalar_log), signature="()->()")(x)
14+
assert isinstance(out.owner.op, Blockwise)
15+
assert isinstance(out.owner.op.core_op, Elemwise)
16+
17+
fg = FunctionGraph([x], [out], clone=False)
18+
[new_out] = local_useless_blockwise.transform(fg, out.owner)
19+
assert isinstance(new_out.owner.op, Elemwise)
20+
21+
22+
def test_useless_unbatched_blockwise():
23+
x = matrix("x")
24+
blockwise_op = Blockwise(MatrixPinv(hermitian=False), signature="(m,n)->(n,m)")
25+
out = blockwise_op(x)
26+
27+
assert isinstance(out.owner.op, Blockwise)
28+
assert isinstance(out.owner.op.core_op, MatrixPinv)
29+
30+
fn = function([x], out, mode="FAST_COMPILE")
31+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, MatrixPinv)
32+
33+
# Test that it's not removed when there are batched dims
34+
x = tensor3("x")
35+
out = blockwise_op(x)
36+
fn = function([x], out, mode="FAST_COMPILE")
37+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
38+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op.core_op, MatrixPinv)

0 commit comments

Comments
 (0)