@@ -244,6 +244,48 @@ func.func @extract_from_non_constant_create_mask(%dim0: index) -> vector<[2]xi1>
244
244
245
245
// -----
246
246
247
+ // CHECK-LABEL: @extract_elementwise
248
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
249
+ func.func @extract_elementwise (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> f32 {
250
+ // CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : f32 from vector<4xf32>
251
+ // CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : f32 from vector<4xf32>
252
+ // CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : f32
253
+ // CHECK: return %[[RES]] : f32
254
+ %0 = arith.addf %arg0 , %arg1 : vector <4 xf32 >
255
+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
256
+ return %1 : f32
257
+ }
258
+
259
+ // -----
260
+
261
+ // CHECK-LABEL: @extract_vec_elementwise
262
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<2x4xf32>)
263
+ func.func @extract_vec_elementwise (%arg0: vector <2 x4 xf32 >, %arg1: vector <2 x4 xf32 >) -> vector <4 xf32 > {
264
+ // CHECK: %[[EXT0:.*]] = vector.extract %[[ARG0]][1] : vector<4xf32> from vector<2x4xf32>
265
+ // CHECK: %[[EXT1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf32> from vector<2x4xf32>
266
+ // CHECK: %[[RES:.*]] = arith.addf %[[EXT0]], %[[EXT1]] : vector<4xf32>
267
+ // CHECK: return %[[RES]] : vector<4xf32>
268
+ %0 = arith.addf %arg0 , %arg1 : vector <2 x4 xf32 >
269
+ %1 = vector.extract %0 [1 ] : vector <4 xf32 > from vector <2 x4 xf32 >
270
+ return %1 : vector <4 xf32 >
271
+ }
272
+
273
+ // -----
274
+
275
+ // CHECK-LABEL: @extract_elementwise_use
276
+ // CHECK-SAME: (%[[ARG0:.*]]: vector<4xf32>, %[[ARG1:.*]]: vector<4xf32>)
277
+ func.func @extract_elementwise_use (%arg0: vector <4 xf32 >, %arg1: vector <4 xf32 >) -> (f32 , vector <4 xf32 >) {
278
+ // Dop not propagate extract, as elementwise has other uses
279
+ // CHECK: %[[ELT:.*]] = arith.addf %[[ARG0]], %[[ARG1]] : vector<4xf32>
280
+ // CHECK: %[[EXT:.*]] = vector.extract %[[ELT]][1] : f32 from vector<4xf32>
281
+ // CHECK: return %[[EXT]], %[[ELT]] : f32, vector<4xf32>
282
+ %0 = arith.addf %arg0 , %arg1 : vector <4 xf32 >
283
+ %1 = vector.extract %0 [1 ] : f32 from vector <4 xf32 >
284
+ return %1 , %0 : f32 , vector <4 xf32 >
285
+ }
286
+
287
+ // -----
288
+
247
289
// CHECK-LABEL: constant_mask_transpose_to_transposed_constant_mask
248
290
func.func @constant_mask_transpose_to_transposed_constant_mask () -> (vector <2 x3 x4 xi1 >, vector <4 x2 x3 xi1 >) {
249
291
// CHECK: vector.constant_mask [1, 2, 3] : vector<2x3x4xi1>
0 commit comments