Skip to content

Commit 2af4e05

Browse files
committed
Support Blockwise in JAX backend
1 parent 9a0e937 commit 2af4e05

File tree

3 files changed

+71
-0
lines changed

3 files changed

+71
-0
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
import pytensor.link.jax.dispatch.elemwise
1414
import pytensor.link.jax.dispatch.scan
1515
import pytensor.link.jax.dispatch.sparse
16+
import pytensor.link.jax.dispatch.blockwise
1617

1718
# isort: on
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.graph import FunctionGraph
4+
from pytensor.link.jax.dispatch import jax_funcify
5+
from pytensor.tensor.blockwise import Blockwise
6+
7+
8+
@jax_funcify.register(Blockwise)
9+
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
10+
signature = op.signature
11+
core_node = op._create_dummy_core_node(node.inputs)
12+
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
13+
tuple_core_fn = jax_funcify(core_fgraph)
14+
15+
if len(node.outputs) == 1:
16+
17+
def core_fn(*inputs):
18+
return tuple_core_fn(*inputs)[0]
19+
20+
else:
21+
core_fn = tuple_core_fn
22+
23+
vect_fn = jnp.vectorize(core_fn, signature=signature)
24+
25+
def blockwise_fn(*inputs):
26+
op._check_runtime_broadcast(node, inputs)
27+
return vect_fn(*inputs)
28+
29+
return blockwise_fn

tests/link/jax/test_blockwise.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import jax
2+
import numpy as np
3+
import pytest
4+
from jax import make_jaxpr
5+
from tensor.test_blockwise import check_blockwise_runtime_broadcasting
6+
7+
from pytensor import config
8+
from pytensor.graph import FunctionGraph
9+
from pytensor.tensor import tensor
10+
from pytensor.tensor.blockwise import Blockwise
11+
from pytensor.tensor.math import Dot, matmul
12+
from tests.link.jax.test_basic import compare_jax_and_py
13+
14+
15+
def test_runtime_broadcasting():
16+
check_blockwise_runtime_broadcasting("JAX")
17+
18+
19+
# Equivalent blockwise to matmul but with dumb signature
20+
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
21+
22+
23+
@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
24+
def test_matmul(matmul_op):
25+
rng = np.random.default_rng(14)
26+
a = tensor("a", shape=(2, 3, 5))
27+
b = tensor("b", shape=(2, 5, 3))
28+
test_values = [
29+
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
30+
]
31+
32+
out = matmul_op(a, b)
33+
assert isinstance(out.owner.op, Blockwise)
34+
fg = FunctionGraph([a, b], [out])
35+
fn, _ = compare_jax_and_py(fg, test_values)
36+
37+
# Check we are not adding any unnecessary stuff
38+
jaxpr = str(make_jaxpr(fn.vm.jit_fn)(*test_values))
39+
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
40+
expected_jaxpr = str(make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
41+
assert jaxpr == expected_jaxpr

0 commit comments

Comments
 (0)