Skip to content

Commit e11777e

Browse files
committed
Rename tests and make static variables test more meaningfull
1 parent 104df83 commit e11777e

File tree

1 file changed

+31
-23
lines changed

1 file changed

+31
-23
lines changed

tests/link/jax/test_as_jax_op.py

+31-23
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
from tests.link.jax.test_basic import compare_jax_and_py
1212

1313

14-
def test_as_jax_op1():
15-
# 2 parameters input, single output
14+
def test_2in_1out():
1615
rng = np.random.default_rng(1)
1716
x = tensor("a", shape=(2,))
1817
y = tensor("b", shape=(2,))
@@ -33,8 +32,7 @@ def f(x, y):
3332
fn, _ = compare_jax_and_py(fg, test_values)
3433

3534

36-
def test_as_jax_op2():
37-
# 2 parameters input, tuple output
35+
def test_2in_tupleout():
3836
rng = np.random.default_rng(2)
3937
x = tensor("a", shape=(2,))
4038
y = tensor("b", shape=(2,))
@@ -55,8 +53,7 @@ def f(x, y):
5553
fn, _ = compare_jax_and_py(fg, test_values)
5654

5755

58-
def test_as_jax_op3():
59-
# 2 parameters input, list output
56+
def test_2in_listout():
6057
rng = np.random.default_rng(3)
6158
x = tensor("a", shape=(2,))
6259
y = tensor("b", shape=(2,))
@@ -77,8 +74,7 @@ def f(x, y):
7774
fn, _ = compare_jax_and_py(fg, test_values)
7875

7976

80-
def test_as_jax_op4():
81-
# single 1d input, tuple output
77+
def test_1din_tupleout():
8278
rng = np.random.default_rng(4)
8379
x = tensor("a", shape=(2,))
8480
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -96,8 +92,7 @@ def f(x):
9692
fn, _ = compare_jax_and_py(fg, test_values)
9793

9894

99-
def test_as_jax_op5():
100-
# single 0d input, tuple output
95+
def test_0din_tupleout():
10196
rng = np.random.default_rng(5)
10297
x = tensor("a", shape=())
10398
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -115,8 +110,7 @@ def f(x):
115110
fn, _ = compare_jax_and_py(fg, test_values)
116111

117112

118-
def test_as_jax_op6():
119-
# single input, list output
113+
def test_1in_listout():
120114
rng = np.random.default_rng(6)
121115
x = tensor("a", shape=(2,))
122116
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -135,8 +129,7 @@ def f(x):
135129
fn, _ = compare_jax_and_py(fg, test_values)
136130

137131

138-
def test_as_jax_op7():
139-
# 2 parameters input with pytree, tuple output
132+
def test_pytreein_tupleout():
140133
rng = np.random.default_rng(7)
141134
x = tensor("a", shape=(2,))
142135
y = tensor("b", shape=(2,))
@@ -159,8 +152,7 @@ def f(x, y):
159152
fn, _ = compare_jax_and_py(fg, test_values)
160153

161154

162-
def test_as_jax_op8():
163-
# 2 parameters input with pytree, pytree output
155+
def test_pytreein_pytreeout():
164156
rng = np.random.default_rng(8)
165157
x = tensor("a", shape=(3,))
166158
y = tensor("b", shape=(1,))
@@ -180,8 +172,7 @@ def f(x, y):
180172
fn, _ = compare_jax_and_py(fg, test_values)
181173

182174

183-
def test_as_jax_op9():
184-
# 2 parameters input with pytree, pytree output and non-graph argument
175+
def test_pytreein_pytreeout_w_nongraphargs():
185176
rng = np.random.default_rng(9)
186177
x = tensor("a", shape=(3,))
187178
y = tensor("b", shape=(1,))
@@ -191,18 +182,35 @@ def test_as_jax_op9():
191182
]
192183

193184
@as_jax_op
194-
def f(x, y, non_model_arg):
195-
return jnp.exp(x), jax.tree_util.tree_map(jax.nn.sigmoid, y)
196-
197-
out = f(x, y_tmp, "Hello World!")
198-
grad_out = grad(pt.sum(out[0]), [x])
185+
def f(x, y, depth, which_variable):
186+
if which_variable == "x":
187+
var = x
188+
elif which_variable == "y":
189+
var = y["a"] + y["b"][0]
190+
else:
191+
return "Unsupported argument"
192+
for _ in range(depth):
193+
var = jax.nn.sigmoid(var)
194+
return var
199195

196+
# arguments depth and which_variable are not part of the graph
197+
out = f(x, y_tmp, depth=3, which_variable="x")
198+
grad_out = grad(pt.sum(out), [x])
200199
fg = FunctionGraph([x, y], [out[0], *grad_out])
201200
fn, _ = compare_jax_and_py(fg, test_values)
201+
with jax.disable_jit():
202+
fn, _ = compare_jax_and_py(fg, test_values)
202203

204+
out = f(x, y_tmp, depth=7, which_variable="y")
205+
grad_out = grad(pt.sum(out), [x])
206+
fg = FunctionGraph([x, y], [out[0], *grad_out])
207+
fn, _ = compare_jax_and_py(fg, test_values)
203208
with jax.disable_jit():
204209
fn, _ = compare_jax_and_py(fg, test_values)
205210

211+
out = f(x, y_tmp, depth=10, which_variable="z")
212+
assert out == "Unsupported argument"
213+
206214

207215
def test_as_jax_op10():
208216
# Use "None" in shape specification and have a non-used output of higher rank

0 commit comments

Comments
 (0)