Skip to content

Commit c8f5650

Browse files
committed
Add Numba implementation of Blockwise
1 parent fa0ab9d commit c8f5650

File tree

9 files changed

+261
-10
lines changed

9 files changed

+261
-10
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
33

44
# Load dispatch specializations
5-
import pytensor.link.numba.dispatch.scalar
6-
import pytensor.link.numba.dispatch.tensor_basic
5+
import pytensor.link.numba.dispatch.blockwise
6+
import pytensor.link.numba.dispatch.elemwise
77
import pytensor.link.numba.dispatch.extra_ops
88
import pytensor.link.numba.dispatch.nlinalg
99
import pytensor.link.numba.dispatch.random
10-
import pytensor.link.numba.dispatch.elemwise
1110
import pytensor.link.numba.dispatch.scan
12-
import pytensor.link.numba.dispatch.sparse
11+
import pytensor.link.numba.dispatch.scalar
1312
import pytensor.link.numba.dispatch.slinalg
13+
import pytensor.link.numba.dispatch.sparse
1414
import pytensor.link.numba.dispatch.subtensor
15+
import pytensor.link.numba.dispatch.tensor_basic
1516

1617
# isort: on
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from numba.core.extending import overload
2+
from numba.np.unsafe.ndarray import to_fixed_tuple
3+
4+
from pytensor.link.numba.dispatch.basic import numba_funcify
5+
from pytensor.link.numba.dispatch.vectorize_codegen import (
6+
_jit_options,
7+
_vectorized,
8+
encode_literals,
9+
store_core_outputs,
10+
)
11+
from pytensor.tensor import get_vector_length
12+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
13+
14+
15+
@numba_funcify.register
16+
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
17+
[blockwise_node] = op.fgraph.apply_nodes
18+
blockwise_op: Blockwise = blockwise_node.op
19+
core_op = blockwise_op.core_op
20+
nin = len(blockwise_node.inputs)
21+
nout = len(blockwise_node.outputs)
22+
core_shapes_len = [get_vector_length(sh) for sh in node.inputs[nin:]]
23+
core_shape_0 = core_shapes_len[0] if nout > 0 else None
24+
core_shape_1 = core_shapes_len[1] if nout > 1 else None
25+
core_shape_2 = core_shapes_len[2] if nout > 2 else None
26+
if nout > 3:
27+
raise NotImplementedError(
28+
"Blockwise with more than 3 outputs not supported in Numba backend"
29+
)
30+
31+
core_node = blockwise_op._create_dummy_core_node(blockwise_node.inputs)
32+
core_op_fn = numba_funcify(
33+
core_op,
34+
node=core_node,
35+
parent_node=node,
36+
fastmath=_jit_options["fastmath"],
37+
**kwargs,
38+
)
39+
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
40+
41+
batch_ndim = blockwise_op.batch_ndim(node)
42+
43+
# numba doesn't support nested literals right now...
44+
input_bc_patterns = encode_literals(
45+
tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs)
46+
)
47+
output_bc_patterns = encode_literals(
48+
tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs)
49+
)
50+
output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs))
51+
inplace_pattern = encode_literals(())
52+
# inplace = rv_op.inplace
53+
54+
def blockwise_wrapper(*inputs_and_core_shapes):
55+
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
56+
# Appease numba Gods :(
57+
# Secular solution welcomed
58+
if nout == 1:
59+
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),)
60+
elif nout == 2:
61+
tuple_core_shapes = (
62+
to_fixed_tuple(core_shapes[0], core_shape_0),
63+
to_fixed_tuple(core_shapes[1], core_shape_1),
64+
)
65+
else:
66+
tuple_core_shapes = (
67+
to_fixed_tuple(core_shapes[0], core_shape_0),
68+
to_fixed_tuple(core_shapes[1], core_shape_1),
69+
to_fixed_tuple(core_shapes[2], core_shape_2),
70+
)
71+
return _vectorized(
72+
core_op_fn,
73+
input_bc_patterns,
74+
output_bc_patterns,
75+
output_dtypes,
76+
inplace_pattern,
77+
(), # constant_inputs
78+
inputs,
79+
tuple_core_shapes,
80+
None, # size
81+
)
82+
83+
def blockwise(*inputs_and_core_shapes):
84+
raise NotImplementedError("Non-jitted blockwise not implemented")
85+
86+
@overload(blockwise, jit_options=_jit_options)
87+
def ov_blockwise(*inputs_and_core_shapes):
88+
return blockwise_wrapper
89+
90+
return blockwise

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def elemwise_wrapper(*inputs):
508508
inplace_pattern_enc,
509509
(), # constant_inputs
510510
inputs,
511-
core_output_shapes, # core_shapes
511+
core_output_shapes,
512512
None, # size
513513
)
514514

