@@ -318,3 +318,81 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens
318
318
}
319
319
return %for0 : tensor <64 x128 xf32 >
320
320
}
321
+
322
+ // -----
323
+
324
+ #map0 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d3 )>
325
+ #map1 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d3 , d2 )>
326
+ #map2 = affine_map <(d0 , d1 , d2 , d3 ) -> (d0 , d1 , d2 )>
327
+ #map3 = affine_map <(d0 , d1 , d2 ) -> (d0 , d2 )>
328
+ #map4 = affine_map <(d0 , d1 , d2 ) -> (d2 , d1 )>
329
+ #map5 = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 )>
330
+ func.func @rank_reduced_extract_slice (
331
+ %prod_in: tensor <1 x6 x5 xf32 >, %prod_weight: tensor <1 x5 x6 xf32 >,
332
+ %cons_in: tensor <4 x6 xf32 >, %prod_init: tensor <1 x6 x6 xf32 >,
333
+ %for_iv_init: tensor <4 x6 xf32 >, %cons_init: tensor <4 x2 xf32 >
334
+ ) -> tensor <4 x6 xf32 > {
335
+ %c0 = arith.constant 0 : index
336
+ %c2 = arith.constant 2 : index
337
+ %c6 = arith.constant 6 : index
338
+ %mmul_prod = linalg.generic
339
+ {index ing_maps = [#map0 , #map1 , #map2 ], iterator_types = [" parallel" , " parallel" , " parallel" , " reduction" ]}
340
+ ins (%prod_in , %prod_weight : tensor <1 x6 x5 xf32 >, tensor <1 x5 x6 xf32 >) outs (%prod_init : tensor <1 x6 x6 xf32 >) {
341
+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
342
+ %10 = arith.mulf %in , %in_1 : f32
343
+ %11 = arith.addf %out , %10 : f32
344
+ linalg.yield %11 : f32
345
+ } -> tensor <1 x6 x6 xf32 >
346
+ %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args (%arg6 = %for_iv_init ) -> (tensor <4 x6 xf32 >) {
347
+
348
+ // Extract slice with rank-reduced result type. When fused in the loop
349
+ // with sliced operands, the producer linalg must have its now sliced
350
+ // result be rank-reduced as well to match consumer's use type.
351
+ %prod_slice = tensor.extract_slice %mmul_prod [0 , 0 , %arg7 ] [1 , 6 , 2 ] [1 , 1 , 1 ] : tensor <1 x6 x6 xf32 > to tensor <6 x2 xf32 >
352
+ %mmul_cons = linalg.generic
353
+ {index ing_maps = [#map3 , #map4 , #map5 ], iterator_types = [" parallel" , " parallel" , " reduction" ]}
354
+ ins (%cons_in , %prod_slice : tensor <4 x6 xf32 >, tensor <6 x2 xf32 >) outs (%cons_init : tensor <4 x2 xf32 >) {
355
+ ^bb0 (%in: f32 , %in_1: f32 , %out: f32 ):
356
+ %20 = arith.mulf %in , %in_1 : f32
357
+ %21 = arith.addf %out , %20 : f32
358
+ linalg.yield %21 : f32
359
+ } -> tensor <4 x2 xf32 >
360
+ %4 = tensor.insert_slice %mmul_cons into %arg6 [0 , %arg7 ] [4 , 2 ] [1 , 1 ] : tensor <4 x2 xf32 > into tensor <4 x6 xf32 >
361
+ scf.yield %4 : tensor <4 x6 xf32 >
362
+ }
363
+ return %for : tensor <4 x6 xf32 >
364
+ }
365
+
366
+ // CHECK: func @rank_reduced_extract_slice(
367
+ // CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32>
368
+ // CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32>
369
+ // CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32>
370
+ // CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32>
371
+ // CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32>
372
+ // CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32>
373
+
374
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
375
+ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
376
+ // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index
377
+
378
+ // For loop right after tensor alloc & fill, no linalg.generic.
379
+ // CHECK-NOT: linalg.generic
380
+ // CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]])
381
+
382
+ // Producer linalg.generic now inside the loop, with tiled args sliced before
383
+ // it.
384
+ // CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32>
385
+ // CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32>
386
+ // CHECK: %[[MMUL_PROD:.*]] = linalg.generic
387
+ // CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>)
388
+ // CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>)
389
+ //
390
+ // Consumer uses a rank-reduced version of producer result so a collapse_shape
391
+ // is generated.
392
+ // CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32>
393
+ // CHECK: %[[MMUL_CONS:.*]] = linalg.generic
394
+ // CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>)
395
+ // CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>)
396
+ // CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32>
397
+ // CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32>
398
+ // CHECK: return %[[FOR]] : tensor<4x6xf32>
0 commit comments