@@ -537,3 +537,71 @@ func.func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor<?x?xi32>, %sz0:
537
537
// CHECK: %[[GENERIC:.+]] = linalg.generic
538
538
// CHECK-SAME: ins(%[[EXPAND_ARG0]] :
539
539
// CHECK: return %[[GENERIC]]
540
+
541
+ // -----
542
+
543
+ func.func @fuse_by_collapsing_pad (%arg0 : tensor <2 x12 x5 x336 x9 xi32 >) -> tensor <8 x3 x4 x17 x6 x7 x8 x14 xi32 > {
544
+ %expand = tensor.expand_shape %arg0 [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] output_shape [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] : tensor <2 x12 x5 x336 x9 xi32 > into tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
545
+ %cst = arith.constant 0 : i32
546
+ %padded_0 = tensor.pad %expand low [1 , 0 , 0 , 8 , 0 , 0 , 0 , 3 ] high [5 , 0 , 0 , 4 , 0 , 0 , 0 , 2 ] {
547
+ ^bb0 (%arg1: index , %arg2: index , %arg3: index , %arg4: index ,
548
+ %arg5: index , %arg6: index , %arg7: index , %arg8: index ):
549
+ tensor.yield %cst : i32
550
+ } : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > to tensor <8 x3 x4 x17 x6 x7 x8 x14 xi32 >
551
+ return %padded_0 : tensor <8 x3 x4 x17 x6 x7 x8 x14 xi32 >
552
+ }
553
+ // CHECK: func @fuse_by_collapsing_pad(
554
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
555
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
556
+ // CHECK-SAME: low[1, 0, 8, 0, 3] high[5, 0, 4, 0, 2]
557
+ // CHECK: tensor<2x12x5x336x9xi32> to tensor<8x12x17x336x14xi32>
558
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
559
+ // CHECK-SAME: output_shape [8, 3, 4, 17, 6, 7, 8, 14] : tensor<8x12x17x336x14xi32> into tensor<8x3x4x17x6x7x8x14xi32>
560
+ // CHECK: return %[[EXPAND]]
561
+
562
+ // -----
563
+
564
+ func.func @no_fuse_by_collapsing_pad (%arg0 : tensor <2 x12 x5 x336 x9 xi32 >) -> tensor <8 x5 x4 x17 x6 x7 x8 x14 xi32 > {
565
+ %expand = tensor.expand_shape %arg0 [[0 ], [1 , 2 ], [3 ], [4 , 5 , 6 ], [7 ]] output_shape [2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 ] : tensor <2 x12 x5 x336 x9 xi32 > into tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 >
566
+ %cst = arith.constant 0 : i32
567
+ %padded_0 = tensor.pad %expand low [1 , 2 , 0 , 8 , 0 , 0 , 0 , 3 ] high [5 , 0 , 0 , 4 , 0 , 0 , 0 , 2 ] {
568
+ ^bb0 (%arg1: index , %arg2: index , %arg3: index , %arg4: index ,
569
+ %arg5: index , %arg6: index , %arg7: index , %arg8: index ):
570
+ tensor.yield %cst : i32
571
+ } : tensor <2 x3 x4 x5 x6 x7 x8 x9 xi32 > to tensor <8 x5 x4 x17 x6 x7 x8 x14 xi32 >
572
+ return %padded_0 : tensor <8 x5 x4 x17 x6 x7 x8 x14 xi32 >
573
+ }
574
+ // CHECK: func @no_fuse_by_collapsing_pad(
575
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32>)
576
+ // CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]]
577
+ // CHECK-SAME: output_shape [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32>
578
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[EXPAND_ARG0]]
579
+ // CHECK-SAME: low[1, 2, 0, 8, 0, 0, 0, 3] high[5, 0, 0, 4, 0, 0, 0, 2]
580
+ // CHECK: tensor<2x3x4x5x6x7x8x9xi32> to tensor<8x5x4x17x6x7x8x14xi32>
581
+ // CHECK: return %[[PAD]]
582
+
583
+ // -----
584
+
585
+ func.func @fuse_by_collapsing_dynamic_pad (%arg0 : tensor <?x?x?x?xf32 >,
586
+ %s0 : index , %s1 : index , %s2 : index , %s3 : index , %s4 : index , %s5 : index ,
587
+ %l0 : index , %l1 : index , %h0 : index , %h1 : index ) -> tensor <?x?x?x?x?x?xf32 > {
588
+ %expand = tensor.expand_shape %arg0 [[0 ], [1 , 2 ], [3 ], [4 , 5 ]] output_shape [%s0 , %s1 , %s2 , %s3 , %s4 , %s5 ] : tensor <?x?x?x?xf32 > into tensor <?x?x?x?x?x?xf32 >
589
+ %cst = arith.constant 0.0 : f32
590
+ %padded_0 = tensor.pad %expand low [%l0 , 0 , 0 , %l1 , 0 , 0 ] high [%h0 , 0 , 0 , %h1 , 0 , 0 ] {
591
+ ^bb0 (%arg1: index , %arg2: index , %arg3: index , %arg4: index , %arg5: index , %arg6: index ):
592
+ tensor.yield %cst : f32
593
+ } : tensor <?x?x?x?x?x?xf32 > to tensor <?x?x?x?x?x?xf32 >
594
+ return %padded_0 : tensor <?x?x?x?x?x?xf32 >
595
+ }
596
+ // CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 + s2)>
597
+ // CHECK: func @fuse_by_collapsing_dynamic_pad(
598
+ // CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?x?xf32>
599
+ // CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index, %[[S4:.+]]: index, %[[S5:.+]]: index, %[[L0:.+]]: index, %[[L1:.+]]: index, %[[H0:.+]]: index, %[[H1:.+]]: index
600
+ // CHECK: %[[PAD_SIZE0:.+]] = affine.apply #[[MAP]]()[%[[L0]], %[[H0]], %[[S0]]]
601
+ // CHECK: %[[PAD_SIZE1:.+]] = affine.apply #[[MAP]]()[%[[L1]], %[[H1]], %[[S3]]]
602
+ // CHECK: %[[PAD:.+]] = tensor.pad %[[ARG0]]
603
+ // CHECK-SAME: low[%[[L0]], 0, %[[L1]], 0] high[%[[H0]], 0, %[[H1]], 0]
604
+ // CHECK: tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
605
+ // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[PAD]] {{\[}}[0], [1, 2], [3], [4, 5]]
606
+ // CHECK-SAME: output_shape [%[[PAD_SIZE0]], %[[S1]], %[[S2]], %[[PAD_SIZE1]], %[[S4]], %[[S5]]] : tensor<?x?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
607
+ // CHECK: return %[[EXPAND]]
0 commit comments