@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
253
253
transform.yield
254
254
}
255
255
}
256
+
257
+ // -----
258
+
259
+ #map = affine_map <(d0 , d1 ) -> (d0 , d1 )>
260
+ #map1 = affine_map <(d0 , d1 , d2 ) -> (d0 + d1 + d2 )>
261
+ func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load (%arg0: tensor <8 x128 x768 xf32 >, %arg1 : index ) -> tensor <8 x1 xf32 > {
262
+ %c0 = arith.constant 0 : index
263
+ %0 = tensor.empty () : tensor <8 x1 xf32 >
264
+ %1 = linalg.generic {
265
+ indexing_maps = [#map ],
266
+ iterator_types = [" parallel" , " parallel" ]
267
+ } outs (%0 : tensor <8 x1 xf32 >) {
268
+ ^bb0 (%arg5: f32 ):
269
+ %2 = linalg.index 0 : index
270
+ %3 = linalg.index 1 : index
271
+ %4 = affine.apply #map1 (%arg1 , %3 , %arg1 )
272
+ %extracted = tensor.extract %arg0 [%2 , %c0 , %4 ] : tensor <8 x128 x768 xf32 >
273
+ linalg.yield %extracted : f32
274
+ } -> tensor <8 x1 xf32 >
275
+ return %1 : tensor <8 x1 xf32 >
276
+ }
277
+
278
+ module attributes {transform.with_named_sequence } {
279
+ transform.named_sequence @__transform_main (%arg0: !transform.any_op {transform.readonly }) {
280
+ %0 = transform.structured.match ops {[" linalg.generic" ]} in %arg0 : (!transform.any_op ) -> !transform.any_op
281
+ %1 = transform.get_parent_op %0 {isolated_from_above } : (!transform.any_op ) -> !transform.any_op
282
+ %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract } : (!transform.any_op ) -> !transform.any_op
283
+ transform.yield
284
+ }
285
+ }
286
+
287
+ // CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
288
+ // CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289
+ // CHECK-SAME: %[[ARG1:.*]]: index
290
+ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
291
+ // CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
292
+ // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
293
+ // CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
294
+ // CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
295
+ // CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
296
+ // CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
297
+ // CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
298
+ // CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
299
+ // CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
300
+ // CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
301
+ // CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
302
+ // CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
303
+ // CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
304
+ // CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
305
+ // CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
306
+ // CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
307
+
256
308
// -----
257
309
258
310
#map = affine_map <(d0 ) -> (d0 )>
0 commit comments