Skip to content

Commit b47d178

Browse files
authored
[mlir][vector] Refine vectorisation of tensor.extract (#109580)
This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the following case is vectorised as `vector.gather` (as opposed to attempting a contiguous load): ```mlir func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Specifically, when looking for loop-invariant indices in `tensor.extract` Ops, any `linalg.index` Op that's used in address colcluation should only access loop dims that are == 1. In the example above, the following does not meet that criteria: ```mlir %1 = linalg.index 0 : index ``` Note that this PR also effectively addresses the issue fixed in #107922, i.e. exercised by: * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load` `getNonUnitLoopDim` introduced in #107922 is still valid though. In fact, it is required to identify that the following case is a contiguous load: ```mlir func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Some logic is still missing to lower the above to `vector.transfer_read`, so it is conservatively lowered to `vector.gather` instead (see TODO in `getTensorExtractMemoryAccessPattern`). There's a few additional changes: * `getNonUnitLoopDim` is simplified and renamed as `getTrailingNonUnitLoopDimIdx`, additional comments are added (note that the functionality didn't change); * extra comments in a few places, variable names in comments update to use Markdown (which is the preferred approach in MLIR). This is a follow-on for: * #107922 * #102321
1 parent 12033e5 commit b47d178

File tree

2 files changed

+140
-35
lines changed

2 files changed

+140
-35
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 50 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810810

811811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812812

813-
/// Find the non-unit dim in a linalgOp.
814-
/// When executing this hook, it is expected that only one dim will be non-unit.
815-
/// Other cases (i.e. reading n-D vectors) should've been labelled as gather
816-
/// loads before calling this method. This is used for finding contiguous loads
817-
/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that
818-
/// this condition is expected to hold for statically shaped Linalg Ops only.
819-
static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) {
820-
uint64_t nonUnitDim = 0;
821-
uint64_t countNonUnitDim = 0;
822-
for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) {
823-
if (tripCount.value() != 1) {
824-
nonUnitDim = tripCount.index();
825-
countNonUnitDim++;
826-
}
827-
}
828-
813+
/// Find the index of the trailing non-unit dim in linalgOp. This hook is used
814+
/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op)
815+
/// represents a contiguous load operation.
816+
///
817+
/// Note that when calling this hook, it is assumed that the output vector is
818+
/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been
819+
/// labelled as a gather load before entering this method.
820+
///
821+
/// Following on from the above, it is assumed that:
822+
/// * for statically shaped loops, when no masks are used, only one dim is !=
823+
/// 1 (that's what the shape of the output vector is based on).
824+
/// * for dynamically shaped loops, there might be more non-unit dims
825+
/// as the output vector type is user-specified.
826+
///
827+
/// TODO: Statically shaped loops + vector masking
828+
static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) {
829+
SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges();
829830
assert(linalgOp.hasDynamicShape() ||
830-
countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one "
831-
"non-unit loop dim is expected");
832-
(void)countNonUnitDim;
833-
return nonUnitDim;
831+
llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) ==
832+
1 &&
833+
"For statically shaped Linalg Ops, only one "
834+
"non-unit loop dim is expected");
835+
836+
size_t idx = loopRanges.size() - 1;
837+
for (; idx >= 0; idx--)
838+
if (loopRanges[idx] != 1)
839+
break;
840+
841+
return idx;
834842
}
835843

836844
/// Checks whether `val` can be used for calculating a loop invariant index.
@@ -854,11 +862,11 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
854862
assert(defOp && "This is neither a block argument nor an operation result");
855863

856864
// IndexOp is loop invariant as long as its result remains constant across
857-
// iterations. Given the assumptions on the loop ranges above, only the
858-
// trailing loop dim ever changes.
859-
auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1;
860-
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
861-
return (indexOp.getDim() != trailingLoopDim);
865+
// iterations. Note that for dynamic shapes, the corresponding dim will also
866+
// be conservatively treated as != 1.
867+
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
868+
return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1;
869+
}
862870

863871
auto *ancestor = block->findAncestorOpInBlock(*defOp);
864872

@@ -877,7 +885,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
877885
return result;
878886
}
879887

880-
/// Check whether \p val could be used for calculating the trailing index for a
888+
/// Check whether `val` could be used for calculating the trailing index for a
881889
/// contiguous load operation.
882890
///
883891
/// There are currently 3 types of values that are allowed here:
@@ -886,13 +894,14 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
886894
/// 3. results of basic arithmetic operations (linear and continuous)
887895
/// involving 1., 2. and 3.
888896
/// This method returns True if indeed only such values are used in calculating
889-
/// \p val.
897+
/// `val.`
890898
///
891899
/// Additionally, the trailing index for a contiguous load operation should
892900
/// increment by 1 with every loop iteration, i.e. be based on:
893901
/// * `linalg.index <dim>` ,
894-
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
895-
/// updated to `true` when such an op is found.
902+
/// where <dim> is the trailing non-unit dim of the iteration space (this way,
903+
/// `linalg.index <dim>` increments by 1 with every loop iteration).
904+
/// `foundIndexOp` is updated to `true` when such Op is found.
896905
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
897906
bool &foundIndexOp, VectorType resType) {
898907

@@ -912,12 +921,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
912921
Operation *defOp = val.getDefiningOp();
913922
assert(defOp && "This is neither a block argument nor an operation result");
914923

915-
// Given the assumption on the loop ranges above, we expect only 1 non-unit
916-
// loop dim.
917-
auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp);
918-
919924
if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) {
920-
foundIndexOp = (indexOp.getDim() == nonUnitLoopDim);
925+
auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp);
926+
927+
foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne);
921928
return true;
922929
}
923930

