-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Reapply "[mlir][linalg] Relax tensor.extract vectorization" (#102232) #102321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reapply "[mlir][linalg] Relax tensor.extract vectorization" (#102232) #102321
Conversation
) [This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f] Simplifies the vectorization of tensor.extract so that: * all cases that read into a genuinely multi-dim vector (*) are considered a gather load, * all other cases are considered as potential contiguous loads. This change means that the following extraction from a "column" tensor is correctly identified as a scalar load followed by a broadcast (rather than a gather load). ```mlir func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %cst = arith.constant dense<[...]> : tensor<15x1xi32> %out = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) { ^bb0(%out: i32): %8 = linalg.index 0 : index %idx_0 = linalg.index 0 : index %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32> linalg.yield %extracted : i32 } -> tensor<1x1x4xi32> return %out:tensor<1x1x4xi32> } ``` Overview of the delta when compared to the original submission: * removed an assert representing a conditon that is being relaxed here, * added a test (reading from a column tensor) based on a repro from @hanhanW. (*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f] Simplifies the vectorization of tensor.extract so that:
This change means that the following extraction from a "column" tensor func.func @<!-- -->vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
%c4 = arith.constant 4 : index
%c0 = arith.constant 0 : index
%cst = arith.constant dense<[...]> : tensor<15x1xi32>
%out = linalg.generic {
indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%in : tensor<1x1x4xi32>) {
^bb0(%out: i32):
%8 = linalg.index 0 : index
%idx_0 = linalg.index 0 : index
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
linalg.yield %extracted : i32
} -> tensor<1x1x4xi32>
return %out:tensor<1x1x4xi32>
} Overview of the delta when compared to the original submission:
(*) Full diff: https://github.com/llvm/llvm-project/pull/102321.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 3d0d6abf702d7..63dcda78d0f2b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -814,11 +814,9 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
auto targetShape = linalgOp.getStaticLoopRanges();
- assert(((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
+ assert(llvm::count_if(targetShape,
+ [](int64_t dimSize) { return dimSize > 1; }) == 1 &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
- "1-D vectors with the trailing dim eqaual 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
// However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -879,8 +877,6 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
assert(((llvm::count_if(targetShape,
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
"n-D vectors are not yet supported");
- assert(targetShape.back() != 1 &&
- "1-D vectors with the trailing dim 1 are not yet supported");
// Blocks outside _this_ linalg.generic are effectively loop invariant.
// However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -946,27 +942,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
if (linalgOp.hasDynamicShape())
return VectorMemoryAccessKind::Gather;
- // 1. Assume that it's a gather load when reading _into_:
- // * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
- // * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
- // TODO: Relax these conditions.
- // FIXME: This condition assumes non-dynamic sizes.
- if ((llvm::count_if(targetShape,
- [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
- targetShape.back() == 1)
- return VectorMemoryAccessKind::Gather;
+ // True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
+ // otherwise.
+ bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
+ return dimSize > 1;
+ }) == 1);
- // 2. Assume that it's a gather load when reading _from_ a tensor for which
- // the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
- // TODO: Relax this condition.
- if (inputShape.getShape().back() == 1)
+ // 1. Assume that it's a gather load when reading non-1D vector.
+ if (!isOutput1DVector)
return VectorMemoryAccessKind::Gather;
bool leadingIdxsLoopInvariant = true;
- // 3. Analyze the leading indices of `extractOp`.
+ // 2. Analyze the leading indices of `extractOp`.
// Look at the way each index is calculated and decide whether it is suitable
- // for a contiguous load, i.e. whether it's loop invariant.
+ // for a contiguous load, i.e. whether it's loop invariant. If not, it's a
+ // gather load.
auto indices = extractOp.getIndices();
auto leadIndices = indices.drop_back(1);
@@ -982,13 +973,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::Gather;
}
- // 4. Analyze the trailing index for `extractOp`.
+ // 3. Analyze the trailing index for `extractOp`.
// At this point we know that the leading indices are loop invariant. This
// means that is potentially a scalar or a contiguous load. We can decide
// based on the trailing idx.
auto extractOpTrailingIdx = indices.back();
- // 4a. Scalar broadcast load
+ // 3a. Scalar broadcast load
// If the trailing index is loop invariant then this is a scalar load.
if (leadingIdxsLoopInvariant &&
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -997,7 +988,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::ScalarBroadcast;
}
- // 4b. Contiguous loads
+ // 3b. Contiguous loads
// The trailing `extractOp` index should increment with every loop iteration.
// This effectively means that it must be based on the trailing loop index.
// This is what the following bool captures.
@@ -1011,7 +1002,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
return VectorMemoryAccessKind::Contiguous;
}
- // 5. Fallback case - gather load.
+ // 4. Fallback case - gather load.
LDBG("Found gather load: " << extractOp);
return VectorMemoryAccessKind::Gather;
}
diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
index 85e1c56dd45a0..bdaa20c3bf971 100644
--- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
+++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
}
// -----
+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
%c0 = arith.constant 1 : index
@@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} {
// -----
-#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
- %1 = linalg.generic {
- indexing_maps = [#map1],
+#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+func.func @vectorize_nd_tensor_extract_transfer_read_basic(
+ %arg0: tensor<3x3x3xf32>,
+ %arg1: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+
+ %res = linalg.generic {
+ indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]
- } outs(%arg2 : tensor<1x1x3xf32>) {
- ^bb0(%arg4: f32):
- %2 = linalg.index 0 : index
- %3 = linalg.index 1 : index
- %4 = linalg.index 2 : index
- %5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
- linalg.yield %5 : f32
+ } outs(%arg1 : tensor<1x1x3xf32>) {
+ ^bb0(%out: f32):
+ %1 = linalg.index 0 : index
+ %2 = linalg.index 1 : index
+ %3 = linalg.index 2 : index
+ %4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x3xf32>
+ linalg.yield %4 : f32
} -> tensor<1x1x3xf32>
- return %1 : tensor<1x1x3xf32>
+
+ return %res : tensor<1x1x3xf32>
}
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
@@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
+// Same as example above, but reading into a column tensor. Note that after the
+// vectorizatoin, the `TransferOpReduceRank` will replace
+// `vector.transfer_read` with `tensor.extract -> scalar`.
+
+// TODO: Currently this fails to vectorise when the indices are non-constant.
+
+func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+ %input: tensor<3x3x3xf32>,
+ %output: tensor<3x1x1xf32>) -> tensor<3x1x1xf32> {
+
+ %c0 = arith.constant 0 : index
+ %res = linalg.generic {
+ indexing_maps = [#map],
+ iterator_types = ["parallel", "parallel", "parallel"]
+ } outs(%output : tensor<3x1x1xf32>) {
+ ^bb0(%out: f32):
+ %5 = tensor.extract %input[%c0, %c0, %c0] : tensor<3x3x3xf32>
+ linalg.yield %5 : f32
+ } -> tensor<3x1x1xf32>
+
+ return %res : tensor<3x1x1xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
+// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
+// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
+// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
+// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
+// CHECK: return %[[RES]] : tensor<3x1x1xf32>
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+
+// -----
+
+func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+ %c4 = arith.constant 4 : index
+ %c0 = arith.constant 0 : index
+ %cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+
+ %out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
+ ^bb0(%out: i32):
+ %8 = linalg.index 0 : index
+ %idx_0 = linalg.index 0 : index
+ %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
+ linalg.yield %extracted : i32
+ } -> tensor<1x1x4xi32>
+
+ return %out:tensor<1x1x4xi32>
+}
+
+// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
+// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
+// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
+// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
+// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
+// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
+// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
+// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
+// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
+// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
+// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
+// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
+// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
+// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
+// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
+ transform.yield
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
…sor.extract case (#107922) In #102321 we relaxed the vectorizer so that when checking for contiguous loads we dont always have a trailing non unit dim. For example in the test case added we have `tensor<8x1xf32>` which is now a valid candidate for contiguous load. However, the logic to check contiguous load assumed that only the trailing dim will be non unit so this PR just updates that logic to find the actual non unit dim.
This PR fixes a bug in `isLoopInvariantIdx`. It makes sure that the following case is vectorised as `vector.gather` (as opposed to attempting a contiguous load): ```mlir func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Specifically, when looking for loop-invariant indices in `tensor.extract` Ops, any `linalg.index` Op that's used in address colcluation should only access loop dims that are == 1. In the example above, the following does not meet that criteria: ```mlir %1 = linalg.index 0 : index ``` Note that this PR also effectively addresses the issue fixed in #107922, i.e. exercised by: * `@vectorize_nd_tensor_extract_load_1d_column_vector_using_gather_load` `getNonUnitLoopDim` introduced in #107922 is still valid though. In fact, it is required to identify that the following case is a contiguous load: ```mlir func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { %c0 = arith.constant 0 : index %0 = tensor.empty() : tensor<8x1xf32> %res = linalg.generic { indexing_maps = [#map], iterator_types = ["parallel", "parallel"] } outs(%0 : tensor<8x1xf32>) { ^bb0(%arg1: f32): %1 = linalg.index 0 : index %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32> linalg.yield %extracted : f32 } -> tensor<8x1xf32> return %res : tensor<8x1xf32> } ``` Some logic is still missing to lower the above to `vector.transfer_read`, so it is conservatively lowered to `vector.gather` instead (see TODO in `getTensorExtractMemoryAccessPattern`). There's a few additional changes: * `getNonUnitLoopDim` is simplified and renamed as `getTrailingNonUnitLoopDimIdx`, additional comments are added (note that the functionality didn't change); * extra comments in a few places, variable names in comments update to use Markdown (which is the preferred approach in MLIR). This is a follow-on for: * #107922 * #102321
[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f]
Simplifies the vectorization of tensor.extract so that:
considered a gather load,
This change means that the following extraction from a "column" tensor
is correctly identified as a scalar load followed by a broadcast (rather
than a gather load).
Overview of the delta compared to the original submission (#99299):
here,
@hanhanW.
(*)
vector<1x4x1xf32>
is considered as 1D vector in this context.