Skip to content

Commit 315ba77

Browse files
authored
[mlir][linalg] Vectorisation of tensor.extract - dynamic shapes (#100582)
This PR removes the assumption that reading from a dynamic tensor is always a gather load: ```mlir %extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32> ``` That assumption was originally introduced to simplify the implementation and to reduce the number of cases to consider. Now that the vectorisation of `tensor.extract` has been around for > 1 year and has been quite stable, we can safely relax it. This is a relatively small change - rather than using the parent linalg Op to infer the target output shape (not possible with dynamic shapes), the vectorizer will use the (previously constructed) output vector shape instead. As expected, the following test required updating (`vector.gather` -> `vector.transfer_read`): * @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous Similar test for scalable vectors is also added.
1 parent d1335fb commit 315ba77

File tree

3 files changed

+123
-72
lines changed

3 files changed

+123
-72
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 25 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -810,12 +810,12 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
810810

811811
enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather };
812812

813-
/// Checks whether /p val can be used for calculating a loop invariant index.
814-
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
813+
/// Checks whether `val` can be used for calculating a loop invariant index.
814+
static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val,
815+
VectorType resType) {
815816

816-
auto targetShape = linalgOp.getStaticLoopRanges();
817-
assert(llvm::count_if(targetShape,
818-
[](int64_t dimSize) { return dimSize > 1; }) == 1 &&
817+
assert(((llvm::count_if(resType.getShape(),
818+
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
819819
"n-D vectors are not yet supported");
820820

821821
// Blocks outside _this_ linalg.generic are effectively loop invariant.
@@ -849,7 +849,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
849849

850850
bool result = true;
851851
for (auto op : ancestor->getOperands())
852-
result &= isLoopInvariantIdx(linalgOp, op);
852+
result &= isLoopInvariantIdx(linalgOp, op, resType);
853853

854854
return result;
855855
}
@@ -871,10 +871,9 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val) {
871871
/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is
872872
/// updated to `true` when such an op is found.
873873
static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
874-
bool &foundIndexOp) {
874+
bool &foundIndexOp, VectorType resType) {
875875

876-
auto targetShape = linalgOp.getStaticLoopRanges();
877-
assert(((llvm::count_if(targetShape,
876+
assert(((llvm::count_if(resType.getShape(),
878877
[](int64_t dimSize) { return dimSize > 1; }) == 1)) &&
879878
"n-D vectors are not yet supported");
880879

@@ -910,44 +909,38 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
910909

911910
bool result = false;
912911
for (auto op : ancestor->getOperands())
913-
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp);
912+
result |= isContiguousLoadIdx(linalgOp, op, foundIndexOp, resType);
914913

915914
return result;
916915
}
917916

918917
/// Infer the memory access pattern for the input ExtractOp
919918
///
920-
/// Based on the operation shapes and indices (usually based on the iteration
921-
/// space of the parent `linalgOp` operation), decides whether the input
922-
/// ExtractOp is a contiguous load (including a broadcast of a scalar) or a
923-
/// gather load.
919+
/// Based on the ExtratOp result shape and the access indices, decides whether
920+
/// this Op corresponds to a contiguous load (including a broadcast of a scalar)
921+
/// or a gather load. When analysing the ExtractOp indices (to identify
922+
/// contiguous laods), this method looks for "loop" invariant indices (e.g.
923+
/// block arguments) and indices that change linearly (e.g. via `linalg.index`
924+
/// Op).
924925
///
925926
/// Note that it is always safe to use gather load operations for contiguous
926927
/// loads (albeit slow), but not vice-versa. When in doubt, bail out and assume
927928
/// that `extractOp` is a gather load.
928929
static VectorMemoryAccessKind
929930
getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
930-
LinalgOp &linalgOp) {
931+
LinalgOp &linalgOp, VectorType resType) {
931932

932-
auto targetShape = linalgOp.getStaticLoopRanges();
933933
auto inputShape = cast<ShapedType>(extractOp.getTensor().getType());
934934

935-
// 0.1 Is this a 0-D vector? If yes then this is a scalar broadcast.
935+
// 0. Is this a 0-D vector? If yes then this is a scalar broadcast.
936936
if (inputShape.getShape().empty())
937937
return VectorMemoryAccessKind::ScalarBroadcast;
938938

939-
// 0.2 In the case of dynamic shapes just bail-out and assume that it's a
940-
// gather load.
941-
// TODO: Relax this condition.
942-
if (linalgOp.hasDynamicShape())
943-
return VectorMemoryAccessKind::Gather;
944-
945939
// True for vectors that are effectively 1D, e.g. `vector<1x4x1xi32>`, false
946940
// otherwise.
947-
bool isOutput1DVector = (llvm::count_if(targetShape, [](int64_t dimSize) {
948-
return dimSize > 1;
949-
}) == 1);
950-
941+
bool isOutput1DVector =
942+
(llvm::count_if(resType.getShape(),
943+
[](int64_t dimSize) { return dimSize > 1; }) == 1);
951944
// 1. Assume that it's a gather load when reading non-1D vector.
952945
if (!isOutput1DVector)
953946
return VectorMemoryAccessKind::Gather;
@@ -965,7 +958,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
965958
if (inputShape.getShape()[i] == 1)
966959
continue;
967960

968-
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal);
961+
leadingIdxsLoopInvariant &= isLoopInvariantIdx(linalgOp, indexVal, resType);
969962
}
970963

971964
if (!leadingIdxsLoopInvariant) {
@@ -982,7 +975,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
982975
// 3a. Scalar broadcast load
983976
// If the trailing index is loop invariant then this is a scalar load.
984977
if (leadingIdxsLoopInvariant &&
985-
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx)) {
978+
isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) {
986979
LDBG("Found scalar broadcast load: " << extractOp);
987980

988981
return VectorMemoryAccessKind::ScalarBroadcast;
@@ -993,8 +986,8 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
993986
// This effectively means that it must be based on the trailing loop index.
994987
// This is what the following bool captures.
995988
bool foundIndexOp = false;
996-
bool isContiguousLoad =
997-
isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp);
989+
bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx,
990+
foundIndexOp, resType);
998991
isContiguousLoad &= foundIndexOp;
999992

1000993
if (isContiguousLoad) {
@@ -1035,7 +1028,7 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
10351028
rewriter.create<arith::ConstantIndexOp>(loc, 0));
10361029

10371030
VectorMemoryAccessKind memAccessKind =
1038-
getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
1031+
getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType);
10391032

10401033
// 1. Handle gather access
10411034
if (memAccessKind == VectorMemoryAccessKind::Gather) {

mlir/test/Dialect/Linalg/vectorization-scalable.mlir

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,17 +162,14 @@ func.func @vectorize_linalg_index(%arg0: tensor<3x3x?xf32>, %arg1: tensor<1x1x?x
162162

163163
// CHECK-LABEL: @vectorize_linalg_index
164164
// CHECK-SAME: %[[SRC:.*]]: tensor<3x3x?xf32>, %[[DST:.*]]: tensor<1x1x?xf32>
165-
// CHECK-DAG: %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x[4]xf32>
166-
// CHECK-DAG: %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x[4]xi1>
167165
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
168166
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
169167
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
170168
// CHECK: %[[DST_DIM2:.*]] = tensor.dim %[[DST]], %[[C2]] : tensor<1x1x?xf32>
171-
// CHECK: %[[DST_MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
169+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[C1]], %[[C1]], %[[DST_DIM2]] : vector<1x1x[4]xi1>
172170
// CHECK: %[[INDEX_VEC:.*]] = vector.step : vector<[4]xindex>
173-
// CHECK: %[[INDEX_VEC_BCAST:.*]] = vector.broadcast %[[INDEX_VEC]] : vector<[4]xindex> to vector<1x1x[4]xindex>
174-
// CHECK: %[[GATHER:.*]] = vector.mask %[[DST_MASK]] { vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {{\[}}%[[INDEX_VEC_BCAST]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x?xf32>, vector<1x1x[4]xindex>, vector<1x1x[4]xi1>, vector<1x1x[4]xf32> into vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
175-
// CHECK: %[[OUT:.*]] = vector.mask %[[DST_MASK]] { vector.transfer_write %[[GATHER]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
171+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]][%c0, %c0, %2], %cst {in_bounds = [true, true, true]} : tensor<3x3x?xf32>, vector<1x1x[4]xf32> } : vector<1x1x[4]xi1> -> vector<1x1x[4]xf32>
172+
// CHECK: %[[OUT:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[DST]]{{\[}}%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x[4]xf32>, tensor<1x1x?xf32> } : vector<1x1x[4]xi1> -> tensor<1x1x?xf32>
176173
// CHECK: return %[[OUT]] : tensor<1x1x?xf32>
177174