@@ -1012,7 +1019,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10121019
bool foundIndexOp = false;
10131020
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
10141021
foundIndexOp, resType);
1015-
isContiguousLoad &= foundIndexOp;
1022+
// TODO: Support generating contiguous loads for column vectors - that will
1023+
// require adding a permutation map to tranfer_read Ops.
1024+
bool isRowVector = resType.getShape().back() != 1;
1025+
isContiguousLoad &= (foundIndexOp && isRowVector);
10161026

10171027
if (isContiguousLoad) {
10181028
LDBG("Found contigous load: " << extractOp);
@@ -1073,6 +1083,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10731083
// b. contiguous loads.
10741084
// Both cases use vector.transfer_read.
10751085

1086+
assert(llvm::count_if(resultType.getShape(),
1087+
[](uint64_t dim) { return dim != 1; }) &&
1088+
"Contiguous loads and scalar loads + broadcast only support 1-D "
1089+
"vectors ATM!");
1090+
10761091
// Collect indices for `vector.transfer_read`. At this point, the indices will
10771092
// either be scalars or would have been broadcast to vectors matching the
10781093
// result type. For indices that are vectors, there are two options:

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,96 @@ module attributes {transform.with_named_sequence} {
307307

308308
// -----
309309

310+
// Reading a 1D column vector (hence a candidate for a contiguous load), but given
311+
// %1, it's a gather load.
312+
313+
#map = affine_map<(d0, d1) -> (d0, d1)>
314+
func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
315+
%c0 = arith.constant 0 : index
316+
%0 = tensor.empty() : tensor<8x1xf32>
317+
%res = linalg.generic {
318+
indexing_maps = [#map],
319+
iterator_types = ["parallel", "parallel"]
320+
} outs(%0 : tensor<8x1xf32>) {
321+
^bb0(%arg1: f32):
322+
%1 = linalg.index 0 : index
323+
%extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32>
324+
linalg.yield %extracted : f32
325+
} -> tensor<8x1xf32>
326+
return %res : tensor<8x1xf32>
327+
}
328+
329+
module attributes {transform.with_named_sequence} {
330+
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
331+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
332+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
333+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
334+
transform.yield
335+
}
336+
}
337+
338+
// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load(
339+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
340+
// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex>
341+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
342+
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
343+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
344+
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
345+
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
346+
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
347+
// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex>
348+
// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
349+
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
350+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
351+
// CHECK: return %[[RES]] : tensor<8x1xf32>
352+
353+
// -----
354+
355+
// Same as above, but the access indices have been swapped and hence this _is_
356+
// a contiguous load. Currently not supported and lowered as vector.gather
357+
// instead.
358+
// TODO: Make sure that this is lowered as a contiguous load.
359+
360+
#map = affine_map<(d0, d1) -> (d0, d1)>
361+
func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> {
362+
%c0 = arith.constant 0 : index
363+
%0 = tensor.empty() : tensor<8x1xf32>
364+
%res = linalg.generic {
365+
indexing_maps = [#map],
366+
iterator_types = ["parallel", "parallel"]
367+
} outs(%0 : tensor<8x1xf32>) {
368+
^bb0(%arg1: f32):
369+
%1 = linalg.index 0 : index
370+
%extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32>
371+
linalg.yield %extracted : f32
372+
} -> tensor<8x1xf32>
373+
return %res : tensor<8x1xf32>
374+
}
375+
376+
module attributes {transform.with_named_sequence} {
377+
transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) {
378+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op
379+
%1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
380+
%2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op
381+
transform.yield
382+
}
383+
}
384+
385+
// CHECK-LABEL: func.func @index_from_output_column_vector_contiguous_load(
386+
// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> {
387+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
388+
// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32>
389+
// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1>
390+
// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex>
391+
// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32>
392+
// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex>
393+
// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex>
394+
// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32>
395+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32>
396+
// CHECK: return %[[RES]] : tensor<8x1xf32>
397+
398+
// -----
399+
310400
#map = affine_map<(d0) -> (d0)>
311401
func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> {
312402
%c5 = arith.constant 5 : index

0 commit comments

Comments
 (0)