Skip to content

Commit e45fc51

Browse files
[Linalg][Vectorization] Add support for linalg vectorization of a tensor.extract case (llvm#107922)
In llvm#102321 we relaxed the vectorizer so that when checking for contiguous loads we dont always have a trailing non unit dim. For example in the test case added we have `tensor<8x1xf32>` which is now a valid candidate for contiguous load. However, the logic to check contiguous load assumed that only the trailing dim will be non unit so this PR just updates that logic to find the actual non unit dim.
1 parent c3f9b73 commit e45fc51

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,28 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810810

811811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812812

813+
/// Find the non-unit dim in a linalgOp.
814+
/// When executing this hook, it is expected that only one dim will be non-unit.
815+
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
816+
/// loads before calling this method. This is used for finding contiguous loads
817+
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
818+
/// this condition is expected to hold for statically shaped Linalg Ops only.
819+
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
820+
uint64_t nonUnitDim = 0;
821+
uint64_t countNonUnitDim = 0;
822+
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
823+
if (tripCount.value() != 1) {
824+
nonUnitDim = tripCount.index();
825+
countNonUnitDim++;
826+
}
827+
}
828+
829+
assert(linalgOp.hasDynamicShape() ||
830+
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
831+
"non-unit loop dim is expected");
832+
return nonUnitDim;
833+
}
834+
813835
/// Checks whether `val` can be used for calculating a loop invariant index.
814836
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
815837
VectorType resType) {
@@ -889,11 +911,12 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
889911
Operation *defOp = val.getDefiningOp();
890912
assert(defOp && "This is neither a block argument nor an operation result");
891913

892-
// Given the assumption on the loop ranges above, only the trailing loop
893-
// index is not constant.
894-
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
914+
// Given the assumption on the loop ranges above, we expect only 1 non-unit
915+
// loop dim.
916+
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
917+
895918
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
896-
foundIndexOp = (indexOp.getDim() == trailingLoopDim);
919+
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
897920
return true;
898921
}
899922

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,58 @@ module attributes {transform.with_named_sequence} {
253253
transform.yield
254254
}
255255
}
256+
257+
// -----
258+
259+
#map = affine_map<(d0, d1) -> (d0, d1)>
260+
#map1 = affine_map<(d0, d1, d2) -> (d0 + d1 + d2)>
261+
func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load(%arg0: tensor<8x128x768xf32>, %arg1 : index) -> tensor<8x1xf32> {
262+
%c0 = arith.constant 0 : index
263+
%0 = tensor.empty() : tensor<8x1xf32>
264+
%1 = linalg.generic {
265+
indexing_maps = [#map],
266+
iterator_types = ["parallel", "parallel"]
267+
} outs(%0 : tensor<8x1xf32>) {
268+
^bb0(%arg5: f32):
269+
%2 = linalg.index 0 : index
270+
%3 = linalg.index 1 : index
271+
%4 = affine.apply #map1(%arg1, %3, %arg1)
272+
%extracted = tensor.extract %arg0[%2, %c0, %4] : tensor<8x128x768xf32>
273+
linalg.yield %extracted : f32
274+
} -> tensor<8x1xf32>
275+
return %1 : tensor<8x1xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
280+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
281+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
282+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
283+
transform.yield
284+
}
285+
}
286+
287+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load
288+
// CHECK-SAME: %[[ARG0:.*]]: tensor<8x128x768xf32>
289+
// CHECK-SAME: %[[ARG1:.*]]: index
290+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
291+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<768> : vector<1x8xindex>
292+
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<128> : vector<1x8xindex>
293+
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
294+
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<true> : vector<8x1xi1>
295+
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
296+
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<8x1xf32>
297+
// CHECK: %[[B1:.*]] = vector.broadcast %[[CST_3]] : vector<8xindex> to vector<1x8xindex>
298+
// CHECK: %[[ADDI_ARG1:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : index
299+
// CHECK: %[[B2:.*]] = vector.broadcast %[[ADDI_ARG1]] : index to vector<1xindex>
300+
// CHECK: %[[MULI_1:.*]] = arith.muli %[[B1]], %[[CST_0]] : vector<1x8xindex>
301+
// CHECK: %[[MULI_2:.*]] = arith.muli %[[MULI_1]], %[[CST]] : vector<1x8xindex>
302+
// CHECK: %[[T:.*]] = vector.transpose %[[MULI_2]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
303+
// CHECK: %[[B3:.*]] = vector.broadcast %[[B2]] : vector<1xindex> to vector<8x1xindex>
304+
// CHECK: %[[ADDI:.*]] = arith.addi %[[B3]], %[[T]] : vector<8x1xindex>
305+
// CHECK: %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[ADDI]]], %[[CST_2]], %[[PASSTHRU]] : tensor<8x128x768xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
306+
// CHECK: vector.transfer_write %[[GATHER]], %[[EMPTY]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
307+
256308
// -----
257309

258310
#map = affine_map<(d0) -> (d0)>

0 commit comments

Comments
 (0)