pytensor/link/numba/dispatch/random.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params):
388388
return rng, draws
389389

390390
def random(core_shape, rng, size, *dist_params):
391-
pass
391+
raise NotImplementedError("Non-jitted random variable not implemented")
392392

393393
@overload(random, jit_options=_jit_options)
394394
def ov_random(core_shape, rng, size, *dist_params):

pytensor/tensor/blockwise.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from pytensor import config
88
from pytensor.compile.builders import OpFromGraph
99
from pytensor.gradient import DisconnectedType
10-
from pytensor.graph.basic import Apply, Constant
10+
from pytensor.graph import FunctionGraph
11+
from pytensor.graph.basic import Apply, Constant, ancestors
1112
from pytensor.graph.null_type import NullType
1213
from pytensor.graph.op import Op
1314
from pytensor.graph.replace import (
@@ -179,16 +180,39 @@ def infer_shape(
179180

180181
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
181182

183+
# Try to extract the core shapes from the core_op
184+
if hasattr(self.core_op, "infer_shape"):
185+
dummy_core_node = self._create_dummy_core_node(node.inputs)
186+
dummy_core_inputs = dummy_core_node.inputs
187+
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
188+
core_input_shapes = [
189+
input_shape[batch_ndims:] for input_shape in input_shapes
190+
]
191+
core_output_shapes = self.core_op.infer_shape(
192+
dummy_fgraph, dummy_core_node, core_input_shapes
193+
)
194+
182195
out_shapes = []
183-
for output, sig in zip(node.outputs, self.outputs_sig):
196+
for o, (output, sig) in enumerate(zip(node.outputs, self.outputs_sig)):
184197
core_out_shape = []
185198
for i, dim_name in enumerate(sig):
186199
# The output dim is the same as another input dim
187200
if dim_name in core_dims:
188201
core_out_shape.append(core_dims[dim_name])
189202
else:
190-
# TODO: We could try to make use of infer_shape of core_op
203+
if hasattr(self.core_op, "infer_shape"):
204+
# If the input values are needed to compute the dimension length, we can't use the infer_shape
205+
# of the core_node as the value is not constant across batch dims of the Blockwise
206+
core_out_dim = core_output_shapes[o][i]
207+
if not (
208+
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
209+
):
210+
core_out_shape.append(core_out_dim)
211+
continue
212+
213+
# Fallback shape requires evaluating the Blockwise Op
191214
core_out_shape.append(Shape_i(batch_ndims + i)(output))
215+
192216
out_shapes.append((*batch_shape, *core_out_shape))
193217

194218
return out_shapes
@@ -379,3 +403,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
379403

380404
class OpWithCoreShape(OpFromGraph):
381405
"""Generalizes an `Op` to include core shape as an additional input."""
406+
407+
408+
class BlockwiseWithCoreShape(OpWithCoreShape):
409+
"""Generalizes a Blockwise `Op` to include a core shape parameter."""
410+
411+
def __str__(self):
412+
[blockwise_node] = self.fgraph.apply_nodes
413+
return f"[{blockwise_node.op!s}]"

pytensor/tensor/rewriting/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.tensor.rewriting.jax
1010
import pytensor.tensor.rewriting.linalg
1111
import pytensor.tensor.rewriting.math
12+
import pytensor.tensor.rewriting.numba
1213
import pytensor.tensor.rewriting.ofg
1314
import pytensor.tensor.rewriting.shape
1415
import pytensor.tensor.rewriting.special

pytensor/tensor/rewriting/numba.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.basic import applys_between
4+
from pytensor.graph.rewriting.basic import out2in
5+
from pytensor.tensor.basic import as_tensor, constant
6+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
7+
8+
9+
@node_rewriter([Blockwise])
10+
def introduce_explicit_core_shape_blockwise(fgraph, node):
11+
"""Introduce the core shape of a Blockwise.
12+
13+
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph
14+
that has an extra "non-functional" input that represents the core shape of the Blockwise variable.
15+
This core_shape is used by the numba backend to pre-allocate the output array.
16+
17+
If available, the core shape is extracted from the shape feature of the graph,
18+
which has a higher change of having been simplified, optimized, constant-folded.
19+
If missing, we fall back to the op._supp_shape_from_params method.
20+
21+
This rewrite is required for the numba backend implementation of Blockwise.
22+
23+
Example
24+
-------
25+
26+
.. code-block:: python
27+
28+
import pytensor
29+
import pytensor.tensor as pt
30+
31+
x = pt.random.dirichlet(alphas=[1, 2, 3], size=(5,))
32+
pytensor.dprint(x, print_type=True)
33+
# dirichlet_rv{"(a)->(a)"}.1 [id A] <Matrix(float64, shape=(5, 3))>
34+
# ├─ RNG(<Generator(PCG64) at 0x7F09E59C18C0>) [id B] <RandomGeneratorType>
35+
# ├─ [5] [id C] <Vector(int64, shape=(1,))>
36+
# └─ ExpandDims{axis=0} [id D] <Matrix(int64, shape=(1, 3))>
37+
# └─ [1 2 3] [id E] <Vector(int64, shape=(3,))>
38+
39+
# After the rewrite, note the new core shape input [3] [id B]
40+
fn = pytensor.function([], x, mode="NUMBA")
41+
pytensor.dprint(fn.maker.fgraph)
42+
# [dirichlet_rv{"(a)->(a)"}].1 [id A] 0
43+
# ├─ [3] [id B]
44+
# ├─ RNG(<Generator(PCG64) at 0x7F15B8E844A0>) [id C]
45+
# ├─ [5] [id D]
46+
# └─ [[1 2 3]] [id E]
47+
# Inner graphs:
48+
# [dirichlet_rv{"(a)->(a)"}] [id A]
49+
# ← dirichlet_rv{"(a)->(a)"}.0 [id F]
50+
# ├─ *1-<RandomGeneratorType> [id G]
51+
# ├─ *2-<Vector(int64, shape=(1,))> [id H]
52+
# └─ *3-<Matrix(int64, shape=(1, 3))> [id I]
53+
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
54+
# └─ ···
55+
"""
56+
op: Blockwise = node.op # type: ignore[annotation-unchecked]
57+
batch_ndim = op.batch_ndim(node)
58+
59+
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
60+
if shape_feature:
61+
core_shapes = [
62+
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
63+
for out in node.outputs
64+
]
65+
else:
66+
raise ValueError
67+
core_shapes = op._supp_shape_from_params(op.dist_params(node))
68+
69+
core_shapes = [
70+
as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64")
71+
for core_shape in core_shapes
72+
]
73+
74+
if any(
75+
isinstance(node.op, Blockwise)
76+
for node in applys_between(node.inputs, core_shapes)
77+
):
78+
# If Blockwise shows up in the shape graph we can't introduce the core shape
79+
return None
80+
81+
return (
82+
BlockwiseWithCoreShape(
83+
[*node.inputs, *core_shapes],
84+
node.outputs,
85+
destroy_map=op.destroy_map,
86+
)
87+
.make_node(*node.inputs, *core_shapes)
88+
.outputs
89+
)
90+
91+
92+
optdb.register(
93+
introduce_explicit_core_shape_blockwise.__name__,
94+
out2in(introduce_explicit_core_shape_blockwise),
95+
"numba",
96+
position=100,
97+
)

tests/link/numba/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def compare_numba_and_py(
242242
Parameters
243243
----------
244244
fgraph
245-
`FunctionGraph` or inputs to compare.
245+
`FunctionGraph` or tuple(inputs, outputs) to compare.
246246
inputs
247247
Numeric inputs to be passed to the compiled graphs.
248248
assert_fn

tests/link/numba/test_blockwise.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import pytest
3+
from link.numba.test_basic import compare_numba_and_py, numba_mode
4+
5+
from pytensor.tensor import tensor
6+
from pytensor.tensor.blockwise import Blockwise
7+
from pytensor.tensor.nlinalg import SVD, Det
8+
from pytensor.tensor.slinalg import Cholesky
9+
10+
11+
# Fails if object mode warning is issued
12+
pytestmark = pytest.mark.filterwarnings("error")
13+
14+
# TODO: Test inplace
15+
# TODO: Test non rectangular fails gracefully
16+
17+
18+
@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str)
19+
def test_blockwise(core_op):
20+
x = tensor(shape=(5, None, None))
21+
outs = Blockwise(core_op=core_op)(x, return_list=True)
22+
23+
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
24+
fn, _ = compare_numba_and_py(
25+
([x], outs),
26+
[x_test],
27+
numba_mode=numba_mode.including("ShapeOpt"),
28+
eval_obj_mode=False,
29+
)
30+
fn.dprint(print_type=True)

0 commit comments

Comments
 (0)