Skip to content

Commit 48fbf0a

Browse files
committed
Clean up tests
1 parent ab326e5 commit 48fbf0a

File tree

1 file changed

+55
-69
lines changed

1 file changed

+55
-69
lines changed

tests/link/jax/test_as_jax_op.py

+55-69
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313

1414
def test_two_inputs_single_output():
1515
rng = np.random.default_rng(1)
16-
x = tensor("a", shape=(2,))
17-
y = tensor("b", shape=(2,))
16+
x = tensor("x", shape=(2,))
17+
y = tensor("y", shape=(2,))
1818
test_values = [
1919
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
2020
]
@@ -34,8 +34,8 @@ def f(x, y):
3434

3535
def test_two_inputs_tuple_output():
3636
rng = np.random.default_rng(2)
37-
x = tensor("a", shape=(2,))
38-
y = tensor("b", shape=(2,))
37+
x = tensor("x", shape=(2,))
38+
y = tensor("y", shape=(2,))
3939
test_values = [
4040
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
4141
]
@@ -44,19 +44,22 @@ def test_two_inputs_tuple_output():
4444
def f(x, y):
4545
return jax.nn.sigmoid(x + y), y * 2
4646

47-
out, _ = f(x, y)
48-
grad_out = grad(pt.sum(out), [x, y])
47+
out1, out2 = f(x, y)
48+
grad_out = grad(pt.sum(out1 + out2), [x, y])
4949

50-
fg = FunctionGraph([x, y], [out, *grad_out])
50+
fg = FunctionGraph([x, y], [out1, out2, *grad_out])
5151
fn, _ = compare_jax_and_py(fg, test_values)
5252
with jax.disable_jit():
53-
fn, _ = compare_jax_and_py(fg, test_values)
53+
# must_be_device_array is False, because the with disabled jit compilation,
54+
# inputs are not automatically transformed to jax.Array anymore
55+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
5456

5557

56-
def test_two_inputs_list_output():
58+
def test_two_inputs_list_output_one_unused_output():
59+
# One output is unused, to test whether the wrapper can handle DisconnectedType
5760
rng = np.random.default_rng(3)
58-
x = tensor("a", shape=(2,))
59-
y = tensor("b", shape=(2,))
61+
x = tensor("x", shape=(2,))
62+
y = tensor("y", shape=(2,))
6063
test_values = [
6164
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
6265
]
@@ -76,63 +79,62 @@ def f(x, y):
7679

7780
def test_single_input_tuple_output():
7881
rng = np.random.default_rng(4)
79-
x = tensor("a", shape=(2,))
82+
x = tensor("x", shape=(2,))
8083
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
8184

8285
@as_jax_op
8386
def f(x):
8487
return jax.nn.sigmoid(x), x * 2
8588

86-
out, _ = f(x)
87-
grad_out = grad(pt.sum(out), [x])
89+
out1, out2 = f(x)
90+
grad_out = grad(pt.sum(out1), [x])
8891

89-
fg = FunctionGraph([x], [out, *grad_out])
92+
fg = FunctionGraph([x], [out1, out2, *grad_out])
9093
fn, _ = compare_jax_and_py(fg, test_values)
9194
with jax.disable_jit():
92-
fn, _ = compare_jax_and_py(fg, test_values)
95+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
9396

9497

9598
def test_scalar_input_tuple_output():
9699
rng = np.random.default_rng(5)
97-
x = tensor("a", shape=())
100+
x = tensor("x", shape=())
98101
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
99102

100103
@as_jax_op
101104
def f(x):
102105
return jax.nn.sigmoid(x), x
103106

104-
out, _ = f(x)
105-
grad_out = grad(pt.sum(out), [x])
107+
out1, out2 = f(x)
108+
grad_out = grad(pt.sum(out1), [x])
106109

107-
fg = FunctionGraph([x], [out, *grad_out])
110+
fg = FunctionGraph([x], [out1, out2, *grad_out])
108111
fn, _ = compare_jax_and_py(fg, test_values)
109112
with jax.disable_jit():
110-
fn, _ = compare_jax_and_py(fg, test_values)
113+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
111114

112115

113116
def test_single_input_list_output():
114117
rng = np.random.default_rng(6)
115-
x = tensor("a", shape=(2,))
118+
x = tensor("x", shape=(2,))
116119
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
117120

118121
@as_jax_op
119122
def f(x):
120123
return [jax.nn.sigmoid(x), 2 * x]
121124

122-
out, _ = f(x)
123-
grad_out = grad(pt.sum(out), [x])
125+
out1, out2 = f(x)
126+
grad_out = grad(pt.sum(out1), [x])
124127

125-
fg = FunctionGraph([x], [out, *grad_out])
128+
fg = FunctionGraph([x], [out1, out2, *grad_out])
126129
fn, _ = compare_jax_and_py(fg, test_values)
127-
128130
with jax.disable_jit():
129-
fn, _ = compare_jax_and_py(fg, test_values)
131+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
130132

131133

