Skip to content

Commit b8c4523

Browse files
committed
Add to some tests a direct call to JAXOp
1 parent 48fbf0a commit b8c4523

File tree

1 file changed

+114
-17
lines changed

1 file changed

+114
-17
lines changed

tests/link/jax/test_as_jax_op.py

+114-17
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import pytensor.tensor as pt
77
from pytensor import as_jax_op, config, grad
88
from pytensor.graph.fg import FunctionGraph
9+
from pytensor.link.jax.ops import JAXOp
910
from pytensor.scalar import all_types
10-
from pytensor.tensor import tensor
11+
from pytensor.tensor import TensorType, tensor
1112
from tests.link.jax.test_basic import compare_jax_and_py
1213

1314

@@ -19,18 +20,29 @@ def test_two_inputs_single_output():
1920
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
2021
]
2122

22-
@as_jax_op
2323
def f(x, y):
2424
return jax.nn.sigmoid(x + y)
2525

26-
out = f(x, y)
26+
# Test with as_jax_op decorator
27+
out = as_jax_op(f)(x, y)
2728
grad_out = grad(pt.sum(out), [x, y])
2829

2930
fg = FunctionGraph([x, y], [out, *grad_out])
3031
fn, _ = compare_jax_and_py(fg, test_values)
3132
with jax.disable_jit():
3233
fn, _ = compare_jax_and_py(fg, test_values)
3334

35+
# Test direct JAXOp usage
36+
jax_op = JAXOp(
37+
[x.type, y.type],
38+
[TensorType(config.floatX, shape=(2,))],
39+
f,
40+
)
41+
out = jax_op(x, y)
42+
grad_out = grad(pt.sum(out), [x, y])
43+
fg = FunctionGraph([x, y], [out, *grad_out])
44+
fn, _ = compare_jax_and_py(fg, test_values)
45+
3446

3547
def test_two_inputs_tuple_output():
3648
rng = np.random.default_rng(2)
@@ -40,11 +52,11 @@ def test_two_inputs_tuple_output():
4052
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
4153
]
4254

43-
@as_jax_op
4455
def f(x, y):
4556
return jax.nn.sigmoid(x + y), y * 2
4657

47-
out1, out2 = f(x, y)
58+
# Test with as_jax_op decorator
59+
out1, out2 = as_jax_op(f)(x, y)
4860
grad_out = grad(pt.sum(out1 + out2), [x, y])
4961

5062
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
@@ -54,6 +66,17 @@ def f(x, y):
5466
# inputs are not automatically transformed to jax.Array anymore
5567
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
5668

69+
# Test direct JAXOp usage
70+
jax_op = JAXOp(
71+
[x.type, y.type],
72+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
73+
f,
74+
)
75+
out1, out2 = jax_op(x, y)
76+
grad_out = grad(pt.sum(out1 + out2), [x, y])
77+
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
78+
fn, _ = compare_jax_and_py(fg, test_values)
79+
5780

5881
def test_two_inputs_list_output_one_unused_output():
5982
# One output is unused, to test whether the wrapper can handle DisconnectedType
@@ -64,72 +87,119 @@ def test_two_inputs_list_output_one_unused_output():
6487
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
6588
]
6689

67-
@as_jax_op
6890
def f(x, y):
6991
return [jax.nn.sigmoid(x + y), y * 2]
7092

71-
out, _ = f(x, y)
93+
# Test with as_jax_op decorator
94+
out, _ = as_jax_op(f)(x, y)
7295
grad_out = grad(pt.sum(out), [x, y])
7396

7497
fg = FunctionGraph([x, y], [out, *grad_out])
7598
fn, _ = compare_jax_and_py(fg, test_values)
7699
with jax.disable_jit():
77100
fn, _ = compare_jax_and_py(fg, test_values)
78101

102+
# Test direct JAXOp usage
103+
jax_op = JAXOp(
104+
[x.type, y.type],
105+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
106+
f,
107+
)
108+
out, _ = jax_op(x, y)
109+
grad_out = grad(pt.sum(out), [x, y])
110+
fg = FunctionGraph([x, y], [out, *grad_out])
111+
fn, _ = compare_jax_and_py(fg, test_values)
112+
79113

80114
def test_single_input_tuple_output():
81115
rng = np.random.default_rng(4)
82116
x = tensor("x", shape=(2,))
83117
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
84118

85-
@as_jax_op
86119
def f(x):
87120
return jax.nn.sigmoid(x), x * 2
88121

89-
out1, out2 = f(x)
122+
# Test with as_jax_op decorator
123+
out1, out2 = as_jax_op(f)(x)
90124
grad_out = grad(pt.sum(out1), [x])
91125

92126
fg = FunctionGraph([x], [out1, out2, *grad_out])
93127
fn, _ = compare_jax_and_py(fg, test_values)
94128
with jax.disable_jit():
95129
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
96130

131+
# Test direct JAXOp usage
132+
jax_op = JAXOp(
133+
[x.type],
134+
[TensorType(config.floatX, shape=(2,)), TensorType(config.floatX, shape=(2,))],
135+
f,
136+
)
137+
out1, out2 = jax_op(x)
138+
grad_out = grad(pt.sum(out1), [x])
139+
fg = FunctionGraph([x], [out1, out2, *grad_out])
140+
fn, _ = compare_jax_and_py(fg, test_values)
141+
97142

98143
def test_scalar_input_tuple_output():
99144
rng = np.random.default_rng(5)
100145
x = tensor("x", shape=())
101146
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
102147

103-
@as_jax_op
104148
def f(x):
105149
return jax.nn.sigmoid(x), x
106150

107-
out1, out2 = f(x)
151+
# Test with as_jax_op decorator
152+
out1, out2 = as_jax_op(f)(x)
108153
grad_out = grad(pt.sum(out1), [x])
109154

