11
11
from tests .link .jax .test_basic import compare_jax_and_py
12
12
13
13
14
- def test_2in_1out ():
14
+ def test_two_inputs_single_output ():
15
15
rng = np .random .default_rng (1 )
16
16
x = tensor ("a" , shape = (2 ,))
17
17
y = tensor ("b" , shape = (2 ,))
@@ -32,7 +32,7 @@ def f(x, y):
32
32
fn , _ = compare_jax_and_py (fg , test_values )
33
33
34
34
35
- def test_2in_tupleout ():
35
+ def test_two_inputs_tuple_output ():
36
36
rng = np .random .default_rng (2 )
37
37
x = tensor ("a" , shape = (2 ,))
38
38
y = tensor ("b" , shape = (2 ,))
@@ -53,7 +53,7 @@ def f(x, y):
53
53
fn , _ = compare_jax_and_py (fg , test_values )
54
54
55
55
56
- def test_2in_listout ():
56
+ def test_two_inputs_list_output ():
57
57
rng = np .random .default_rng (3 )
58
58
x = tensor ("a" , shape = (2 ,))
59
59
y = tensor ("b" , shape = (2 ,))
@@ -74,7 +74,7 @@ def f(x, y):
74
74
fn , _ = compare_jax_and_py (fg , test_values )
75
75
76
76
77
- def test_1din_tupleout ():
77
+ def test_single_input_tuple_output ():
78
78
rng = np .random .default_rng (4 )
79
79
x = tensor ("a" , shape = (2 ,))
80
80
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -92,7 +92,7 @@ def f(x):
92
92
fn , _ = compare_jax_and_py (fg , test_values )
93
93
94
94
95
- def test_0din_tupleout ():
95
+ def test_scalar_input_tuple_output ():
96
96
rng = np .random .default_rng (5 )
97
97
x = tensor ("a" , shape = ())
98
98
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -110,7 +110,7 @@ def f(x):
110
110
fn , _ = compare_jax_and_py (fg , test_values )
111
111
112
112
113
- def test_1in_listout ():
113
+ def test_single_input_list_output ():
114
114
rng = np .random .default_rng (6 )
115
115
x = tensor ("a" , shape = (2 ,))
116
116
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -129,7 +129,7 @@ def f(x):
129
129
fn , _ = compare_jax_and_py (fg , test_values )
130
130
131
131
132
- def test_pytreein_tupleout ():
132
+ def test_pytree_input_tuple_output ():
133
133
rng = np .random .default_rng (7 )
134
134
x = tensor ("a" , shape = (2 ,))
135
135
y = tensor ("b" , shape = (2 ,))
@@ -152,7 +152,7 @@ def f(x, y):
152
152
fn , _ = compare_jax_and_py (fg , test_values )
153
153
154
154
155
- def test_pytreein_pytreeout ():
155
+ def test_pytree_input_pytree_output ():
156
156
rng = np .random .default_rng (8 )
157
157
x = tensor ("a" , shape = (3 ,))
158
158
y = tensor ("b" , shape = (1 ,))
@@ -172,7 +172,7 @@ def f(x, y):
172
172
fn , _ = compare_jax_and_py (fg , test_values )
173
173
174
174
175
- def test_pytreein_pytreeout_w_nongraphargs ():
175
+ def test_pytree_input_with_non_graph_args ():
176
176
rng = np .random .default_rng (9 )
177
177
x = tensor ("a" , shape = (3 ,))
178
178
y = tensor ("b" , shape = (1 ,))
@@ -212,8 +212,7 @@ def f(x, y, depth, which_variable):
212
212
assert out == "Unsupported argument"
213
213
214
214
215
- def test_as_jax_op10 ():
216
- # Use "None" in shape specification and have a non-used output of higher rank
215
+ def test_unused_matrix_product_and_exp_gradient ():
217
216
rng = np .random .default_rng (10 )
218
217
x = tensor ("a" , shape = (3 ,))
219
218
y = tensor ("b" , shape = (3 ,))
@@ -235,8 +234,7 @@ def f(x, y):
235
234
fn , _ = compare_jax_and_py (fg , test_values )
236
235
237
236
238
- def test_as_jax_op11 ():
239
- # Test unknown static shape
237
+ def test_unknown_static_shape ():
240
238
rng = np .random .default_rng (11 )
241
239
x = tensor ("a" , shape = (3 ,))
242
240
y = tensor ("b" , shape = (3 ,))
@@ -260,8 +258,7 @@ def f(x, y):
260
258
fn , _ = compare_jax_and_py (fg , test_values )
261
259
262
260
263
- def test_as_jax_op12 ():
264
- # Test non-array return values
261
+ def test_non_array_return_values ():
265
262
rng = np .random .default_rng (12 )
266
263
x = tensor ("a" , shape = (3 ,))
267
264
y = tensor ("b" , shape = (3 ,))
@@ -283,8 +280,7 @@ def f(x, y, message):
283
280
fn , _ = compare_jax_and_py (fg , test_values )
284
281
285
282
286
- def test_as_jax_op13 ():
287
- # Test nested functions
283
+ def test_nested_functions ():
288
284
rng = np .random .default_rng (13 )
289
285
x = tensor ("a" , shape = (3 ,))
290
286
y = tensor ("b" , shape = (3 ,))
0 commit comments