Skip to content

Commit 2014cd9

Browse files
committed
Add Numba implementation of Blockwise
Restricted to 3 outputs, due to limitations in jitting of Numba functions
1 parent 06d9a49 commit 2014cd9

File tree

8 files changed

+290
-6
lines changed

8 files changed

+290
-6
lines changed

pytensor/link/numba/dispatch/__init__.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
33

44
# Load dispatch specializations
5-
import pytensor.link.numba.dispatch.scalar
6-
import pytensor.link.numba.dispatch.tensor_basic
5+
import pytensor.link.numba.dispatch.blockwise
6+
import pytensor.link.numba.dispatch.elemwise
77
import pytensor.link.numba.dispatch.extra_ops
88
import pytensor.link.numba.dispatch.nlinalg
99
import pytensor.link.numba.dispatch.random
10-
import pytensor.link.numba.dispatch.elemwise
1110
import pytensor.link.numba.dispatch.scan
12-
import pytensor.link.numba.dispatch.sparse
11+
import pytensor.link.numba.dispatch.scalar
1312
import pytensor.link.numba.dispatch.slinalg
13+
import pytensor.link.numba.dispatch.sparse
1414
import pytensor.link.numba.dispatch.subtensor
15+
import pytensor.link.numba.dispatch.tensor_basic
1516

1617
# isort: on
+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from numba.core.extending import overload
2+
from numba.np.unsafe.ndarray import to_fixed_tuple
3+
4+
from pytensor.link.numba.dispatch.basic import numba_funcify
5+
from pytensor.link.numba.dispatch.vectorize_codegen import (
6+
_jit_options,
7+
_vectorized,
8+
encode_literals,
9+
store_core_outputs,
10+
)
11+
from pytensor.tensor import get_vector_length
12+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
13+
14+
15+
@numba_funcify.register
16+
def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
17+
[blockwise_node] = op.fgraph.apply_nodes
18+
blockwise_op: Blockwise = blockwise_node.op
19+
core_op = blockwise_op.core_op
20+
nin = len(blockwise_node.inputs)
21+
nout = len(blockwise_node.outputs)
22+
if nout > 3:
23+
raise NotImplementedError(
24+
"Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs."
25+
)
26+
27+
core_shapes_len = [get_vector_length(sh) for sh in node.inputs[nin:]]
28+
core_shape_0 = core_shapes_len[0] if nout > 0 else None
29+
core_shape_1 = core_shapes_len[1] if nout > 1 else None
30+
core_shape_2 = core_shapes_len[2] if nout > 2 else None
31+
32+
core_node = blockwise_op._create_dummy_core_node(blockwise_node.inputs)
33+
core_op_fn = numba_funcify(
34+
core_op,
35+
node=core_node,
36+
parent_node=node,
37+
fastmath=_jit_options["fastmath"],
38+
**kwargs,
39+
)
40+
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
41+
42+
batch_ndim = blockwise_op.batch_ndim(node)
43+
44+
# numba doesn't support nested literals right now...
45+
input_bc_patterns = encode_literals(
46+
tuple(inp.type.broadcastable[:batch_ndim] for inp in node.inputs)
47+
)
48+
output_bc_patterns = encode_literals(
49+
tuple(out.type.broadcastable[:batch_ndim] for out in node.outputs)
50+
)
51+
output_dtypes = encode_literals(tuple(out.type.dtype for out in node.outputs))
52+
inplace_pattern = encode_literals(())
53+
54+
def blockwise_wrapper(*inputs_and_core_shapes):
55+
inputs, core_shapes = inputs_and_core_shapes[:nin], inputs_and_core_shapes[nin:]
56+
# Appease numba Gods :(
57+
# Secular solution welcomed
58+
if nout == 1:
59+
tuple_core_shapes = (to_fixed_tuple(core_shapes[0], core_shape_0),)
60+
elif nout == 2:
61+
tuple_core_shapes = (
62+
to_fixed_tuple(core_shapes[0], core_shape_0),
63+
to_fixed_tuple(core_shapes[1], core_shape_1),
64+
)
65+
else:
66+
tuple_core_shapes = (
67+
to_fixed_tuple(core_shapes[0], core_shape_0),
68+
to_fixed_tuple(core_shapes[1], core_shape_1),
69+
to_fixed_tuple(core_shapes[2], core_shape_2),
70+
)
71+
return _vectorized(
72+
core_op_fn,
73+
input_bc_patterns,
74+
output_bc_patterns,
75+
output_dtypes,
76+
inplace_pattern,
77+
(), # constant_inputs
78+
inputs,
79+
tuple_core_shapes,
80+
None, # size
81+
)
82+
83+
def blockwise(*inputs_and_core_shapes):
84+
raise NotImplementedError("Non-jitted BlockwiseWithCoreShape not implemented")
85+
86+
@overload(blockwise, jit_options=_jit_options)
87+
def ov_blockwise(*inputs_and_core_shapes):
88+
return blockwise_wrapper
89+
90+
return blockwise