110155
fg = FunctionGraph([x], [out1, out2, *grad_out])
111156
fn, _ = compare_jax_and_py(fg, test_values)
112157
with jax.disable_jit():
113158
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
114159

160+
# Test direct JAXOp usage
161+
jax_op = JAXOp(
162+
[x.type],
163+
[TensorType(config.floatX, shape=()), TensorType(config.floatX, shape=())],
164+
f,
165+
)
166+
out1, out2 = jax_op(x)
167+
grad_out = grad(pt.sum(out1), [x])
168+
fg = FunctionGraph([x], [out1, out2, *grad_out])
169+
fn, _ = compare_jax_and_py(fg, test_values)
170+
115171

116172
def test_single_input_list_output():
117173
rng = np.random.default_rng(6)
118174
x = tensor("x", shape=(2,))
119175
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
120176

121-
@as_jax_op
122177
def f(x):
123178
return [jax.nn.sigmoid(x), 2 * x]
124179

125-
out1, out2 = f(x)
180+
# Test with as_jax_op decorator
181+
out1, out2 = as_jax_op(f)(x)
126182
grad_out = grad(pt.sum(out1), [x])
127183

128184
fg = FunctionGraph([x], [out1, out2, *grad_out])
129185
fn, _ = compare_jax_and_py(fg, test_values)
130186
with jax.disable_jit():
131187
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
132188

189+
# Test direct JAXOp usage, with unspecified output shapes
190+
jax_op = JAXOp(
191+
[x.type],
192+
[
193+
TensorType(config.floatX, shape=(None,)),
194+
TensorType(config.floatX, shape=(None,)),
195+
],
196+
f,
197+
)
198+
out1, out2 = jax_op(x)
199+
grad_out = grad(pt.sum(out1), [x])
200+
fg = FunctionGraph([x], [out1, out2, *grad_out])
201+
fn, _ = compare_jax_and_py(fg, test_values)
202+
133203

134204
def test_pytree_input_tuple_output():
135205
rng = np.random.default_rng(7)
@@ -144,6 +214,7 @@ def test_pytree_input_tuple_output():
144214
def f(x, y):
145215
return jax.nn.sigmoid(x), 2 * x + y["y"] + y["y2"][0]
146216

217+
# Test with as_jax_op decorator
147218
out = f(x, y_tmp)
148219
grad_out = grad(pt.sum(out[1]), [x, y])
149220

@@ -167,6 +238,7 @@ def test_pytree_input_pytree_output():
167238
def f(x, y):
168239
return x, jax.tree_util.tree_map(lambda x: jnp.exp(x), y)
169240

241+
# Test with as_jax_op decorator
170242
out = f(x, y_tmp)
171243
grad_out = grad(pt.sum(out[1]["b"][0]), [x, y])
172244

@@ -198,6 +270,7 @@ def f(x, y, depth, which_variable):
198270
var = jax.nn.sigmoid(var)
199271
return var
200272

273+
# Test with as_jax_op decorator
201274
# arguments depth and which_variable are not part of the graph
202275
out = f(x, y_tmp, depth=3, which_variable="x")
203276
grad_out = grad(pt.sum(out), [x])
@@ -228,11 +301,11 @@ def test_unused_matrix_product():
228301
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
229302
]
230303

231-
@as_jax_op
232304
def f(x, y):
233305
return x[:, None] @ y[None], jnp.exp(x)
234306

235-
out = f(x, y)
307+
# Test with as_jax_op decorator
308+
out = as_jax_op(f)(x, y)
236309
grad_out = grad(pt.sum(out[1]), [x])
237310

238311
fg = FunctionGraph([x, y], [out[1], *grad_out])
@@ -241,6 +314,20 @@ def f(x, y):
241314
with jax.disable_jit():
242315
fn, _ = compare_jax_and_py(fg, test_values)
243316

317+
# Test direct JAXOp usage
318+
jax_op = JAXOp(
319+
[x.type, y.type],
320+
[
321+
TensorType(config.floatX, shape=(3, 3)),
322+
TensorType(config.floatX, shape=(3,)),
323+
],
324+
f,
325+
)
326+
out = jax_op(x, y)
327+
grad_out = grad(pt.sum(out[1]), [x])
328+
fg = FunctionGraph([x, y], [out[1], *grad_out])
329+
fn, _ = compare_jax_and_py(fg, test_values)
330+
244331

245332
def test_unknown_static_shape():
246333
rng = np.random.default_rng(11)
@@ -252,11 +339,10 @@ def test_unknown_static_shape():
252339

253340
x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape
254341

255-
@as_jax_op
256342
def f(x, y):
257343
return x * jnp.ones(3)
258344

259-
out = f(x_cumsum, y)
345+
out = as_jax_op(f)(x_cumsum, y)
260346
grad_out = grad(pt.sum(out), [x])
261347

262348
fg = FunctionGraph([x, y], [out, *grad_out])
@@ -265,6 +351,17 @@ def f(x, y):
265351
with jax.disable_jit():
266352
fn, _ = compare_jax_and_py(fg, test_values)
267353

354+
# Test direct JAXOp usage
355+
jax_op = JAXOp(
356+
[x.type, y.type],
357+
[TensorType(config.floatX, shape=(None,))],
358+
f,
359+
)
360+
out = jax_op(x_cumsum, y)
361+
grad_out = grad(pt.sum(out), [x])
362+
fg = FunctionGraph([x, y], [out, *grad_out])
363+
fn, _ = compare_jax_and_py(fg, test_values)
364+
268365

269366
def test_nested_functions():
270367
rng = np.random.default_rng(13)

0 commit comments

Comments
 (0)