178175
module attributes {transform.with_named_sequence} {

mlir/test/Dialect/Linalg/vectorize-tensor-extract-masked.mlir

Lines changed: 95 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -120,52 +120,54 @@ module attributes {transform.with_named_sequence} {
120120

121121
// -----
122122

123-
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(%6: tensor<?x?xf32>, %arg0: index, %extracted_slice : tensor<?x?xf32>) -> tensor<?x?xf32> {
123+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
124+
%src: tensor<?x?xf32>,
125+
%output : tensor<?x?xf32>,
126+
%idx: index) -> tensor<?x?xf32> {
127+
124128
%c79 = arith.constant 79 : index
125129
%1 = linalg.generic {
126130
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
127131
iterator_types = ["parallel", "parallel"]
128-
} outs(%extracted_slice : tensor<?x?xf32>) {
132+
} outs(%output : tensor<?x?xf32>) {
129133
^bb0(%out: f32):
130134
%2 = linalg.index 1 : index
131-
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %arg0)
132-
%extracted = tensor.extract %6[%c79, %3] : tensor<?x?xf32>
135+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
136+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
133137
linalg.yield %extracted : f32
134138
} -> tensor<?x?xf32>
135139
return %1 : tensor<?x?xf32>
136140
}
137141

138142
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous(
139-
// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
140-
// CHECK-SAME: %[[VAL_1:.*]]: index,
141-
// CHECK-SAME: %[[VAL_2:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
142-
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 79 : index
143-
// CHECK-DAG: %[[VAL_4:.*]] = arith.constant 0 : index
144-
// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>
145-
// CHECK: %[[VAL_6:.*]] = arith.constant 1 : index
146-
// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_2]], %[[VAL_6]] : tensor<?x?xf32>
147-
// CHECK-DAG: %[[VAL_8:.*]] = arith.constant 0 : index
148-
// CHECK-DAG: %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
149-
// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<1x4xi1>
150-
// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_2]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
151-
// CHECK: %[[VAL_12:.*]] = vector.step : vector<4xindex>
152-
// CHECK: %[[VAL_13:.*]] = vector.broadcast %[[VAL_1]] : index to vector<4xindex>
153-
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : vector<4xindex>
154-
// CHECK-DAG: %[[VAL_15:.*]] = arith.constant dense<true> : vector<1x4xi1>
155-
// CHECK-DAG: %[[VAL_16:.*]] = arith.constant dense<0.000000e+00> : vector<1x4xf32>
156-
// CHECK-DAG: %[[VAL_17:.*]] = arith.constant 0 : index
157-
// CHECK-DAG: %[[VAL_18:.*]] = arith.constant dense<79> : vector<1x4xindex>
158-
// CHECK-DAG: %[[VAL_19:.*]] = arith.constant 1 : index
159-
// CHECK: %[[VAL_20:.*]] = tensor.dim %[[VAL_0]], %[[VAL_19]] : tensor<?x?xf32>
160-
// CHECK: %[[VAL_21:.*]] = vector.broadcast %[[VAL_20]] : index to vector<1x4xindex>
161-
// CHECK: %[[VAL_22:.*]] = arith.muli %[[VAL_18]], %[[VAL_21]] : vector<1x4xindex>
162-
// CHECK: %[[VAL_23:.*]] = vector.broadcast %[[VAL_14]] : vector<4xindex> to vector<1x4xindex>
163-
// CHECK: %[[VAL_24:.*]] = arith.addi %[[VAL_23]], %[[VAL_22]] : vector<1x4xindex>
164-
// CHECK: %[[VAL_25:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_17]], %[[VAL_17]]] {{\[}}%[[VAL_24]]], %[[VAL_15]], %[[VAL_16]] : tensor<?x?xf32>, vector<1x4xindex>, vector<1x4xi1>, vector<1x4xf32> into vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
165-
// CHECK: %[[VAL_26:.*]] = arith.constant 0 : index
166-
// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_25]], %[[VAL_2]]{{\[}}%[[VAL_26]], %[[VAL_26]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
167-
// CHECK: return %[[VAL_27]] : tensor<?x?xf32>
168-
// CHECK: }
143+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
144+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
145+
// CHECK-SAME: %[[IDX:.*]]: index)
146+
147+
/// Create the mask
148+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
149+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
150+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
151+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
152+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
153+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x4xi1>
154+
155+
/// TODO: This transfer_read is redundant - remove
156+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
157+
158+
/// Caluclate the index vector
159+
// CHECK: %[[STEP:.*]] = vector.step : vector<4xindex>
160+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<4xindex>
161+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<4xindex>
162+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
163+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<4xindex> to vector<4xindex>
164+
165+
/// Extract the starting point from the index vector
166+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<4xindex>
167+
168+
// Final read and write
169+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x4xf32> } : vector<1x4xi1> -> vector<1x4xf32>
170+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x4xf32>, tensor<?x?xf32> } : vector<1x4xi1> -> tensor<?x?xf32>
169171

