11
11
from tests .link .jax .test_basic import compare_jax_and_py
12
12
13
13
14
- def test_as_jax_op1 ():
15
- # 2 parameters input, single output
14
+ def test_2in_1out ():
16
15
rng = np .random .default_rng (1 )
17
16
x = tensor ("a" , shape = (2 ,))
18
17
y = tensor ("b" , shape = (2 ,))
@@ -33,8 +32,7 @@ def f(x, y):
33
32
fn , _ = compare_jax_and_py (fg , test_values )
34
33
35
34
36
- def test_as_jax_op2 ():
37
- # 2 parameters input, tuple output
35
+ def test_2in_tupleout ():
38
36
rng = np .random .default_rng (2 )
39
37
x = tensor ("a" , shape = (2 ,))
40
38
y = tensor ("b" , shape = (2 ,))
@@ -55,8 +53,7 @@ def f(x, y):
55
53
fn , _ = compare_jax_and_py (fg , test_values )
56
54
57
55
58
- def test_as_jax_op3 ():
59
- # 2 parameters input, list output
56
+ def test_2in_listout ():
60
57
rng = np .random .default_rng (3 )
61
58
x = tensor ("a" , shape = (2 ,))
62
59
y = tensor ("b" , shape = (2 ,))
@@ -77,8 +74,7 @@ def f(x, y):
77
74
fn , _ = compare_jax_and_py (fg , test_values )
78
75
79
76
80
- def test_as_jax_op4 ():
81
- # single 1d input, tuple output
77
+ def test_1din_tupleout ():
82
78
rng = np .random .default_rng (4 )
83
79
x = tensor ("a" , shape = (2 ,))
84
80
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -96,8 +92,7 @@ def f(x):
96
92
fn , _ = compare_jax_and_py (fg , test_values )
97
93
98
94
99
- def test_as_jax_op5 ():
100
- # single 0d input, tuple output
95
+ def test_0din_tupleout ():
101
96
rng = np .random .default_rng (5 )
102
97
x = tensor ("a" , shape = ())
103
98
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -115,8 +110,7 @@ def f(x):
115
110
fn , _ = compare_jax_and_py (fg , test_values )
116
111
117
112
118
- def test_as_jax_op6 ():
119
- # single input, list output
113
+ def test_1in_listout ():
120
114
rng = np .random .default_rng (6 )
121
115
x = tensor ("a" , shape = (2 ,))
122
116
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
@@ -135,8 +129,7 @@ def f(x):
135
129
fn , _ = compare_jax_and_py (fg , test_values )
136
130
137
131
138
- def test_as_jax_op7 ():
139
- # 2 parameters input with pytree, tuple output
132
+ def test_pytreein_tupleout ():
140
133
rng = np .random .default_rng (7 )
141
134
x = tensor ("a" , shape = (2 ,))
142
135
y = tensor ("b" , shape = (2 ,))
@@ -159,8 +152,7 @@ def f(x, y):
159
152
fn , _ = compare_jax_and_py (fg , test_values )
160
153
161
154
162
- def test_as_jax_op8 ():
163
- # 2 parameters input with pytree, pytree output
155
+ def test_pytreein_pytreeout ():
164
156
rng = np .random .default_rng (8 )
165
157
x = tensor ("a" , shape = (3 ,))
166
158
y = tensor ("b" , shape = (1 ,))
@@ -180,8 +172,7 @@ def f(x, y):
180
172
fn , _ = compare_jax_and_py (fg , test_values )
181
173
182
174
183
- def test_as_jax_op9 ():
184
- # 2 parameters input with pytree, pytree output and non-graph argument
175
+ def test_pytreein_pytreeout_w_nongraphargs ():
185
176
rng = np .random .default_rng (9 )
186
177
x = tensor ("a" , shape = (3 ,))
187
178
y = tensor ("b" , shape = (1 ,))
@@ -191,18 +182,35 @@ def test_as_jax_op9():
191
182
]
192
183
193
184
@as_jax_op
194
- def f (x , y , non_model_arg ):
195
- return jnp .exp (x ), jax .tree_util .tree_map (jax .nn .sigmoid , y )
196
-
197
- out = f (x , y_tmp , "Hello World!" )
198
- grad_out = grad (pt .sum (out [0 ]), [x ])
185
+ def f (x , y , depth , which_variable ):
186
+ if which_variable == "x" :
187
+ var = x
188
+ elif which_variable == "y" :
189
+ var = y ["a" ] + y ["b" ][0 ]
190
+ else :
191
+ return "Unsupported argument"
192
+ for _ in range (depth ):
193
+ var = jax .nn .sigmoid (var )
194
+ return var
199
195
196
+ # arguments depth and which_variable are not part of the graph
197
+ out = f (x , y_tmp , depth = 3 , which_variable = "x" )
198
+ grad_out = grad (pt .sum (out ), [x ])
200
199
fg = FunctionGraph ([x , y ], [out [0 ], * grad_out ])
201
200
fn , _ = compare_jax_and_py (fg , test_values )
201
+ with jax .disable_jit ():
202
+ fn , _ = compare_jax_and_py (fg , test_values )
202
203
204
+ out = f (x , y_tmp , depth = 7 , which_variable = "y" )
205
+ grad_out = grad (pt .sum (out ), [x ])
206
+ fg = FunctionGraph ([x , y ], [out [0 ], * grad_out ])
207
+ fn , _ = compare_jax_and_py (fg , test_values )
203
208
with jax .disable_jit ():
204
209
fn , _ = compare_jax_and_py (fg , test_values )
205
210
211
+ out = f (x , y_tmp , depth = 10 , which_variable = "z" )
212
+ assert out == "Unsupported argument"
213
+
206
214
207
215
def test_as_jax_op10 ():
208
216
# Use "None" in shape specification and have a non-used output of higher rank
0 commit comments