pytensor/link/numba/dispatch/random.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def random_wrapper(core_shape, rng, size, *dist_params):
388388
return rng, draws
389389

390390
def random(core_shape, rng, size, *dist_params):
391-
pass
391+
raise NotImplementedError("Non-jitted random variable not implemented")
392392

393393
@overload(random, jit_options=_jit_options)
394394
def ov_random(core_shape, rng, size, *dist_params):

pytensor/tensor/blockwise.py

+8
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,11 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
402402

403403
class OpWithCoreShape(OpFromGraph):
404404
"""Generalizes an `Op` to include core shape as an additional input."""
405+
406+
407+
class BlockwiseWithCoreShape(OpWithCoreShape):
408+
"""Generalizes a Blockwise `Op` to include a core shape parameter."""
409+
410+
def __str__(self):
411+
[blockwise_node] = self.fgraph.apply_nodes
412+
return f"[{blockwise_node.op!s}]"

pytensor/tensor/rewriting/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.tensor.rewriting.jax
1010
import pytensor.tensor.rewriting.linalg
1111
import pytensor.tensor.rewriting.math
12+
import pytensor.tensor.rewriting.numba
1213
import pytensor.tensor.rewriting.ofg
1314
import pytensor.tensor.rewriting.shape
1415
import pytensor.tensor.rewriting.special

pytensor/tensor/rewriting/numba.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
from pytensor.compile import optdb
2+
from pytensor.graph import node_rewriter
3+
from pytensor.graph.basic import applys_between
4+
from pytensor.graph.rewriting.basic import out2in
5+
from pytensor.tensor.basic import as_tensor, constant
6+
from pytensor.tensor.blockwise import Blockwise, BlockwiseWithCoreShape
7+
from pytensor.tensor.rewriting.shape import ShapeFeature
8+
9+
10+
@node_rewriter([Blockwise])
11+
def introduce_explicit_core_shape_blockwise(fgraph, node):
12+
"""Introduce the core shape of a Blockwise.
13+
14+
We wrap Blockwise graphs into a BlockwiseWithCoreShape OpFromGraph
15+
that has an extra "non-functional" input that represents the core shape of the Blockwise variable.
16+
This core_shape is used by the numba backend to pre-allocate the output array.
17+
18+
If available, the core shape is extracted from the shape feature of the graph,
19+
which has a higher change of having been simplified, optimized, constant-folded.
20+
If missing, we fall back to the op._supp_shape_from_params method.
21+
22+
This rewrite is required for the numba backend implementation of Blockwise.
23+
24+
Example
25+
-------
26+
27+
.. code-block:: python
28+
29+
import pytensor
30+
import pytensor.tensor as pt
31+
32+
x = pt.tensor("x", shape=(5, None, None))
33+
outs = pt.linalg.svd(x, compute_uv=True)
34+
pytensor.dprint(outs)
35+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.0 [id A]
36+
# └─ x [id B]
37+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.1 [id A]
38+
# └─ ···
39+
# Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}.2 [id A]
40+
# └─ ···
41+
42+
# After the rewrite, note the new 3 core shape inputs
43+
fn = pytensor.function([x], outs, mode="NUMBA")
44+
fn.dprint(print_type=False)
45+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].0 [id A] 6
46+
# ├─ x [id B]
47+
# ├─ MakeVector{dtype='int64'} [id C] 5
48+
# │ ├─ Shape_i{1} [id D] 2
49+
# │ │ └─ x [id B]
50+
# │ └─ Shape_i{1} [id D] 2
51+
# │ └─ ···
52+
# ├─ MakeVector{dtype='int64'} [id E] 4
53+
# │ └─ Minimum [id F] 3
54+
# │ ├─ Shape_i{1} [id D] 2
55+
# │ │ └─ ···
56+
# │ └─ Shape_i{2} [id G] 0
57+
# │ └─ x [id B]
58+
# └─ MakeVector{dtype='int64'} [id H] 1
59+
# ├─ Shape_i{2} [id G] 0
60+
# │ └─ ···
61+
# └─ Shape_i{2} [id G] 0
62+
# └─ ···
63+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].1 [id A] 6
64+
# └─ ···
65+
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
66+
# └─ ···
67+
"""
68+
if len(node.outputs) > 3:
69+
# Current implementation of BlockwiseWithCoreShape does not support more than 3 outputs.
70+
return None
71+
72+
op: Blockwise = node.op
73+
batch_ndim = op.batch_ndim(node)
74+
75+
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
76+
if shape_feature:
77+
core_shapes = [
78+
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
79+
for out in node.outputs
80+
]
81+
else:
82+
input_shapes = [tuple(inp.shape) for inp in node.inputs]
83+
core_shapes = [
84+
out_shape[batch_ndim:]
85+
for out_shape in op.infer_shape(None, node, input_shapes)
86+
]
87+
88+
core_shapes = [
89+
as_tensor(core_shape) if len(core_shape) else constant([], dtype="int64")
90+
for core_shape in core_shapes
91+
]
92+
93+
if any(
94+
isinstance(node.op, Blockwise)
95+
for node in applys_between(node.inputs, core_shapes)
96+
):
97+
# If Blockwise shows up in the shape graph we can't introduce the core shape
98+
return None
99+
100+
return BlockwiseWithCoreShape(
101+
[*node.inputs, *core_shapes],
102+
node.outputs,
103+
destroy_map=op.destroy_map,
104+
)(*node.inputs, *core_shapes, return_list=True)
105+
106+
107+
optdb.register(
108+
introduce_explicit_core_shape_blockwise.__name__,
109+
out2in(introduce_explicit_core_shape_blockwise),
110+
"numba",
111+
position=100,
112+
)

