Skip to content

Commit 5362092

Browse files
committed
Reapply "[mlir][linalg] Relax tensor.extract vectorization" (llvm#102232)
[This reverts commit 6662523d6b2ca0198141c94ee80ebbb41601df9f] Simplifies the vectorization of tensor.extract so that: * all cases that read into a genuinely multi-dim vector (*) are considered a gather load, * all other cases are considered as potential contiguous loads. This change means that the following extraction from a "column" tensor is correctly identified as a scalar load followed by a broadcast (rather than a gather load). ```mlir func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> { %c4 = arith.constant 4 : index %c0 = arith.constant 0 : index %cst = arith.constant dense<[...]> : tensor<15x1xi32> %out = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) { ^bb0(%out: i32): %8 = linalg.index 0 : index %idx_0 = linalg.index 0 : index %extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32> linalg.yield %extracted : i32 } -> tensor<1x1x4xi32> return %out:tensor<1x1x4xi32> } ``` Overview of the delta when compared to the original submission: * removed an assert representing a conditon that is being relaxed here, * added a test (reading from a column tensor) based on a repro from @hanhanW. (*) `vector<1x4x1xf32>` is considered as 1D vector in this context.
1 parent 96d824d commit 5362092

File tree

2 files changed

+121
-37
lines changed

2 files changed

+121
-37
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -814,11 +814,9 @@ enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
814814
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
815815

816816
auto targetShape = linalgOp.getStaticLoopRanges();
817-
assert(((llvm::count_if(targetShape,
818-
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
817+
assert(llvm::count_if(targetShape,
818+
[](int64_t dimSize) { return dimSize > 1; }) == 1 &&
819819
"n-D vectors are not yet supported");
820-
assert(targetShape.back() != 1 &&
821-
"1-D vectors with the trailing dim eqaual 1 are not yet supported");
822820

823821
// Blocks outside _this_ linalg.generic are effectively loop invariant.
824822
// However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -879,8 +877,6 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
879877
assert(((llvm::count_if(targetShape,
880878
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
881879
"n-D vectors are not yet supported");
882-
assert(targetShape.back() != 1 &&
883-
"1-D vectors with the trailing dim 1 are not yet supported");
884880

885881
// Blocks outside _this_ linalg.generic are effectively loop invariant.
886882
// However, analysing block arguments for _this_ linalg.generic Op is a bit
@@ -946,27 +942,22 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
946942
if (linalgOp.hasDynamicShape())
947943
return VectorMemoryAccessKind::Gather;
948944

949-
// 1. Assume that it's a gather load when reading _into_:
950-
// * an n-D "vector", like `tensor<1x2x4xi32` or `tensor<2x1x4xi32>`, or
951-
// * a 1-D "vector" with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
952-
// TODO: Relax these conditions.
953-
// FIXME: This condition assumes non-dynamic sizes.
954-
if ((llvm::count_if(targetShape,
955-
[](int64_t dimSize) { return dimSize > 1; }) != 1) ||
956-
targetShape.back() == 1)
957-
return VectorMemoryAccessKind::Gather;
945+
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
946+
// otherwise.
947+
bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
948+
return dimSize > 1;
949+
}) == 1);
958950

959-
// 2. Assume that it's a gather load when reading _from_ a tensor for which
960-
// the trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
961-
// TODO: Relax this condition.
962-
if (inputShape.getShape().back() == 1)
951+
// 1. Assume that it's a gather load when reading non-1D vector.
952+
if (!isOutput1DVector)
963953
return VectorMemoryAccessKind::Gather;
964954

965955
bool leadingIdxsLoopInvariant = true;
966956

967-
// 3. Analyze the leading indices of `extractOp`.
957+
// 2. Analyze the leading indices of `extractOp`.
968958
// Look at the way each index is calculated and decide whether it is suitable
969-
// for a contiguous load, i.e. whether it's loop invariant.
959+
// for a contiguous load, i.e. whether it's loop invariant. If not, it's a
960+
// gather load.
970961
auto indices = extractOp.getIndices();
971962
auto leadIndices = indices.drop_back(1);
972963

@@ -982,13 +973,13 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
982973
return VectorMemoryAccessKind::Gather;
983974
}
984975

985-
// 4. Analyze the trailing index for `extractOp`.
976+
// 3. Analyze the trailing index for `extractOp`.
986977
// At this point we know that the leading indices are loop invariant. This
987978
// means that is potentially a scalar or a contiguous load. We can decide
988979
// based on the trailing idx.
989980
auto extractOpTrailingIdx = indices.back();
990981

991-
// 4a. Scalar broadcast load
982+
// 3a. Scalar broadcast load
992983
// If the trailing index is loop invariant then this is a scalar load.
993984
if (leadingIdxsLoopInvariant &&
994985
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
@@ -997,7 +988,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
997988
return VectorMemoryAccessKind::ScalarBroadcast;
998989
}
999990

1000-
// 4b. Contiguous loads
991+
// 3b. Contiguous loads
1001992
// The trailing `extractOp` index should increment with every loop iteration.
1002993
// This effectively means that it must be based on the trailing loop index.
1003994
// This is what the following bool captures.
@@ -1011,7 +1002,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
10111002
return VectorMemoryAccessKind::Contiguous;
10121003
}
10131004

1014-
// 5. Fallback case - gather load.
1005+
// 4. Fallback case - gather load.
10151006
LDBG("Found gather load: " << extractOp);
10161007
return VectorMemoryAccessKind::Gather;
10171008
}

mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir

Lines changed: 105 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ module attributes {transform.with_named_sequence} {
3737
}
3838

3939
// -----
40+
4041
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
4142
func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
4243
%c0 = arith.constant 1 : index
@@ -74,20 +75,24 @@ module attributes {transform.with_named_sequence} {
7475

7576
// -----
7677

77-
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
78-
func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
79-
%1 = linalg.generic {
80-
indexing_maps = [#map1],
78+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
79+
func.func @vectorize_nd_tensor_extract_transfer_read_basic(
80+
%arg0: tensor<3x3x3xf32>,
81+
%arg1: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
82+
83+
%res = linalg.generic {
84+
indexing_maps = [#map],
8185
iterator_types = ["parallel", "parallel", "parallel"]
82-
} outs(%arg2 : tensor<1x1x3xf32>) {
83-
^bb0(%arg4: f32):
84-
%2 = linalg.index 0 : index
85-
%3 = linalg.index 1 : index
86-
%4 = linalg.index 2 : index
87-
%5 = tensor.extract %arg0[%2, %3, %4] : tensor<3x3x3xf32>
88-
linalg.yield %5 : f32
86+
} outs(%arg1 : tensor<1x1x3xf32>) {
87+
^bb0(%out: f32):
88+
%1 = linalg.index 0 : index
89+
%2 = linalg.index 1 : index
90+
%3 = linalg.index 2 : index
91+
%4 = tensor.extract %arg0[%1, %2, %3] : tensor<3x3x3xf32>
92+
linalg.yield %4 : f32
8993
} -> tensor<1x1x3xf32>
90-
return %1 : tensor<1x1x3xf32>
94+
95+
return %res : tensor<1x1x3xf32>
9196
}
9297

9398
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
@@ -104,6 +109,38 @@ func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf
104109
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
105110
// CHECK: vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
106111

112+
// Same as example above, but reading into a column tensor. Note that after the
113+
// vectorizatoin, the `TransferOpReduceRank` will replace
114+
// `vector.transfer_read` with `tensor.extract -> scalar`.
115+
116+
// TODO: Currently this fails to vectorise when the indices are non-constant.
117+
118+
func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
119+
%input: tensor<3x3x3xf32>,
120+
%output: tensor<3x1x1xf32>) -> tensor<3x1x1xf32> {
121+
122+
%c0 = arith.constant 0 : index
123+
%res = linalg.generic {
124+
indexing_maps = [#map],
125+
iterator_types = ["parallel", "parallel", "parallel"]
126+
} outs(%output : tensor<3x1x1xf32>) {
127+
^bb0(%out: f32):
128+
%5 = tensor.extract %input[%c0, %c0, %c0] : tensor<3x3x3xf32>
129+
linalg.yield %5 : f32
130+
} -> tensor<3x1x1xf32>
131+
132+
return %res : tensor<3x1x1xf32>
133+
}
134+
135+
// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic_column(
136+
// CHECK-SAME: %[[INPUT:.*]]: tensor<3x3x3xf32>,
137+
// CHECK-SAME: %[[OUTPUT:.*]]: tensor<3x1x1xf32>)
138+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
139+
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[INPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] : tensor<3x3x3xf32>
140+
// CHECK: %[[BCAST:.*]] = vector.broadcast %[[EXTRACT]] : f32 to vector<3x1x1xf32>
141+
// CHECK: %[[RES:.*]] = vector.transfer_write %[[BCAST]], %[[OUTPUT]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<3x1x1xf32>, tensor<3x1x1xf32>
142+
// CHECK: return %[[RES]] : tensor<3x1x1xf32>
143+
107144
module attributes {transform.with_named_sequence} {
108145
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
109146
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
@@ -595,3 +632,59 @@ module attributes {transform.with_named_sequence} {
595632
transform.yield
596633
}
597634
}
635+
636+
637+
// -----
638+
639+
func.func @vectorize_scalar_broadcast_column_tensor(%in: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
640+
%c4 = arith.constant 4 : index
641+
%c0 = arith.constant 0 : index
642+
%cst = arith.constant dense<[[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
643+
644+
%out = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} outs(%in : tensor<1x1x4xi32>) {
645+
^bb0(%out: i32):
646+
%8 = linalg.index 0 : index
647+
%idx_0 = linalg.index 0 : index
648+
%extracted = tensor.extract %cst[%idx_0, %c0] : tensor<15x1xi32>
649+
linalg.yield %extracted : i32
650+
} -> tensor<1x1x4xi32>
651+
652+
return %out:tensor<1x1x4xi32>
653+
}
654+
655+
// CHECK: #[[$ATTR_1:.+]] = affine_map<(d0, d1) -> (0, 0, 0)>
656+
// CHECK-LABEL: func.func @vectorize_scalar_broadcast_column_tensor(
657+
// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x4xi32>) -> tensor<1x1x4xi32> {
658+
// CHECK: %[[VAL_1:.*]] = arith.constant 4 : index
659+
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
660+
// CHECK: %[[VAL_3:.*]] = arith.constant dense<{{\[\[}}0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11], [12], [13], [14]]> : tensor<15x1xi32>
661+
// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
662+
// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
663+
// CHECK: %[[VAL_6:.*]] = arith.constant 4 : index
664+
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
665+
// CHECK: %[[VAL_8:.*]] = arith.constant 0 : i32
666+
// CHECK: %[[VAL_9:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_7]], %[[VAL_7]], %[[VAL_7]]], %[[VAL_8]] : tensor<1x1x4xi32>, vector<1x1x4xi32>
667+
// CHECK: %[[VAL_10:.*]] = vector.step : vector<1xindex>
668+
// CHECK: %[[VAL_11:.*]] = vector.broadcast %[[VAL_10]] : vector<1xindex> to vector<4x1x1xindex>
669+
// CHECK: %[[VAL_12:.*]] = vector.transpose %[[VAL_11]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
670+
// CHECK: %[[VAL_13:.*]] = vector.step : vector<1xindex>
671+
// CHECK: %[[VAL_14:.*]] = vector.broadcast %[[VAL_13]] : vector<1xindex> to vector<4x1x1xindex>
672+
// CHECK: %[[VAL_15:.*]] = vector.transpose %[[VAL_14]], [2, 1, 0] : vector<4x1x1xindex> to vector<1x1x4xindex>
673+
// CHECK: %[[VAL_16:.*]] = arith.constant dense<true> : vector<1x1x4xi1>
674+
// CHECK: %[[VAL_17:.*]] = arith.constant dense<0> : vector<1x1x4xi32>
675+
// CHECK: %[[VAL_18:.*]] = arith.constant 0 : index
676+
// CHECK: %[[VAL_19:.*]] = arith.constant 0 : i32
677+
// CHECK: %[[VAL_20:.*]] = vector.shape_cast %[[VAL_15]] : vector<1x1x4xindex> to vector<4xindex>
678+
// CHECK: %[[VAL_21:.*]] = vector.extractelement %[[VAL_20]]{{\[}}%[[VAL_19]] : i32] : vector<4xindex>
679+
// CHECK: %[[VAL_22:.*]] = arith.constant 0 : i32
680+
// CHECK: %[[VAL_23:.*]] = vector.transfer_read %[[VAL_3]]{{\[}}%[[VAL_21]], %[[VAL_2]]], %[[VAL_22]] {in_bounds = [true, true, true], permutation_map = #[[$ATTR_1]]} : tensor<15x1xi32>, vector<1x1x4xi32>
681+
// CHECK: %[[VAL_24:.*]] = arith.constant 0 : index
682+
// CHECK: %[[VAL_25:.*]] = vector.transfer_write %[[VAL_23]], %[[VAL_0]]{{\[}}%[[VAL_24]], %[[VAL_24]], %[[VAL_24]]] : vector<1x1x4xi32>, tensor<1x1x4xi32>
683+
684+
module attributes {transform.with_named_sequence} {
685+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
686+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
687+
transform.structured.vectorize %0 vector_sizes [1, 1, 4]{ vectorize_nd_extract } : !transform.any_op
688+
transform.yield
689+
}
690+
}

0 commit comments

Comments
 (0)