132134
def test_pytree_input_tuple_output():
133135
rng = np.random.default_rng(7)
134-
x = tensor("a", shape=(2,))
135-
y = tensor("b", shape=(2,))
136+
x = tensor("x", shape=(2,))
137+
y = tensor("y", shape=(2,))
136138
y_tmp = {"y": y, "y2": [y**2]}
137139
test_values = [
138140
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
@@ -149,13 +151,13 @@ def f(x, y):
149151
fn, _ = compare_jax_and_py(fg, test_values)
150152

151153
with jax.disable_jit():
152-
fn, _ = compare_jax_and_py(fg, test_values)
154+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
153155

154156

155157
def test_pytree_input_pytree_output():
156158
rng = np.random.default_rng(8)
157-
x = tensor("a", shape=(3,))
158-
y = tensor("b", shape=(1,))
159+
x = tensor("x", shape=(3,))
160+
y = tensor("y", shape=(1,))
159161
y_tmp = {"a": y, "b": [y**2]}
160162
test_values = [
161163
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
@@ -171,11 +173,14 @@ def f(x, y):
171173
fg = FunctionGraph([x, y], [out[0], out[1]["a"], *grad_out])
172174
fn, _ = compare_jax_and_py(fg, test_values)
173175

176+
with jax.disable_jit():
177+
fn, _ = compare_jax_and_py(fg, test_values, must_be_device_array=False)
178+
174179

175180
def test_pytree_input_with_non_graph_args():
176181
rng = np.random.default_rng(9)
177-
x = tensor("a", shape=(3,))
178-
y = tensor("b", shape=(1,))
182+
x = tensor("x", shape=(3,))
183+
y = tensor("y", shape=(1,))
179184
y_tmp = {"a": y, "b": [y**2]}
180185
test_values = [
181186
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
@@ -212,10 +217,13 @@ def f(x, y, depth, which_variable):
212217
assert out == "Unsupported argument"
213218

214219

215-
def test_unused_matrix_product_and_exp_gradient():
220+
def test_unused_matrix_product():
221+
# A matrix output is unused, to test whether the wrapper can handle a
222+
# DisconnectedType with a larger dimension.
223+
216224
rng = np.random.default_rng(10)
217-
x = tensor("a", shape=(3,))
218-
y = tensor("b", shape=(3,))
225+
x = tensor("x", shape=(3,))
226+
y = tensor("y", shape=(3,))
219227
test_values = [
220228
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
221229
]
@@ -236,19 +244,19 @@ def f(x, y):
236244

237245
def test_unknown_static_shape():
238246
rng = np.random.default_rng(11)
239-
x = tensor("a", shape=(3,))
240-
y = tensor("b", shape=(3,))
247+
x = tensor("x", shape=(3,))
248+
y = tensor("y", shape=(3,))
241249
test_values = [
242250
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
243251
]
244252

245-
x = pt.cumsum(x) # Now x has an unknown shape
253+
x_cumsum = pt.cumsum(x) # Now x_cumsum has an unknown shape
246254

247255
@as_jax_op
248256
def f(x, y):
249257
return x * jnp.ones(3)
250258

251-
out = f(x, y)
259+
out = f(x_cumsum, y)
252260
grad_out = grad(pt.sum(out), [x])
253261

254262
fg = FunctionGraph([x, y], [out, *grad_out])
@@ -258,32 +266,10 @@ def f(x, y):
258266
fn, _ = compare_jax_and_py(fg, test_values)
259267

260268

261-
def test_non_array_return_values():
262-
rng = np.random.default_rng(12)
263-
x = tensor("a", shape=(3,))
264-
y = tensor("b", shape=(3,))
265-
test_values = [
266-
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
267-
]
268-
269-
@as_jax_op
270-
def f(x, y, message):
271-
return x * jnp.ones(3), "Success: " + message
272-
273-
out = f(x, y, "Hi")
274-
grad_out = grad(pt.sum(out[0]), [x])
275-
276-
fg = FunctionGraph([x, y], [out[0], *grad_out])
277-
fn, _ = compare_jax_and_py(fg, test_values)
278-
279-
with jax.disable_jit():
280-
fn, _ = compare_jax_and_py(fg, test_values)
281-
282-
283269
def test_nested_functions():
284270
rng = np.random.default_rng(13)
285-
x = tensor("a", shape=(3,))
286-
y = tensor("b", shape=(3,))
271+
x = tensor("x", shape=(3,))
272+
y = tensor("y", shape=(3,))
287273
test_values = [
288274
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (x, y)
289275
]
@@ -319,8 +305,8 @@ class TestDtypes:
319305
@pytest.mark.parametrize("in_dtype", list(map(str, all_types)))
320306
@pytest.mark.parametrize("out_dtype", list(map(str, all_types)))
321307
def test_different_in_output(self, in_dtype, out_dtype):
322-
x = tensor("a", shape=(3,), dtype=in_dtype)
323-
y = tensor("b", shape=(3,), dtype=in_dtype)
308+
x = tensor("x", shape=(3,), dtype=in_dtype)
309+
y = tensor("y", shape=(3,), dtype=in_dtype)
324310

325311
if "int" in in_dtype:
326312
test_values = [
@@ -356,8 +342,8 @@ def f(x, y):
356342
@pytest.mark.parametrize("in1_dtype", list(map(str, all_types)))
357343
@pytest.mark.parametrize("in2_dtype", list(map(str, all_types)))
358344
def test_test_different_inputs(self, in1_dtype, in2_dtype):
359-
x = tensor("a", shape=(3,), dtype=in1_dtype)
360-
y = tensor("b", shape=(3,), dtype=in2_dtype)
345+
x = tensor("x", shape=(3,), dtype=in1_dtype)
346+
y = tensor("y", shape=(3,), dtype=in2_dtype)
361347

362348
if "int" in in1_dtype:
363349
test_values = [np.random.randint(0, 10, size=(3,)).astype(x.type.dtype)]

0 commit comments

Comments
 (0)