Skip to content

Commit 43d91d0

Browse files
committed
Fix JAX dispatch for multi-output Composite
1 parent 21d723b commit 43d91d0

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

pytensor/link/jax/dispatch/scalar.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,18 @@ def clip(x, min, max):
6363

6464

6565
@jax_funcify.register(Composite)
66-
def jax_funcify_Composite(op, vectorize=True, **kwargs):
66+
def jax_funcify_Composite(op, node, vectorize=True, **kwargs):
6767
jax_impl = jax_funcify(op.fgraph)
6868

69-
def composite(*args):
70-
return jax_impl(*args)[0]
69+
if len(node.outputs) == 1:
70+
71+
def composite(*args):
72+
return jax_impl(*args)[0]
73+
74+
else:
75+
76+
def composite(*args):
77+
return jax_impl(*args)
7178

7279
return jnp.vectorize(composite)
7380

tests/link/jax/test_scalar.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_identity():
6363
),
6464
],
6565
)
66-
def test_jax_Composite(x, y, x_val, y_val):
66+
def test_jax_Composite_singe_output(x, y, x_val, y_val):
6767
x_s = aes.float64("x")
6868
y_s = aes.float64("y")
6969

@@ -80,6 +80,16 @@ def test_jax_Composite(x, y, x_val, y_val):
8080
_ = compare_jax_and_py(out_fg, test_input_vals)
8181

8282

83+
def test_jax_Composite_multi_output():
84+
x = vector("x")
85+
86+
x_s = aes.float64("xs")
87+
outs = Elemwise(Composite(inputs=[x_s], outputs=[x_s + 1, x_s - 1]))(x)
88+
89+
fgraph = FunctionGraph([x], outs)
90+
compare_jax_and_py(fgraph, [np.arange(10, dtype=config.floatX)])
91+
92+
8393
def test_erf():
8494
x = scalar("x")
8595
out = erf(x)

0 commit comments

Comments
 (0)