170172
module attributes {transform.with_named_sequence} {
171173
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -177,6 +179,65 @@ module attributes {transform.with_named_sequence} {
177179

178180
// -----
179181

182+
func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
183+
%src: tensor<?x?xf32>,
184+
%output : tensor<?x?xf32>,
185+
%idx: index) -> tensor<?x?xf32> {
186+
187+
%c79 = arith.constant 79 : index
188+
%1 = linalg.generic {
189+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
190+
iterator_types = ["parallel", "parallel"]
191+
} outs(%output : tensor<?x?xf32>) {
192+
^bb0(%out: f32):
193+
%2 = linalg.index 1 : index
194+
%3 = affine.apply affine_map<(d0, d1) -> (d0 + d1)>(%2, %idx)
195+
%extracted = tensor.extract %src[%c79, %3] : tensor<?x?xf32>
196+
linalg.yield %extracted : f32
197+
} -> tensor<?x?xf32>
198+
return %1 : tensor<?x?xf32>
199+
}
200+
201+
// CHECK-LABEL: func.func @masked_dynamic_vectorize_nd_tensor_extract_with_affine_apply_contiguous_scalable(
202+
// CHECK-SAME: %[[SRC:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
203+
// CHECK-SAME: %[[OUTPUT:[a-zA-Z0-9]*]]: tensor<?x?xf32>,
204+
// CHECK-SAME: %[[IDX:.*]]: index)
205+
206+
/// Create the mask
207+
// CHECK: %[[C79:.*]] = arith.constant 79 : index
208+
// CHECK: %[[DIM_0_IDX:.*]] = arith.constant 0 : index
209+
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_0_IDX]] : tensor<?x?xf32>
210+
// CHECK: %[[DIM_1_IDX:.*]] = arith.constant 1 : index
211+
// CHECK: %[[DIM_1:.*]] = tensor.dim %[[OUTPUT]], %[[DIM_1_IDX]] : tensor<?x?xf32>
212+
// CHECK: %[[MASK:.*]] = vector.create_mask %[[DIM_0]], %[[DIM_1]] : vector<1x[4]xi1>
213+
214+
/// TODO: This transfer_read is redundant - remove
215+
// CHECK: vector.mask %[[MASK]] { vector.transfer_read %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
216+
217+
/// Caluclate the index vector
218+
// CHECK: %[[STEP:.*]] = vector.step : vector<[4]xindex>
219+
// CHECK: %[[IDX_BC:.*]] = vector.broadcast %[[IDX]] : index to vector<[4]xindex>
220+
// CHECK: %[[IDX_VEC:.*]] = arith.addi %[[STEP]], %[[IDX_BC]] : vector<[4]xindex>
221+
// CHECK: %[[C0:.*]] = arith.constant 0 : i32
222+
// CHECK: %[[SC:.*]] = vector.shape_cast %[[IDX_VEC]] : vector<[4]xindex> to vector<[4]xindex>
223+
224+
/// Extract the starting point from the index vector
225+
// CHECK: %[[IDX_START:.*]] = vector.extractelement %[[SC]]{{\[}}%[[C0]] : i32] : vector<[4]xindex>
226+
227+
// Final read and write
228+
// CHECK: %[[READ:.*]] = vector.mask %[[MASK]] { vector.transfer_read %[[SRC]]{{\[}}%[[C79]], %[[IDX_START]]], {{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<1x[4]xf32> } : vector<1x[4]xi1> -> vector<1x[4]xf32>
229+
// CHECK: %[[VAL_24:.*]] = vector.mask %[[MASK]] { vector.transfer_write %[[READ]], %[[OUTPUT]]{{.*}} {in_bounds = [true, true]} : vector<1x[4]xf32>, tensor<?x?xf32> } : vector<1x[4]xi1> -> tensor<?x?xf32>
230+
231+
module attributes {transform.with_named_sequence} {
232+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
233+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
234+
transform.structured.vectorize %0 vector_sizes [1, [4]] {vectorize_nd_extract} : !transform.any_op
235+
transform.yield
236+
}
237+
}
238+
239+
// -----
240+
180241
func.func @masked_vectorize_nd_tensor_extract_with_affine_apply_gather(%6: tensor<80x16xf32>, %arg0: index, %extracted_slice : tensor<1x3xf32>) -> tensor<1x3xf32> {
181242
%c16 = arith.constant 16 : index
182243
%1 = linalg.generic {

0 commit comments

Comments
 (0)