Skip to content

Commit 8fe3975

Browse files
committed
[mlir][vector] Standardise valueToStore Naming Across Vector Ops (NFC)
This change standardises the naming convention for the argument representing the value to store in various vector operations. Specifically, it ensures that all vector ops storing a value—whether into memory, a tensor, or another vector — use `valueToStore` for the corresponding argument name. Updated operations: * `vector.transfer_write`, `vector.insert`, `vector.scalable_insert`, `vector.insert_strided_slice`. For reference, here are operations that currently use `valueToStore`: * `vector.store` `vector.scatter`, `vector.compressstore`, `vector.maskedstore`. This change is non-functional (NFC) and does not affect the functionality of these operations. Implements #131602
1 parent acc6bcd commit 8fe3975

File tree

16 files changed

+119
-86
lines changed

16 files changed

+119
-86
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -907,7 +907,7 @@ def Vector_InsertOp :
907907
}];
908908

909909
let arguments = (ins
910-
AnyType:$source,
910+
AnyType:$valueToStore,
911911
AnyVectorOfAnyRank:$dest,
912912
Variadic<Index>:$dynamic_position,
913913
DenseI64ArrayAttr:$static_position
@@ -916,15 +916,15 @@ def Vector_InsertOp :
916916

917917
let builders = [
918918
// Builder to insert a scalar/rank-0 vector into a rank-0 vector.
919-
OpBuilder<(ins "Value":$source, "Value":$dest)>,
920-
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
921-
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
922-
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
923-
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
919+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
920+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
921+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
922+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
923+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
924924
];
925925

926926
let extraClassDeclaration = extraPoisonClassDeclaration # [{
927-
Type getSourceType() { return getSource().getType(); }
927+
Type getValueToStoreType() { return getValueToStore().getType(); }
928928
VectorType getDestVectorType() {
929929
return ::llvm::cast<VectorType>(getDest().getType());
930930
}
@@ -946,8 +946,8 @@ def Vector_InsertOp :
946946
}];
947947

948948
let assemblyFormat = [{
949-
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
950-
attr-dict `:` type($source) `into` type($dest)
949+
$valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
950+
attr-dict `:` type($valueToStore) `into` type($dest)
951951
}];
952952

953953
let hasCanonicalizer = 1;
@@ -957,13 +957,13 @@ def Vector_InsertOp :
957957

958958
def Vector_ScalableInsertOp :
959959
Vector_Op<"scalable.insert", [Pure,
960-
AllElementTypesMatch<["source", "dest"]>,
960+
AllElementTypesMatch<["valueToStore", "dest"]>,
961961
AllTypesMatch<["dest", "res"]>,
962962
PredOpTrait<"position is a multiple of the source length.",
963963
CPred<
964964
"(getPos() % getSourceVectorType().getNumElements()) == 0"
965965
>>]>,
966-
Arguments<(ins VectorOfRank<[1]>:$source,
966+
Arguments<(ins VectorOfRank<[1]>:$valueToStore,
967967
ScalableVectorOfRank<[1]>:$dest,
968968
I64Attr:$pos)>,
969969
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
@@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
999999
}];
10001000

10011001
let assemblyFormat = [{
1002-
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
1002+
$valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
10031003
}];
10041004

