Skip to content

Commit a570dbf

Browse files
Ch0ronomatoricardoV94
authored andcommitted
Implement Blockwise in PyTorch backend
1 parent 2b106fc commit a570dbf

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

pytensor/link/pytorch/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
import pytensor.link.pytorch.dispatch.shape
1212
import pytensor.link.pytorch.dispatch.sort
1313
import pytensor.link.pytorch.dispatch.subtensor
14+
import pytensor.link.pytorch.dispatch.blockwise
1415
# isort: on
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.compiler
3+
4+
from pytensor.graph import FunctionGraph
5+
from pytensor.link.pytorch.dispatch import pytorch_funcify
6+
from pytensor.tensor.blockwise import Blockwise
7+
8+
9+
@pytorch_funcify.register(Blockwise)
10+
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
11+
batched_dims = op.batch_ndim(node)
12+
core_node = op._create_dummy_core_node(node.inputs)
13+
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
14+
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)
15+
16+
for _ in range(batched_dims):
17+
inner_func = torch.vmap(inner_func)
18+
19+
@torch.compiler.disable(recursive=False)
20+
def batcher(*inputs):
21+
op._check_runtime_broadcast(node, inputs)
22+
# broadcast on batched_dims
23+
all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs)
24+
batched_shape = torch.broadcast_shapes(*all_batched_dims)
25+
broadcast_inputs = [
26+
torch.broadcast_to(i, batched_shape + i.shape[batched_dims:])
27+
for i in inputs
28+
]
29+
res = inner_func(*broadcast_inputs)
30+
return res
31+
32+
return batcher

tests/link/pytorch/test_blockwise.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
import pytest
3+
4+
import pytensor
5+
import pytensor.tensor as pt
6+
from pytensor.graph.basic import Apply
7+
from pytensor.graph.op import Op
8+
from pytensor.tensor.blockwise import Blockwise
9+
10+
11+
torch = pytest.importorskip("torch")
12+
basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
13+
14+
15+
class TestOp(Op):
16+
gufunc_signature = "(m,n),(n,p)->(m,p)"
17+
18+
def __init__(self, final_shape):
19+
super().__init__()
20+
self.final_shape = final_shape
21+
self.call_shapes = []
22+
23+
def make_node(self, *args):
24+
return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)])
25+
26+
def perform(self, *_):
27+
raise RuntimeError("In perform")
28+
29+
30+
@basic.pytorch_funcify.register(TestOp)
31+
def evaluate_test_op(op, **_):
32+
@torch.compiler.disable(recursive=False)
33+
def func(a, b):
34+
op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
35+
return a @ b
36+
37+
return func
38+
39+
40+
def test_blockwise_broadcast():
41+
_x = np.random.rand(5, 1, 2, 3)
42+
_y = np.random.rand(3, 3, 2)
43+
44+
x = pt.tensor4("x", shape=(5, 1, 2, 3))
45+
y = pt.tensor3("y", shape=(3, 3, 2))
46+
op = TestOp((2, 2))
47+
z = Blockwise(op)(x, y)
48+
49+
f = pytensor.function([x, y], z, mode="PYTORCH")
50+
res = f(_x, _y)
51+
assert tuple(res.shape) == (5, 3, 2, 2)
52+
np.testing.assert_allclose(res, _x @ _y)
53+
assert op.call_shapes == [(2, 3), (3, 2)]

0 commit comments

Comments
 (0)