@@ -390,3 +390,68 @@ func.func @parallel_insert_slice_of_insert_slice_dynamic(
390
390
}
391
391
return %0: tensor <12 x34 xf32 >
392
392
}
393
+
394
+ // -----
395
+
396
+ func.func @fold_casting_insert_slice_of_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <8 x1 x8 xf32 >) -> tensor <8 x1 x8 xf32 > {
397
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 8 , 1 , 8 ] [1 , 1 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <8 x8 xf32 >
398
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [8 , 1 , 8 ] [1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <8 x1 x8 xf32 >
399
+ return %inserted_slice : tensor <8 x1 x8 xf32 >
400
+ }
401
+ // CHECK-LABEL: func.func @fold_casting_insert_slice_of_extract_slice(
402
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
403
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 8, 1, 8] [1, 1, 1, 1]
404
+ // CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<8x1x8xf32>
405
+ // CHECK: return %[[EXTRACTED_SLICE]] : tensor<8x1x8xf32>
406
+
407
+ // -----
408
+
409
+ func.func @fold_casting_insert_slice_of_strided_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <1 x4 x8 xf32 >) -> tensor <1 x4 x8 xf32 > {
410
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 4 , 1 , 8 ] [1 , 2 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <4 x8 xf32 >
411
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 4 , 8 ] [1 , 1 , 1 ] : tensor <4 x8 xf32 > into tensor <1 x4 x8 xf32 >
412
+ return %inserted_slice : tensor <1 x4 x8 xf32 >
413
+ }
414
+ // CHECK-LABEL: func.func @fold_casting_insert_slice_of_strided_extract_slice(
415
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
416
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0, 0] [1, 4, 1, 8] [1, 2, 1, 1]
417
+ // CHECK-SAME: : tensor<?x8x2x8xf32> to tensor<1x4x8xf32>
418
+ // CHECK: return %[[EXTRACTED_SLICE]] : tensor<1x4x8xf32>
419
+
420
+ // -----
421
+
422
+ func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice (%in : tensor <?x8 x8 xf32 >, %dest : tensor <1 x1 x8 x8 xf32 >) -> tensor <1 x1 x8 x8 xf32 > {
423
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 1 , 1 ] : tensor <?x8 x8 xf32 > to tensor <8 x8 xf32 >
424
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 , 0 ] [1 , 1 , 8 , 8 ] [1 , 1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <1 x1 x8 x8 xf32 >
425
+ return %inserted_slice : tensor <1 x1 x8 x8 xf32 >
426
+ }
427
+ // CHECK-LABEL: func.func @no_fold_more_unit_dims_insert_slice_of_extract_slice(
428
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x8xf32>
429
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
430
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
431
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<1x1x8x8xf32>
432
+
433
+ // -----
434
+
435
+ func.func @no_fold_strided_insert_slice_of_extract_slice (%in : tensor <?x8 x2 x8 xf32 >, %dest : tensor <1 x4 x4 xf32 >) -> tensor <1 x4 x4 xf32 > {
436
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 ] [1 , 8 , 1 , 8 ] [1 , 1 , 1 , 1 ] : tensor <?x8 x2 x8 xf32 > to tensor <8 x8 xf32 >
437
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 2 , 2 ] : tensor <8 x8 xf32 > into tensor <1 x4 x4 xf32 >
438
+ return %inserted_slice : tensor <1 x4 x4 xf32 >
439
+ }
440
+ // CHECK-LABEL: func.func @no_fold_strided_insert_slice_of_extract_slice(
441
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<?x8x2x8xf32>
442
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
443
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
444
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<1x4x4xf32>
445
+
446
+ // -----
447
+
448
+ func.func @no_fold_non_casting_insert_slice_of_extract_slice (%in : tensor <1 x1 x1 x8 x8 xf32 >, %dest : tensor <2 x8 x8 xf32 >) -> tensor <2 x8 x8 xf32 > {
449
+ %extracted_slice = tensor.extract_slice %in [0 , 0 , 0 , 0 , 0 ] [1 , 1 , 1 , 8 , 8 ] [1 , 1 , 1 , 1 , 1 ] : tensor <1 x1 x1 x8 x8 xf32 > to tensor <8 x8 xf32 >
450
+ %inserted_slice = tensor.insert_slice %extracted_slice into %dest [0 , 0 , 0 ] [1 , 8 , 8 ] [1 , 1 , 1 ] : tensor <8 x8 xf32 > into tensor <2 x8 x8 xf32 >
451
+ return %inserted_slice : tensor <2 x8 x8 xf32 >
452
+ }
453
+ // CHECK-LABEL: func.func @no_fold_non_casting_insert_slice_of_extract_slice(
454
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<1x1x1x8x8xf32>
455
+ // CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]]
456
+ // CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[EXTRACTED_SLICE]]
457
+ // CHECK: return %[[INSERTED_SLICE]] : tensor<2x8x8xf32>
0 commit comments