10051005
let extraClassDeclaration = extraPoisonClassDeclaration # [{
10061006
VectorType getSourceVectorType() {
1007-
return ::llvm::cast<VectorType>(getSource().getType());
1007+
return ::llvm::cast<VectorType>(getValueToStore().getType());
10081008
}
10091009
VectorType getDestVectorType() {
10101010
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
10681068
PredOpTrait<"operand #0 and result have same element type",
10691069
TCresVTEtIsSameAsOpBase<0, 0>>,
10701070
AllTypesMatch<["dest", "res"]>]>,
1071-
Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
1071+
Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
10721072
I64ArrayAttr:$strides)>,
10731073
Results<(outs AnyVectorOfNonZeroRank:$res)> {
10741074
let summary = "strided_slice operation";
10751075
let description = [{
1076-
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
1076+
Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
10771077
`offsets` integer array attribute, a k-sized `strides` integer array attribute
1078-
and inserts the k-D source vector as a strided subvector at the proper offset
1078+
and inserts the k-D valueToStore vector as a strided subvector at the proper offset
10791079
into the n-D destination vector.
10801080

10811081
At the moment strides must contain only 1s.
10821082

10831083
Returns an n-D vector that is a copy of the n-D destination vector in which
1084-
the last k-D dimensions contain the k-D source vector elements strided at
1084+
the last k-D dimensions contain the k-D valueToStore vector elements strided at
10851085
the proper location as specified by the offsets.
10861086

10871087
Example:
@@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
10941094
}];
10951095

10961096
let assemblyFormat = [{
1097-
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
1097+
$valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
10981098
}];
10991099

11001100
let builders = [
1101-
OpBuilder<(ins "Value":$source, "Value":$dest,
1101+
OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
11021102
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
11031103
];
11041104
let extraClassDeclaration = [{
1105+
// TODO: Rename
11051106
VectorType getSourceVectorType() {
1106-
return ::llvm::cast<VectorType>(getSource().getType());
1107+
return ::llvm::cast<VectorType>(getValueToStore().getType());
11071108
}
11081109
VectorType getDestVectorType() {
11091110
return ::llvm::cast<VectorType>(getDest().getType());
@@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
15201521
AttrSizedOperandSegments,
15211522
DestinationStyleOpInterface
15221523
]>,
1523-
Arguments<(ins AnyVectorOfAnyRank:$vector,
1524+
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
15241525
AnyShaped:$source,
15251526
Variadic<Index>:$indices,
15261527
AffineMapAttr:$permutation_map,

mlir/include/mlir/Interfaces/VectorInterfaces.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
124124
/*methodName=*/"getVector",
125125
/*args=*/(ins)
126126
>,
127+
InterfaceMethod<
128+
/*desc=*/[{
129+
Return the type of the vector that this operation operates on.
130+
}],
131+
/*retTy=*/"::mlir::VectorType",
132+
/*methodName=*/"getVectorType",
133+
/*args=*/(ins)
134+
>,
127135
InterfaceMethod<
128136
/*desc=*/[{
129137
Return the indices that specify the starting offsets into the source
@@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
133141
/*methodName=*/"getIndices",
134142
/*args=*/(ins)
135143
>,
144+
136145
InterfaceMethod<
137146
/*desc=*/[{
138147
Return the permutation map that describes the mapping of vector
@@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
202211
return $_op.getPermutationMap().getNumResults();
203212
}
204213

205-
/// Return the type of the vector that this operation operates on.
206-
::mlir::VectorType getVectorType() {
207-
return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
208-
}
209-
210214
/// Return "true" if at least one of the vector dimensions is a broadcasted
211215
/// dimension.
212216
bool hasBroadcastDim() {

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
579579
auto loc = insertOp.getLoc();
580580
auto position = insertOp.getMixedPosition();
581581

582-
Value source = insertOp.getSource();
582+
Value source = insertOp.getValueToStore();
583583

584584
// Overwrite entire vector with value. Should be handled by folder, but
585585
// just to be safe.

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
12571257
// We are going to mutate this 1D vector until it is either the final
12581258
// result (in the non-aggregate case) or the value that needs to be
12591259
// inserted into the aggregate result.
1260-
Value sourceAggregate = adaptor.getSource();
1260+
Value sourceAggregate = adaptor.getValueToStore();
12611261
if (insertIntoInnermostDim) {
12621262
// Scalar-into-1D-vector case, so we know we will have to create a
12631263
// InsertElementOp. The question is into what destination.
@@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
12791279
}
12801280
// Insert the scalar into the 1D vector.
12811281
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1282-
loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
1282+
loc, sourceAggregate.getType(), sourceAggregate,
1283+
adaptor.getValueToStore(),
12831284
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
12841285
}
12851286

@@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
13051306
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
13061307
ConversionPatternRewriter &rewriter) const override {
13071308
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
1308-
insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
1309+
insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
13091310
return success();
13101311
}
13111312
};

mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
661661
buffers.dataBuffer);
662662
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
663663
rewriter.modifyOpInPlace(xferOp, [&]() {
664-
xferOp.getVectorMutable().assign(loadedVec);
664+
xferOp.getValueToStoreMutable().assign(loadedVec);
665665
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
666666
});
667667

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
287287
LogicalResult
288288
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
289289
ConversionPatternRewriter &rewriter) const override {
290-
if (isa<VectorType>(insertOp.getSourceType()))
290+
if (isa<VectorType>(insertOp.getValueToStoreType()))
291291
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
292292
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
293293
return rewriter.notifyMatchFailure(insertOp,
294294
"unsupported dest vector type");
295295

296296
// Special case for inserting scalar values into size-1 vectors.
297-
if (insertOp.getSourceType().isIntOrFloat() &&
297+
if (insertOp.getValueToStoreType().isIntOrFloat() &&
298298
insertOp.getDestVectorType().getNumElements() == 1) {
299-
rewriter.replaceOp(insertOp, adaptor.getSource());
299+
rewriter.replaceOp(insertOp, adaptor.getValueToStore());
300300
return success();
301301
}
302302

@@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
307307
insertOp,
308308
"Static use of poison index handled elsewhere (folded to poison)");
309309
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
310-
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
310+
insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
311311
} else {
312312
Value sanitizedIndex = sanitizeDynamicIndex(
313313
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
314314
vector::InsertOp::kPoisonIndex,
315315
insertOp.getDestVectorType().getNumElements());
316316
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
317-
insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
317+
insertOp, insertOp.getDest(), adaptor.getValueToStore(),
318+
sanitizedIndex);
318319
}
319320
return success();
320321
}

mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition
357357

358358
auto loc = writeOp.getLoc();
359359
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
360-
auto inputSMETiles = adaptor.getVector();
360+
auto inputSMETiles = adaptor.getValueToStore();
361361

362362
Value destTensorOrMemref = writeOp.getSource();
363363
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
@@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
464464
rewriter.setInsertionPointToStart(storeLoop.getBody());
465465

466466
// For each sub-tile of the multi-tile `vectorType`.
467-
auto inputSMETiles = adaptor.getVector();
467+
auto inputSMETiles = adaptor.getValueToStore();
468468
auto tileSliceIndex = storeLoop.getInductionVar();
469469
for (auto [index, smeTile] : llvm::enumerate(
470470
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
371371
if (failed(maybeNewLoop))
372372
return WalkResult::interrupt();
373373

374-
transferWrite.getVectorMutable().assign(
374+
transferWrite.getValueToStoreMutable().assign(
375375
maybeNewLoop->getOperation()->getResults().back());
376376
changed = true;
377377
// Need to interrupt and restart because erasing the loop messes up

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
31773177
rewriter.create<vector::TransferWriteOp>(
31783178
xferOp.getLoc(), vector, out, xferOp.getIndices(),
31793179
xferOp.getPermutationMapAttr(), xferOp.getMask(),
3180-
rewriter.getBoolArrayAttr(
3181-
SmallVector<bool>(vector.getType().getRank(), false)));
3180+
rewriter.getBoolArrayAttr(SmallVector<bool>(
3181+
dyn_cast<VectorType>(vector.getType()).getRank(), false)));
31823182

31833183
rewriter.eraseOp(copyOp);
31843184
rewriter.eraseOp(xferOp);

0 commit comments

Comments
 (0)