Skip to content

Commit d2e788f

Browse files
committed
More test renaming, forgot a few
1 parent e11777e commit d2e788f

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

tests/link/jax/test_as_jax_op.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tests.link.jax.test_basic import compare_jax_and_py
1212

1313

14-
def test_2in_1out():
14+
def test_two_inputs_single_output():
1515
rng = np.random.default_rng(1)
1616
x = tensor("a", shape=(2,))
1717
y = tensor("b", shape=(2,))
@@ -32,7 +32,7 @@ def f(x, y):
3232
fn, _ = compare_jax_and_py(fg, test_values)
3333

3434

35-
def test_2in_tupleout():
35+
def test_two_inputs_tuple_output():
3636
rng = np.random.default_rng(2)
3737
x = tensor("a", shape=(2,))
3838
y = tensor("b", shape=(2,))
@@ -53,7 +53,7 @@ def f(x, y):
5353
fn, _ = compare_jax_and_py(fg, test_values)
5454

5555

56-
def test_2in_listout():
56+
def test_two_inputs_list_output():
5757
rng = np.random.default_rng(3)
5858
x = tensor("a", shape=(2,))
5959
y = tensor("b", shape=(2,))
@@ -74,7 +74,7 @@ def f(x, y):
7474
fn, _ = compare_jax_and_py(fg, test_values)
7575

7676

77-
def test_1din_tupleout():
77+
def test_single_input_tuple_output():
7878
rng = np.random.default_rng(4)
7979
x = tensor("a", shape=(2,))
8080
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -92,7 +92,7 @@ def f(x):
9292
fn, _ = compare_jax_and_py(fg, test_values)
9393

9494

95-
def test_0din_tupleout():
95+
def test_scalar_input_tuple_output():
9696
rng = np.random.default_rng(5)
9797
x = tensor("a", shape=())
9898
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -110,7 +110,7 @@ def f(x):
110110
fn, _ = compare_jax_and_py(fg, test_values)
111111

112112

113-
def test_1in_listout():
113+
def test_single_input_list_output():
114114
rng = np.random.default_rng(6)
115115
x = tensor("a", shape=(2,))
116116
test_values = [rng.normal(size=(x.type.shape)).astype(config.floatX)]
@@ -129,7 +129,7 @@ def f(x):
129129
fn, _ = compare_jax_and_py(fg, test_values)
130130

131131

132-
def test_pytreein_tupleout():
132+
def test_pytree_input_tuple_output():
133133
rng = np.random.default_rng(7)
134134
x = tensor("a", shape=(2,))
135135
y = tensor("b", shape=(2,))
@@ -152,7 +152,7 @@ def f(x, y):
152152
fn, _ = compare_jax_and_py(fg, test_values)
153153

154154

155-
def test_pytreein_pytreeout():
155+
def test_pytree_input_pytree_output():
156156
rng = np.random.default_rng(8)
157157
x = tensor("a", shape=(3,))
158158
y = tensor("b", shape=(1,))
@@ -172,7 +172,7 @@ def f(x, y):
172172
fn, _ = compare_jax_and_py(fg, test_values)
173173

174174

175-
def test_pytreein_pytreeout_w_nongraphargs():
175+
def test_pytree_input_with_non_graph_args():
176176
rng = np.random.default_rng(9)
177177
x = tensor("a", shape=(3,))
178178
y = tensor("b", shape=(1,))
@@ -212,8 +212,7 @@ def f(x, y, depth, which_variable):
212212
assert out == "Unsupported argument"
213213

214214

215-
def test_as_jax_op10():
216-
# Use "None" in shape specification and have a non-used output of higher rank
215+
def test_unused_matrix_product_and_exp_gradient():
217216
rng = np.random.default_rng(10)
218217
x = tensor("a", shape=(3,))
219218
y = tensor("b", shape=(3,))
@@ -235,8 +234,7 @@ def f(x, y):
235234
fn, _ = compare_jax_and_py(fg, test_values)
236235

237236

238-
def test_as_jax_op11():
239-
# Test unknown static shape
237+
def test_unknown_static_shape():
240238
rng = np.random.default_rng(11)
241239
x = tensor("a", shape=(3,))
242240
y = tensor("b", shape=(3,))
@@ -260,8 +258,7 @@ def f(x, y):
260258
fn, _ = compare_jax_and_py(fg, test_values)
261259

262260

263-
def test_as_jax_op12():
264-
# Test non-array return values
261+
def test_non_array_return_values():
265262
rng = np.random.default_rng(12)
266263
x = tensor("a", shape=(3,))
267264
y = tensor("b", shape=(3,))
@@ -283,8 +280,7 @@ def f(x, y, message):
283280
fn, _ = compare_jax_and_py(fg, test_values)
284281

285282

286-
def test_as_jax_op13():
287-
# Test nested functions
283+
def test_nested_functions():
288284
rng = np.random.default_rng(13)
289285
x = tensor("a", shape=(3,))
290286
y = tensor("b", shape=(3,))

0 commit comments

Comments
 (0)