6
6
import pytensor .tensor as pt
7
7
from pytensor import as_jax_op , config , grad
8
8
from pytensor .graph .fg import FunctionGraph
9
+ from pytensor .link .jax .ops import JAXOp
9
10
from pytensor .scalar import all_types
10
- from pytensor .tensor import tensor
11
+ from pytensor .tensor import TensorType , tensor
11
12
from tests .link .jax .test_basic import compare_jax_and_py
12
13
13
14
@@ -19,18 +20,29 @@ def test_two_inputs_single_output():
19
20
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
20
21
]
21
22
22
- @as_jax_op
23
23
def f (x , y ):
24
24
return jax .nn .sigmoid (x + y )
25
25
26
- out = f (x , y )
26
+ # Test with as_jax_op decorator
27
+ out = as_jax_op (f )(x , y )
27
28
grad_out = grad (pt .sum (out ), [x , y ])
28
29
29
30
fg = FunctionGraph ([x , y ], [out , * grad_out ])
30
31
fn , _ = compare_jax_and_py (fg , test_values )
31
32
with jax .disable_jit ():
32
33
fn , _ = compare_jax_and_py (fg , test_values )
33
34
35
+ # Test direct JAXOp usage
36
+ jax_op = JAXOp (
37
+ [x .type , y .type ],
38
+ [TensorType (config .floatX , shape = (2 ,))],
39
+ f ,
40
+ )
41
+ out = jax_op (x , y )
42
+ grad_out = grad (pt .sum (out ), [x , y ])
43
+ fg = FunctionGraph ([x , y ], [out , * grad_out ])
44
+ fn , _ = compare_jax_and_py (fg , test_values )
45
+
34
46
35
47
def test_two_inputs_tuple_output ():
36
48
rng = np .random .default_rng (2 )
@@ -40,11 +52,11 @@ def test_two_inputs_tuple_output():
40
52
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
41
53
]
42
54
43
- @as_jax_op
44
55
def f (x , y ):
45
56
return jax .nn .sigmoid (x + y ), y * 2
46
57
47
- out1 , out2 = f (x , y )
58
+ # Test with as_jax_op decorator
59
+ out1 , out2 = as_jax_op (f )(x , y )
48
60
grad_out = grad (pt .sum (out1 + out2 ), [x , y ])
49
61
50
62
fg = FunctionGraph ([x , y ], [out1 , out2 , * grad_out ])
@@ -54,6 +66,17 @@ def f(x, y):
54
66
# inputs are not automatically transformed to jax.Array anymore
55
67
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
56
68
69
+ # Test direct JAXOp usage
70
+ jax_op = JAXOp (
71
+ [x .type , y .type ],
72
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
73
+ f ,
74
+ )
75
+ out1 , out2 = jax_op (x , y )
76
+ grad_out = grad (pt .sum (out1 + out2 ), [x , y ])
77
+ fg = FunctionGraph ([x , y ], [out1 , out2 , * grad_out ])
78
+ fn , _ = compare_jax_and_py (fg , test_values )
79
+
57
80
58
81
def test_two_inputs_list_output_one_unused_output ():
59
82
# One output is unused, to test whether the wrapper can handle DisconnectedType
@@ -64,72 +87,119 @@ def test_two_inputs_list_output_one_unused_output():
64
87
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
65
88
]
66
89
67
- @as_jax_op
68
90
def f (x , y ):
69
91
return [jax .nn .sigmoid (x + y ), y * 2 ]
70
92
71
- out , _ = f (x , y )
93
+ # Test with as_jax_op decorator
94
+ out , _ = as_jax_op (f )(x , y )
72
95
grad_out = grad (pt .sum (out ), [x , y ])
73
96
74
97
fg = FunctionGraph ([x , y ], [out , * grad_out ])
75
98
fn , _ = compare_jax_and_py (fg , test_values )
76
99
with jax .disable_jit ():
77
100
fn , _ = compare_jax_and_py (fg , test_values )
78
101
102
+ # Test direct JAXOp usage
103
+ jax_op = JAXOp (
104
+ [x .type , y .type ],
105
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
106
+ f ,
107
+ )
108
+ out , _ = jax_op (x , y )
109
+ grad_out = grad (pt .sum (out ), [x , y ])
110
+ fg = FunctionGraph ([x , y ], [out , * grad_out ])
111
+ fn , _ = compare_jax_and_py (fg , test_values )
112
+
79
113
80
114
def test_single_input_tuple_output ():
81
115
rng = np .random .default_rng (4 )
82
116
x = tensor ("x" , shape = (2 ,))
83
117
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
84
118
85
- @as_jax_op
86
119
def f (x ):
87
120
return jax .nn .sigmoid (x ), x * 2
88
121
89
- out1 , out2 = f (x )
122
+ # Test with as_jax_op decorator
123
+ out1 , out2 = as_jax_op (f )(x )
90
124
grad_out = grad (pt .sum (out1 ), [x ])
91
125
92
126
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
93
127
fn , _ = compare_jax_and_py (fg , test_values )
94
128
with jax .disable_jit ():
95
129
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
96
130
131
+ # Test direct JAXOp usage
132
+ jax_op = JAXOp (
133
+ [x .type ],
134
+ [TensorType (config .floatX , shape = (2 ,)), TensorType (config .floatX , shape = (2 ,))],
135
+ f ,
136
+ )
137
+ out1 , out2 = jax_op (x )
138
+ grad_out = grad (pt .sum (out1 ), [x ])
139
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
140
+ fn , _ = compare_jax_and_py (fg , test_values )
141
+
97
142
98
143
def test_scalar_input_tuple_output ():
99
144
rng = np .random .default_rng (5 )
100
145
x = tensor ("x" , shape = ())
101
146
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
102
147
103
- @as_jax_op
104
148
def f (x ):
105
149
return jax .nn .sigmoid (x ), x
106
150
107
- out1 , out2 = f (x )
151
+ # Test with as_jax_op decorator
152
+ out1 , out2 = as_jax_op (f )(x )
108
153
grad_out = grad (pt .sum (out1 ), [x ])
109
154
110
155
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
111
156
fn , _ = compare_jax_and_py (fg , test_values )
112
157
with jax .disable_jit ():
113
158
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
114
159
160
+ # Test direct JAXOp usage
161
+ jax_op = JAXOp (
162
+ [x .type ],
163
+ [TensorType (config .floatX , shape = ()), TensorType (config .floatX , shape = ())],
164
+ f ,
165
+ )
166
+ out1 , out2 = jax_op (x )
167
+ grad_out = grad (pt .sum (out1 ), [x ])
168
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
169
+ fn , _ = compare_jax_and_py (fg , test_values )
170
+
115
171
116
172
def test_single_input_list_output ():
117
173
rng = np .random .default_rng (6 )
118
174
x = tensor ("x" , shape = (2 ,))
119
175
test_values = [rng .normal (size = (x .type .shape )).astype (config .floatX )]
120
176
121
- @as_jax_op
122
177
def f (x ):
123
178
return [jax .nn .sigmoid (x ), 2 * x ]
124
179
125
- out1 , out2 = f (x )
180
+ # Test with as_jax_op decorator
181
+ out1 , out2 = as_jax_op (f )(x )
126
182
grad_out = grad (pt .sum (out1 ), [x ])
127
183
128
184
fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
129
185
fn , _ = compare_jax_and_py (fg , test_values )
130
186
with jax .disable_jit ():
131
187
fn , _ = compare_jax_and_py (fg , test_values , must_be_device_array = False )
132
188
189
+ # Test direct JAXOp usage, with unspecified output shapes
190
+ jax_op = JAXOp (
191
+ [x .type ],
192
+ [
193
+ TensorType (config .floatX , shape = (None ,)),
194
+ TensorType (config .floatX , shape = (None ,)),
195
+ ],
196
+ f ,
197
+ )
198
+ out1 , out2 = jax_op (x )
199
+ grad_out = grad (pt .sum (out1 ), [x ])
200
+ fg = FunctionGraph ([x ], [out1 , out2 , * grad_out ])
201
+ fn , _ = compare_jax_and_py (fg , test_values )
202
+
133
203
134
204
def test_pytree_input_tuple_output ():
135
205
rng = np .random .default_rng (7 )
@@ -144,6 +214,7 @@ def test_pytree_input_tuple_output():
144
214
def f (x , y ):
145
215
return jax .nn .sigmoid (x ), 2 * x + y ["y" ] + y ["y2" ][0 ]
146
216
217
+ # Test with as_jax_op decorator
147
218
out = f (x , y_tmp )
148
219
grad_out = grad (pt .sum (out [1 ]), [x , y ])
149
220
@@ -167,6 +238,7 @@ def test_pytree_input_pytree_output():
167
238
def f (x , y ):
168
239
return x , jax .tree_util .tree_map (lambda x : jnp .exp (x ), y )
169
240
241
+ # Test with as_jax_op decorator
170
242
out = f (x , y_tmp )
171
243
grad_out = grad (pt .sum (out [1 ]["b" ][0 ]), [x , y ])
172
244
@@ -198,6 +270,7 @@ def f(x, y, depth, which_variable):
198
270
var = jax .nn .sigmoid (var )
199
271
return var
200
272
273
+ # Test with as_jax_op decorator
201
274
# arguments depth and which_variable are not part of the graph
202
275
out = f (x , y_tmp , depth = 3 , which_variable = "x" )
203
276
grad_out = grad (pt .sum (out ), [x ])
@@ -228,11 +301,11 @@ def test_unused_matrix_product():
228
301
rng .normal (size = (inp .type .shape )).astype (config .floatX ) for inp in (x , y )
229
302
]
230
303
231
- @as_jax_op
232
304
def f (x , y ):
233
305
return x [:, None ] @ y [None ], jnp .exp (x )
234
306
235
- out = f (x , y )
307
+ # Test with as_jax_op decorator
308
+ out = as_jax_op (f )(x , y )
236
309
grad_out = grad (pt .sum (out [1 ]), [x ])
237
310
238
311
fg = FunctionGraph ([x , y ], [out [1 ], * grad_out ])
@@ -241,6 +314,20 @@ def f(x, y):
241
314
with jax .disable_jit ():
242
315
fn , _ = compare_jax_and_py (fg , test_values )
243
316
317
+ # Test direct JAXOp usage
318
+ jax_op = JAXOp (
319
+ [x .type , y .type ],
320
+ [
321
+ TensorType (config .floatX , shape = (3 , 3 )),
322
+ TensorType (config .floatX , shape = (3 ,)),
323
+ ],
324
+ f ,
325
+ )
326
+ out = jax_op (x , y )
327
+ grad_out = grad (pt .sum (out [1 ]), [x ])
328
+ fg = FunctionGraph ([x , y ], [out [1 ], * grad_out ])
329
+ fn , _ = compare_jax_and_py (fg , test_values )
330
+
244
331
245
332
def test_unknown_static_shape ():
246
333
rng = np .random .default_rng (11 )
@@ -252,11 +339,10 @@ def test_unknown_static_shape():
252
339
253
340
x_cumsum = pt .cumsum (x ) # Now x_cumsum has an unknown shape
254
341
255
- @as_jax_op
256
342
def f (x , y ):
257
343
return x * jnp .ones (3 )
258
344
259
- out = f (x_cumsum , y )
345
+ out = as_jax_op ( f ) (x_cumsum , y )
260
346
grad_out = grad (pt .sum (out ), [x ])
261
347
262
348
fg = FunctionGraph ([x , y ], [out , * grad_out ])
@@ -265,6 +351,17 @@ def f(x, y):
265
351
with jax .disable_jit ():
266
352
fn , _ = compare_jax_and_py (fg , test_values )
267
353
354
+ # Test direct JAXOp usage
355
+ jax_op = JAXOp (
356
+ [x .type , y .type ],
357
+ [TensorType (config .floatX , shape = (None ,))],
358
+ f ,
359
+ )
360
+ out = jax_op (x_cumsum , y )
361
+ grad_out = grad (pt .sum (out ), [x ])
362
+ fg = FunctionGraph ([x , y ], [out , * grad_out ])
363
+ fn , _ = compare_jax_and_py (fg , test_values )
364
+
268
365
269
366
def test_nested_functions ():
270
367
rng = np .random .default_rng (13 )
0 commit comments