@@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
129
129
130
130
// -----
131
131
132
+ // CHECK-LABEL: @static_rank_reduce
133
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32
134
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32>
135
+ // CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] {
136
+ // CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32>
137
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32>
138
+ // CHECK: return %[[RESULT]]
139
+ func.func @static_rank_reduce (%arg0: tensor <8 x16 x4 xf32 >, %pad: f32 )
140
+ -> tensor <16 x4 xf32 > {
141
+ %0 = tensor.pad %arg0 low [0 , 2 , 0 ] high [0 , 0 , 0 ] {
142
+ ^bb0 (%i: index , %j: index , %k: index ):
143
+ tensor.yield %pad : f32
144
+ } : tensor <8 x16 x4 xf32 > to tensor <8 x18 x4 xf32 >
145
+ %1 = tensor.extract_slice %0 [0 , 0 , 0 ] [1 , 16 , 4 ] [1 , 1 , 1 ]
146
+ : tensor <8 x18 x4 xf32 > to tensor <16 x4 xf32 >
147
+ return %1 : tensor <16 x4 xf32 >
148
+ }
149
+
150
+ // -----
151
+
132
152
// CHECK-LABEL: @dynamic_high_pad
133
153
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x5xf32>
134
154
// CHECK-NOT: tensor.pad
@@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor<?x?xf32>, %pad : f32,
217
237
return %1 : tensor <?x?xf32 >
218
238
}
219
239
240
+ // -----
241
+
242
+ // CHECK-LABEL: @dynamic_rank_reduce
243
+ // CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) {
244
+ // CHECK: tensor.generate
245
+ // CHECK: } else {
246
+ // CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor<?x5xf32> to tensor<?x1xf32>
247
+ // CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] {
248
+ // CHECK: } : tensor<?x1xf32> to tensor<1x4xf32>
249
+ // CHECK: }
250
+ // CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32>
251
+ // CHECK: return %[[RESULT]]
252
+ func.func @dynamic_rank_reduce (%arg0 : tensor <?x5 xf32 >, %s1: index , %pad : f32 ) -> tensor <4 xf32 > {
253
+ %0 = tensor.pad %arg0 low [0 , 0 ] high [7 , 8 ] {
254
+ ^bb0 (%arg1: index , %arg2: index ):
255
+ tensor.yield %pad : f32
256
+ } : tensor <?x5 xf32 > to tensor <?x13 xf32 >
257
+ %1 = tensor.extract_slice %0 [2 , 4 ] [1 , 4 ] [1 , 1 ] : tensor <?x13 xf32 > to tensor <4 xf32 >
258
+ return %1 : tensor <4 xf32 >
259
+ }
260
+
220
261
// -----
221
262
// CHECK-LABEL: @nopaddim_with_dynamic_extract(
222
263
// CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>
0 commit comments