@@ -215,6 +215,58 @@ def test_blockwise_shape():
215
215
assert tuple (shape_fn (inp1_test , inp2_test )[1 ]) == (7 , 5 , 4 )
216
216
217
217
218
+ def test_blockwise_infer_core_shape ():
219
+ class TestOpWithInferShape (Op ):
220
+ def make_node (self , a , b ):
221
+ assert a .type .ndim == 1
222
+ assert b .type .ndim == 1
223
+ c = tensor (shape = (None ,))
224
+ d = tensor (shape = (None ,))
225
+ return Apply (self , [a , b ], [c , d ])
226
+
227
+ def perform (self , node , inputs , outputs ):
228
+ a , b = inputs
229
+ c , d = outputs
230
+ c [0 ] = np .arange (a .size + b .size )
231
+ d [0 ] = np .arange (a .sum () + b .sum ())
232
+
233
+ def infer_shape (self , fgraph , node , input_shapes ):
234
+ # First output shape depends only on input_shapes
235
+ # Second output shape depends on input values
236
+ x , y = node .inputs
237
+ [(x_shape ,), (y_shape ,)] = input_shapes
238
+ return (x_shape + y_shape ,), (x .sum () + y .sum (),)
239
+
240
+ blockwise_op = Blockwise (
241
+ core_op = TestOpWithInferShape (), signature = "(a),(b)->(c),(d)"
242
+ )
243
+
244
+ a = tensor ("a" , shape = (5 , 3 ))
245
+ b = tensor ("b" , shape = (1 , 4 ))
246
+ c , d = blockwise_op (a , b )
247
+ assert c .type .shape == (5 , None )
248
+ assert d .type .shape == (5 , None )
249
+
250
+ c_shape_fn = pytensor .function ([a , b ], c .shape )
251
+ # c_shape can be computed from the input shapes alone
252
+ assert not any (
253
+ isinstance (getattr (n .op , "core_op" , n .op ), TestOpWithInferShape )
254
+ for n in c_shape_fn .maker .fgraph .apply_nodes
255
+ )
256
+
257
+ d_shape_fn = pytensor .function ([a , b ], d .shape )
258
+ # d_shape cannot be computed from the input shapes alone
259
+ assert any (
260
+ isinstance (getattr (n .op , "core_op" , n .op ), TestOpWithInferShape )
261
+ for n in d_shape_fn .maker .fgraph .apply_nodes
262
+ )
263
+
264
+ a_test = np .zeros (a .type .shape )
265
+ b_test = np .zeros (b .type .shape )
266
+ assert tuple (c_shape_fn (a_test , b_test )) == (5 , 7 )
267
+ assert tuple (d_shape_fn (a_test , b_test )) == (5 , 0 )
268
+
269
+
218
270
class BlockwiseOpTester :
219
271
"""Base class to test Blockwise works for specific Ops"""
220
272
0 commit comments