13
13
14
14
def test_two_inputs_single_output ():
15
15
rng = np .random .default_rng (1 )
16
- x = tensor ("a " , shape = (2 ,))
17
- y = tensor ("b " , shape = (2 ,))
16
+ x = tensor ("x " , shape = (2 ,))
17
+ y = tensor ("y " , shape = (2 ,))
18
18
test_values = [
19
19
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
20
20
]
@@ -34,8 +34,8 @@ def f(x, y):
34
34
35
35
def test_two_inputs_tuple_output ():
36
36
rng = np .random .default_rng (2 )
37
- x = tensor ("a " , shape = (2 ,))
38
- y = tensor ("b " , shape = (2 ,))
37
+ x = tensor ("x " , shape = (2 ,))
38
+ y = tensor ("y " , shape = (2 ,))
39
39
test_values = [
40
40
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
41
41
]
@@ -44,19 +44,22 @@ def test_two_inputs_tuple_output():
44
44
def f (x , y ):
45
45
return jax .nn .sigmoid (x + y ), y * 2
46
46
47
- out , _ = f (x , y )
48
- grad_out = grad (pt .sum (out ), [x , y ])
47
+ out1 , out2 = f (x , y )
48
+ grad_out = grad (pt .sum (out1 + out2 ), [x , y ])
49
49
50
- fg = FunctionGraph ([x , y ], [out , * grad_out ])
50
+ fg = FunctionGraph ([x , y ], [out1 , out2 , * grad_out ])
51
51
fn , _ = compare_jax_and_py (fg , test_values )
52
52
with jax .disable_jit ():
53
- fn , _ = compare_jax_and_py (fg , test_values )
53
+ # must_be_device_array is False, because the with disabled jit compilation,
54
+ # inputs are not automatically transformed to jax.Array anymore
55
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
54
56
55
57
56
- def test_two_inputs_list_output ():
58
+ def test_two_inputs_list_output_one_unused_output ():
59
+ # One output is unused, to test whether the wrapper can handle DisconnectedType
57
60
rng = np .random .default_rng (3 )
58
- x = tensor ("a " , shape = (2 ,))
59
- y = tensor ("b " , shape = (2 ,))
61
+ x = tensor ("x " , shape = (2 ,))
62
+ y = tensor ("y " , shape = (2 ,))
60
63
test_values = [
61
64
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
62
65
]
@@ -76,63 +79,62 @@ def f(x, y):
76
79
77
80
def test_single_input_tuple_output ():
78
81
rng = np .random .default_rng (4 )
79
- x = tensor ("a " , shape = (2 ,))
82
+ x = tensor ("x " , shape = (2 ,))
80
83
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
81
84
82
85
@as_jax_op
83
86
def f (x ):
84
87
return jax .nn .sigmoid (x ), x * 2
85
88
86
- out , _ = f (x )
87
- grad_out = grad (pt .sum (out ), [x ])
89
+ out1 , out2 = f (x )
90
+ grad_out = grad (pt .sum (out1 ), [x ])
88
91
89
- fg = FunctionGraph ([x ], [out , * grad_out ])
92
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
90
93
fn , _ = compare_jax_and_py (fg , test_values )
91
94
with jax .disable_jit ():
92
- fn , _ = compare_jax_and_py (fg , test_values )
95
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
93
96
94
97
95
98
def test_scalar_input_tuple_output ():
96
99
rng = np .random .default_rng (5 )
97
- x = tensor ("a " , shape = ())
100
+ x = tensor ("x " , shape = ())
98
101
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
99
102
100
103
@as_jax_op
101
104
def f (x ):
102
105
return jax .nn .sigmoid (x ), x
103
106
104
- out , _ = f (x )
105
- grad_out = grad (pt .sum (out ), [x ])
107
+ out1 , out2 = f (x )
108
+ grad_out = grad (pt .sum (out1 ), [x ])
106
109
107
- fg = FunctionGraph ([x ], [out , * grad_out ])
110
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
108
111
fn , _ = compare_jax_and_py (fg , test_values )
109
112
with jax .disable_jit ():
110
- fn , _ = compare_jax_and_py (fg , test_values )
113
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
111
114
112
115
113
116
def test_single_input_list_output ():
114
117
rng = np .random .default_rng (6 )
115
- x = tensor ("a " , shape = (2 ,))
118
+ x = tensor ("x " , shape = (2 ,))
116
119
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
117
120
118
121
@as_jax_op
119
122
def f (x ):
120
123
return [jax .nn .sigmoid (x ), 2 * x ]
121
124
122
- out , _ = f (x )
123
- grad_out = grad (pt .sum (out ), [x ])
125
+ out1 , out2 = f (x )
126
+ grad_out = grad (pt .sum (out1 ), [x ])
124
127
125
- fg = FunctionGraph ([x ], [out , * grad_out ])
128
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
126
129
fn , _ = compare_jax_and_py (fg , test_values )
127
-
128
130
with jax .disable_jit ():
129
- fn , _ = compare_jax_and_py (fg , test_values )
131
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
130
132
131
133
132
134
def test_pytree_input_tuple_output ():
133
135
rng = np .random .default_rng (7 )
134
- x = tensor ("a " , shape = (2 ,))
135
- y = tensor ("b " , shape = (2 ,))
136
+ x = tensor ("x " , shape = (2 ,))
137
+ y = tensor ("y " , shape = (2 ,))
136
138
y_tmp = {"y" : y , "y2" : [y ** 2 ]}
137
139
test_values = [
138
140
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
@@ -149,13 +151,13 @@ def f(x, y):
149
151
fn , _ = compare_jax_and_py (fg , test_values )
150
152
151
153
with jax .disable_jit ():
152
- fn , _ = compare_jax_and_py (fg , test_values )
154
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
153
155
154
156
155
157
def test_pytree_input_pytree_output ():
156
158
rng = np .random .default_rng (8 )
157
- x = tensor ("a " , shape = (3 ,))
158
- y = tensor ("b " , shape = (1 ,))
159
+ x = tensor ("x " , shape = (3 ,))
160
+ y = tensor ("y " , shape = (1 ,))
159
161
y_tmp = {"a" : y , "b" : [y ** 2 ]}
160
162
test_values = [
161
163
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
@@ -171,11 +173,14 @@ def f(x, y):
171
173
fg = FunctionGraph ([x , y ], [out [0 ], out [1 ]["a" ], * grad_out ])
172
174
fn , _ = compare_jax_and_py (fg , test_values )
173
175
176
+ with jax .disable_jit ():
177
+ fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
178
+
174
179
175
180
def test_pytree_input_with_non_graph_args ():
176
181
rng = np .random .default_rng (9 )
177
- x = tensor ("a " , shape = (3 ,))
178
- y = tensor ("b " , shape = (1 ,))
182
+ x = tensor ("x " , shape = (3 ,))
183
+ y = tensor ("y " , shape = (1 ,))
179
184
y_tmp = {"a" : y , "b" : [y ** 2 ]}
180
185
test_values = [
181
186
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
@@ -212,10 +217,13 @@ def f(x, y, depth, which_variable):
212
217
assert out == "Unsupported argument"
213
218
214
219
215
- def test_unused_matrix_product_and_exp_gradient ():
220
+ def test_unused_matrix_product ():
221
+ # A matrix output is unused, to test whether the wrapper can handle a
222
+ # DisconnectedType with a larger dimension.
223
+
216
224
rng = np .random .default_rng (10 )
217
- x = tensor ("a " , shape = (3 ,))
218
- y = tensor ("b " , shape = (3 ,))
225
+ x = tensor ("x " , shape = (3 ,))
226
+ y = tensor ("y " , shape = (3 ,))
219
227
test_values = [
220
228
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
221
229
]
@@ -236,19 +244,19 @@ def f(x, y):
236
244
237
245
def test_unknown_static_shape ():
238
246
rng = np .random .default_rng (11 )
239
- x = tensor ("a " , shape = (3 ,))
240
- y = tensor ("b " , shape = (3 ,))
247
+ x = tensor ("x " , shape = (3 ,))
248
+ y = tensor ("y " , shape = (3 ,))
241
249
test_values = [
242
250
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
243
251
]
244
252
245
- x = pt .cumsum (x ) # Now x has an unknown shape
253
+ x_cumsum = pt .cumsum (x ) # Now x_cumsum has an unknown shape
246
254
247
255
@as_jax_op
248
256
def f (x , y ):
249
257
return x * jnp .ones (3 )
250
258
251
- out = f (x , y )
259
+ out = f (x_cumsum , y )
252
260
grad_out = grad (pt .sum (out ), [x ])
253
261
254
262
fg = FunctionGraph ([x , y ], [out , * grad_out ])
@@ -258,32 +266,10 @@ def f(x, y):
258
266
fn , _ = compare_jax_and_py (fg , test_values )
259
267
260
268
261
- def test_non_array_return_values ():
262
- rng = np .random .default_rng (12 )
263
- x = tensor ("a" , shape = (3 ,))
264
- y = tensor ("b" , shape = (3 ,))
265
- test_values = [
266
- rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
267
- ]
268
-
269
- @as_jax_op
270
- def f (x , y , message ):
271
- return x * jnp .ones (3 ), "Success: " + message
272
-
273
- out = f (x , y , "Hi" )
274
- grad_out = grad (pt .sum (out [0 ]), [x ])
275
-
276
- fg = FunctionGraph ([x , y ], [out [0 ], * grad_out ])
277
- fn , _ = compare_jax_and_py (fg , test_values )
278
-
279
- with jax .disable_jit ():
280
- fn , _ = compare_jax_and_py (fg , test_values )
281
-
282
-
283
269
def test_nested_functions ():
284
270
rng = np .random .default_rng (13 )
285
- x = tensor ("a " , shape = (3 ,))
286
- y = tensor ("b " , shape = (3 ,))
271
+ x = tensor ("x " , shape = (3 ,))
272
+ y = tensor ("y " , shape = (3 ,))
287
273
test_values = [
288
274
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
289
275
]
@@ -319,8 +305,8 @@ class TestDtypes:
319
305
@pytest .mark .parametrize ("in_dtype" , list (map (str , all_types )))
320
306
@pytest .mark .parametrize ("out_dtype" , list (map (str , all_types )))
321
307
def test_different_in_output (self , in_dtype , out_dtype ):
322
- x = tensor ("a " , shape = (3 ,), dtype = in_dtype )
323
- y = tensor ("b " , shape = (3 ,), dtype = in_dtype )
308
+ x = tensor ("x " , shape = (3 ,), dtype = in_dtype )
309
+ y = tensor ("y " , shape = (3 ,), dtype = in_dtype )
324
310
325
311
if "int" in in_dtype :
326
312
test_values = [
@@ -356,8 +342,8 @@ def f(x, y):
356
342
@pytest .mark .parametrize ("in1_dtype" , list (map (str , all_types )))
357
343
@pytest .mark .parametrize ("in2_dtype" , list (map (str , all_types )))
358
344
def test_test_different_inputs (self , in1_dtype , in2_dtype ):
359
- x = tensor ("a " , shape = (3 ,), dtype = in1_dtype )
360
- y = tensor ("b " , shape = (3 ,), dtype = in2_dtype )
345
+ x = tensor ("x " , shape = (3 ,), dtype = in1_dtype )
346
+ y = tensor ("y " , shape = (3 ,), dtype = in2_dtype )
361
347
362
348
if "int" in in1_dtype :
363
349
test_values = [np .random .randint (0 , 10 , size = (3 ,)).astype (x .type .dtype )]
0 commit comments