Skip to content

[mlir][vector] Standardise valueToStore Naming Across Vector Ops (NFC) #134206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 22 additions & 21 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,7 @@ def Vector_InsertOp :
}];

let arguments = (ins
AnyType:$source,
AnyType:$valueToStore,
AnyVectorOfAnyRank:$dest,
Variadic<Index>:$dynamic_position,
DenseI64ArrayAttr:$static_position
Expand All @@ -916,15 +916,15 @@ def Vector_InsertOp :

let builders = [
// Builder to insert a scalar/rank-0 vector into a rank-0 vector.
OpBuilder<(ins "Value":$source, "Value":$dest)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "int64_t":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<int64_t>":$position)>,
OpBuilder<(ins "Value":$source, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest)>,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "int64_t":$position)>,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "OpFoldResult":$position)>,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<int64_t>":$position)>,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest, "ArrayRef<OpFoldResult>":$position)>,
];

let extraClassDeclaration = extraPoisonClassDeclaration # [{
Type getSourceType() { return getSource().getType(); }
Type getValueToStoreType() { return getValueToStore().getType(); }
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
}
Expand All @@ -946,8 +946,8 @@ def Vector_InsertOp :
}];

let assemblyFormat = [{
$source `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
attr-dict `:` type($source) `into` type($dest)
$valueToStore `,` $dest custom<DynamicIndexList>($dynamic_position, $static_position)
attr-dict `:` type($valueToStore) `into` type($dest)
}];

let hasCanonicalizer = 1;
Expand All @@ -957,13 +957,13 @@ def Vector_InsertOp :

def Vector_ScalableInsertOp :
Vector_Op<"scalable.insert", [Pure,
AllElementTypesMatch<["source", "dest"]>,
AllElementTypesMatch<["valueToStore", "dest"]>,
AllTypesMatch<["dest", "res"]>,
PredOpTrait<"position is a multiple of the source length.",
CPred<
"(getPos() % getSourceVectorType().getNumElements()) == 0"
>>]>,
Arguments<(ins VectorOfRank<[1]>:$source,
Arguments<(ins VectorOfRank<[1]>:$valueToStore,
ScalableVectorOfRank<[1]>:$dest,
I64Attr:$pos)>,
Results<(outs ScalableVectorOfRank<[1]>:$res)> {
Expand Down Expand Up @@ -999,12 +999,12 @@ def Vector_ScalableInsertOp :
}];

let assemblyFormat = [{
$source `,` $dest `[` $pos `]` attr-dict `:` type($source) `into` type($dest)
$valueToStore `,` $dest `[` $pos `]` attr-dict `:` type($valueToStore) `into` type($dest)
}];

let extraClassDeclaration = extraPoisonClassDeclaration # [{
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
Expand Down Expand Up @@ -1068,20 +1068,20 @@ def Vector_InsertStridedSliceOp :
PredOpTrait<"operand #0 and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
AllTypesMatch<["dest", "res"]>]>,
Arguments<(ins AnyVectorOfNonZeroRank:$source, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
Arguments<(ins AnyVectorOfNonZeroRank:$valueToStore, AnyVectorOfNonZeroRank:$dest, I64ArrayAttr:$offsets,
I64ArrayAttr:$strides)>,
Results<(outs AnyVectorOfNonZeroRank:$res)> {
let summary = "strided_slice operation";
let description = [{
Takes a k-D source vector, an n-D destination vector (n >= k), n-sized
Takes a k-D valueToStore vector, an n-D destination vector (n >= k), n-sized
`offsets` integer array attribute, a k-sized `strides` integer array attribute
and inserts the k-D source vector as a strided subvector at the proper offset
and inserts the k-D valueToStore vector as a strided subvector at the proper offset
into the n-D destination vector.

At the moment strides must contain only 1s.

Returns an n-D vector that is a copy of the n-D destination vector in which
the last k-D dimensions contain the k-D source vector elements strided at
the last k-D dimensions contain the k-D valueToStore vector elements strided at
the proper location as specified by the offsets.

Example:
Expand All @@ -1094,16 +1094,17 @@ def Vector_InsertStridedSliceOp :
}];

let assemblyFormat = [{
$source `,` $dest attr-dict `:` type($source) `into` type($dest)
$valueToStore `,` $dest attr-dict `:` type($valueToStore) `into` type($dest)
}];

let builders = [
OpBuilder<(ins "Value":$source, "Value":$dest,
OpBuilder<(ins "Value":$valueToStore, "Value":$dest,
"ArrayRef<int64_t>":$offsets, "ArrayRef<int64_t>":$strides)>
];
let extraClassDeclaration = [{
// TODO: Rename
VectorType getSourceVectorType() {
return ::llvm::cast<VectorType>(getSource().getType());
return ::llvm::cast<VectorType>(getValueToStore().getType());
}
VectorType getDestVectorType() {
return ::llvm::cast<VectorType>(getDest().getType());
Expand Down Expand Up @@ -1520,7 +1521,7 @@ def Vector_TransferWriteOp :
AttrSizedOperandSegments,
DestinationStyleOpInterface
]>,
Arguments<(ins AnyVectorOfAnyRank:$vector,
Arguments<(ins AnyVectorOfAnyRank:$valueToStore,
AnyShaped:$source,
Variadic<Index>:$indices,
AffineMapAttr:$permutation_map,
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Interfaces/VectorInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getVector",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Return the type of the vector that this operation operates on.
}],
/*retTy=*/"::mlir::VectorType",
/*methodName=*/"getVectorType",
/*args=*/(ins)
>,
InterfaceMethod<
/*desc=*/[{
Return the indices that specify the starting offsets into the source
Expand All @@ -133,6 +141,7 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
/*methodName=*/"getIndices",
/*args=*/(ins)
>,

InterfaceMethod<
/*desc=*/[{
Return the permutation map that describes the mapping of vector
Expand Down Expand Up @@ -202,11 +211,6 @@ def VectorTransferOpInterface : OpInterface<"VectorTransferOpInterface"> {
return $_op.getPermutationMap().getNumResults();
}

/// Return the type of the vector that this operation operates on.
::mlir::VectorType getVectorType() {
return ::llvm::cast<::mlir::VectorType>($_op.getVector().getType());
}

/// Return "true" if at least one of the vector dimensions is a broadcasted
/// dimension.
bool hasBroadcastDim() {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ struct VectorInsertToArmSMELowering
auto loc = insertOp.getLoc();
auto position = insertOp.getMixedPosition();

Value source = insertOp.getSource();
Value source = insertOp.getValueToStore();

// Overwrite entire vector with value. Should be handled by folder, but
// just to be safe.
Expand Down
7 changes: 4 additions & 3 deletions mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1257,7 +1257,7 @@ class VectorInsertOpConversion
// We are going to mutate this 1D vector until it is either the final
// result (in the non-aggregate case) or the value that needs to be
// inserted into the aggregate result.
Value sourceAggregate = adaptor.getSource();
Value sourceAggregate = adaptor.getValueToStore();
if (insertIntoInnermostDim) {
// Scalar-into-1D-vector case, so we know we will have to create a
// InsertElementOp. The question is into what destination.
Expand All @@ -1279,7 +1279,8 @@ class VectorInsertOpConversion
}
// Insert the scalar into the 1D vector.
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
loc, sourceAggregate.getType(), sourceAggregate,
adaptor.getValueToStore(),
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
}

Expand All @@ -1305,7 +1306,7 @@ struct VectorScalableInsertOpLowering
matchAndRewrite(vector::ScalableInsertOp insOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::vector_insert>(
insOp, adaptor.getDest(), adaptor.getSource(), adaptor.getPos());
insOp, adaptor.getDest(), adaptor.getValueToStore(), adaptor.getPos());
return success();
}
};
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ struct PrepareTransferWriteConversion
buffers.dataBuffer);
auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
rewriter.modifyOpInPlace(xferOp, [&]() {
xferOp.getVectorMutable().assign(loadedVec);
xferOp.getValueToStoreMutable().assign(loadedVec);
xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
});

Expand Down
11 changes: 6 additions & 5 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,16 +287,16 @@ struct VectorInsertOpConvert final
LogicalResult
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (isa<VectorType>(insertOp.getSourceType()))
if (isa<VectorType>(insertOp.getValueToStoreType()))
return rewriter.notifyMatchFailure(insertOp, "unsupported vector source");
if (!getTypeConverter()->convertType(insertOp.getDestVectorType()))
return rewriter.notifyMatchFailure(insertOp,
"unsupported dest vector type");

// Special case for inserting scalar values into size-1 vectors.
if (insertOp.getSourceType().isIntOrFloat() &&
if (insertOp.getValueToStoreType().isIntOrFloat() &&
insertOp.getDestVectorType().getNumElements() == 1) {
rewriter.replaceOp(insertOp, adaptor.getSource());
rewriter.replaceOp(insertOp, adaptor.getValueToStore());
return success();
}

Expand All @@ -307,14 +307,15 @@ struct VectorInsertOpConvert final
insertOp,
"Static use of poison index handled elsewhere (folded to poison)");
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
insertOp, adaptor.getValueToStore(), adaptor.getDest(), id.value());
} else {
Value sanitizedIndex = sanitizeDynamicIndex(
rewriter, insertOp.getLoc(), adaptor.getDynamicPosition()[0],
vector::InsertOp::kPoisonIndex,
insertOp.getDestVectorType().getNumElements());
rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
insertOp, insertOp.getDest(), adaptor.getSource(), sanitizedIndex);
insertOp, insertOp.getDest(), adaptor.getValueToStore(),
sanitizedIndex);
}
return success();
}
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ struct LegalizeTransferWriteOpsByDecomposition

auto loc = writeOp.getLoc();
auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
auto inputSMETiles = adaptor.getVector();
auto inputSMETiles = adaptor.getValueToStore();

Value destTensorOrMemref = writeOp.getSource();
for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles(
Expand Down Expand Up @@ -464,7 +464,7 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop
rewriter.setInsertionPointToStart(storeLoop.getBody());

// For each sub-tile of the multi-tile `vectorType`.
auto inputSMETiles = adaptor.getVector();
auto inputSMETiles = adaptor.getValueToStore();
auto tileSliceIndex = storeLoop.getInductionVar();
for (auto [index, smeTile] : llvm::enumerate(
decomposeToSMETiles(rewriter, vectorType, smeTileType))) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(Operation *root,
if (failed(maybeNewLoop))
return WalkResult::interrupt();

transferWrite.getVectorMutable().assign(
transferWrite.getValueToStoreMutable().assign(
maybeNewLoop->getOperation()->getResults().back());
changed = true;
// Need to interrupt and restart because erasing the loop messes up
Expand Down
4 changes: 2 additions & 2 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3177,8 +3177,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
rewriter.create<vector::TransferWriteOp>(
xferOp.getLoc(), vector, out, xferOp.getIndices(),
xferOp.getPermutationMapAttr(), xferOp.getMask(),
rewriter.getBoolArrayAttr(
SmallVector<bool>(vector.getType().getRank(), false)));
rewriter.getBoolArrayAttr(SmallVector<bool>(
dyn_cast<VectorType>(vector.getType()).getRank(), false)));

rewriter.eraseOp(copyOp);
rewriter.eraseOp(xferOp);
Expand Down
Loading