tests/link/numba/test_basic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def compare_numba_and_py(
242242
Parameters
243243
----------
244244
fgraph
245-
`FunctionGraph` or inputs to compare.
245+
`FunctionGraph` or tuple(inputs, outputs) to compare.
246246
inputs
247247
Numeric inputs to be passed to the compiled graphs.
248248
assert_fn

tests/link/numba/test_blockwise.py

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor import function
5+
from pytensor.compile.builders import OpFromGraph
6+
from pytensor.link.numba.test_basic import compare_numba_and_py, numba_mode
7+
from pytensor.tensor import tensor
8+
from pytensor.tensor.basic import ARange
9+
from pytensor.tensor.blockwise import Blockwise
10+
from pytensor.tensor.nlinalg import SVD, Det
11+
from pytensor.tensor.slinalg import Cholesky, cholesky
12+
13+
14+
# Fails if object mode warning is issued when not expected
15+
pytestmark = pytest.mark.filterwarnings("error")
16+
17+
18+
@pytest.mark.parametrize("shape_opt", [True, False], ids=str)
19+
@pytest.mark.parametrize("core_op", [Det(), Cholesky(), SVD(compute_uv=True)], ids=str)
20+
def test_blockwise(core_op, shape_opt):
21+
x = tensor(shape=(5, None, None))
22+
outs = Blockwise(core_op=core_op)(x, return_list=True)
23+
24+
mode = (
25+
numba_mode.including("ShapeOpt")
26+
if shape_opt
27+
else numba_mode.excluding("ShapeOpt")
28+
)
29+
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
30+
compare_numba_and_py(
31+
([x], outs),
32+
[x_test],
33+
numba_mode=mode,
34+
eval_obj_mode=False,
35+
)
36+
37+
38+
def test_non_square_blockwise():
39+
"""Test that Op that cannot always be blockwised at runtime fails gracefully."""
40+
x = tensor(shape=(3,), dtype="int64")
41+
out = Blockwise(core_op=ARange(dtype="int64"), signature="(),(),()->(a)")(0, x, 1)
42+
43+
with pytest.warns(UserWarning, match="Numba will use object mode"):
44+
fn = function([x], out, mode="NUMBA")
45+
46+
np.testing.assert_allclose(fn([5, 5, 5]), np.broadcast_to(np.arange(5), (3, 5)))
47+
48+
with pytest.raises(ValueError):
49+
fn([3, 4, 5])
50+
51+
52+
def test_too_many_outputs_blockwise():
53+
"""Current implementation of Blockwise does not support more than 3 outputs."""
54+
x = tensor("x", shape=())
55+
core_op = OpFromGraph([x], [x + i for i in range(4)])
56+
57+
xs = tensor("x", shape=(3,))
58+
outs = Blockwise(core_op=core_op, signature="()->(),(),(),()")(xs)
59+
60+
with pytest.warns(UserWarning, match="Numba will use object mode"):
61+
compare_numba_and_py(([xs], outs), [np.arange(3)])
62+
63+
64+
def test_blockwise_benchmark(benchmark):
65+
x = tensor(shape=(5, 3, 3))
66+
out = cholesky(x)
67+
assert isinstance(out.owner.op, Blockwise)
68+
69+
fn = function([x], out, mode="NUMBA")
70+
x_test = np.eye(3) * np.arange(1, 6)[:, None, None]
71+
fn(x_test) # JIT compile
72+
benchmark(fn, x_test)

0 commit comments

Comments
 (0)