Skip to content

Commit 11d5b88

Browse files
committed
Support Blockwise in JAX backend
1 parent 9a0e937 commit 11d5b88

File tree

3 files changed

+72
-0
lines changed

3 files changed

+72
-0
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
import pytensor.link.jax.dispatch.elemwise
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
16+
import pytensor.link.jax.dispatch.blockwise
1617

1718
# isort: on
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.graph import FunctionGraph
4+
from pytensor.link.jax.dispatch import jax_funcify
5+
from pytensor.tensor.blockwise import Blockwise
6+
7+
8+
@jax_funcify.register(Blockwise)
9+
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
10+
signature = op.signature
11+
core_node = op._create_dummy_core_node(node.inputs)
12+
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13+
tuple_core_fn = jax_funcify(core_fgraph)
14+
15+
if len(node.outputs) == 1:
16+
17+
def core_fn(*inputs):
18+
return tuple_core_fn(*inputs)[0]
19+
20+
else:
21+
core_fn = tuple_core_fn
22+
23+
vect_fn = jnp.vectorize(core_fn, signature=signature)
24+
25+
def blockwise_fn(*inputs):
26+
op._check_runtime_broadcast(node, inputs)
27+
return vect_fn(*inputs)
28+
29+
return blockwise_fn

tests/link/jax/test_blockwise.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor import config
5+
from pytensor.graph import FunctionGraph
6+
from pytensor.tensor import tensor
7+
from pytensor.tensor.blockwise import Blockwise
8+
from pytensor.tensor.math import Dot, matmul
9+
from tests.link.jax.test_basic import compare_jax_and_py
10+
from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting
11+
12+
13+
jax = pytest.importorskip("jax")
14+
15+
16+
def test_runtime_broadcasting():
17+
check_blockwise_runtime_broadcasting("JAX")
18+
19+
20+
# Equivalent blockwise to matmul but with dumb signature
21+
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
22+
23+
24+
@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
25+
def test_matmul(matmul_op):
26+
rng = np.random.default_rng(14)
27+
a = tensor("a", shape=(2, 3, 5))
28+
b = tensor("b", shape=(2, 5, 3))
29+
test_values = [
30+
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
31+
]
32+
33+
out = matmul_op(a, b)
34+
assert isinstance(out.owner.op, Blockwise)
35+
fg = FunctionGraph([a, b], [out])
36+
fn, _ = compare_jax_and_py(fg, test_values)
37+
38+
# Check we are not adding any unnecessary stuff
39+
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
40+
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
41+
expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
42+
assert jaxpr == expected_jaxpr

0 commit comments

Comments
 (0)