Skip to content

Commit a1aad28

Browse files
committed
[mlir][vector] NFC: Improve vector type accessor methods
Plain `getVectorType()` can be quite confusing and error-prone given that, well, vector ops always work on vector types, and it can commonly involve both source and result vectors. So this commit makes various such accessor methods to be explicit w.r.t. source or result vectors. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D144159
1 parent 5382d28 commit a1aad28

File tree

10 files changed

+86
-79
lines changed

10 files changed

+86
-79
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def Vector_ReductionOp :
321321
```
322322
}];
323323
let extraClassDeclaration = [{
324-
VectorType getVectorType() {
324+
VectorType getSourceVectorType() {
325325
return getVector().getType().cast<VectorType>();
326326
}
327327
}];
@@ -449,7 +449,7 @@ def Vector_BroadcastOp :
449449
}];
450450
let extraClassDeclaration = [{
451451
Type getSourceType() { return getSource().getType(); }
452-
VectorType getVectorType() {
452+
VectorType getResultVectorType() {
453453
return getVector().getType().cast<VectorType>();
454454
}
455455

@@ -466,7 +466,7 @@ def Vector_BroadcastOp :
466466
/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
467467
/// the helper will assert. This means:
468468
/// 1. `dstShape` must not be empty.
469-
/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
469+
/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getResultVectorType)]
470470
/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
471471
// must match the `value` shape.
472472
static Value createOrFoldBroadcastOp(
@@ -537,7 +537,7 @@ def Vector_ShuffleOp :
537537
VectorType getV2VectorType() {
538538
return getV2().getType().cast<VectorType>();
539539
}
540-
VectorType getVectorType() {
540+
VectorType getResultVectorType() {
541541
return getVector().getType().cast<VectorType>();
542542
}
543543
}];
@@ -584,7 +584,7 @@ def Vector_ExtractElementOp :
584584
OpBuilder<(ins "Value":$source)>,
585585
];
586586
let extraClassDeclaration = [{
587-
VectorType getVectorType() {
587+
VectorType getSourceVectorType() {
588588
return getVector().getType().cast<VectorType>();
589589
}
590590
}];
@@ -619,7 +619,7 @@ def Vector_ExtractOp :
619619
];
620620
let extraClassDeclaration = [{
621621
static StringRef getPositionAttrStrName() { return "position"; }
622-
VectorType getVectorType() {
622+
VectorType getSourceVectorType() {
623623
return getVector().getType().cast<VectorType>();
624624
}
625625
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
@@ -996,7 +996,7 @@ def Vector_OuterProductOp :
996996
? VectorType()
997997
: (*getAcc().begin()).getType().cast<VectorType>();
998998
}
999-
VectorType getVectorType() {
999+
VectorType getResultVectorType() {
10001000
return getResult().getType().cast<VectorType>();
10011001
}
10021002
static constexpr StringRef getKindAttrStrName() {
@@ -1172,7 +1172,9 @@ def Vector_ExtractStridedSliceOp :
11721172
static StringRef getOffsetsAttrStrName() { return "offsets"; }
11731173
static StringRef getSizesAttrStrName() { return "sizes"; }
11741174
static StringRef getStridesAttrStrName() { return "strides"; }
1175-
VectorType getVectorType(){ return getVector().getType().cast<VectorType>(); }
1175+
VectorType getSourceVectorType() {
1176+
return getVector().getType().cast<VectorType>();
1177+
}
11761178
void getOffsets(SmallVectorImpl<int64_t> &results);
11771179
bool hasNonUnitStrides() {
11781180
return llvm::any_of(getStrides(), [](Attribute attr) {
@@ -2424,10 +2426,10 @@ def Vector_TransposeOp :
24242426
OpBuilder<(ins "Value":$vector, "ArrayRef<int64_t>":$transp)>
24252427
];
24262428
let extraClassDeclaration = [{
2427-
VectorType getVectorType() {
2429+
VectorType getSourceVectorType() {
24282430
return getVector().getType().cast<VectorType>();
24292431
}
2430-
VectorType getResultType() {
2432+
VectorType getResultVectorType() {
24312433
return getResult().getType().cast<VectorType>();
24322434
}
24332435
void getTransp(SmallVectorImpl<int64_t> &results);

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
203203

204204
/// Return true if this is a broadcast from scalar to a 2D vector.
205205
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
206-
return broadcastOp.getVectorType().getRank() == 2;
206+
return broadcastOp.getResultVectorType().getRank() == 2;
207207
}
208208

209209
/// Return true if this integer extend op can be folded into a contract op.
@@ -949,7 +949,7 @@ convertExtractStridedSlice(RewriterBase &rewriter,
949949

950950
SmallVector<int64_t> sizes;
951951
populateFromInt64AttrArray(op.getSizes(), sizes);
952-
ArrayRef<int64_t> warpVectorShape = op.getVectorType().getShape();
952+
ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
953953

954954
// Compute offset in vector registers. Note that the mma.sync vector registers
955955
// are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
@@ -1045,7 +1045,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
10451045
assert(broadcastSupportsMMAMatrixType(op));
10461046

10471047
const char *fragType = inferFragType(op);
1048-
auto vecType = op.getVectorType();
1048+
auto vecType = op.getResultVectorType();
10491049
gpu::MMAMatrixType type = gpu::MMAMatrixType::get(
10501050
vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
10511051
auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -939,7 +939,7 @@ class VectorShuffleOpConversion
939939
auto loc = shuffleOp->getLoc();
940940
auto v1Type = shuffleOp.getV1VectorType();
941941
auto v2Type = shuffleOp.getV2VectorType();
942-
auto vectorType = shuffleOp.getVectorType();
942+
auto vectorType = shuffleOp.getResultVectorType();
943943
Type llvmType = typeConverter->convertType(vectorType);
944944
auto maskArrayAttr = shuffleOp.getMask();
945945

@@ -1002,7 +1002,7 @@ class VectorExtractElementOpConversion
10021002
LogicalResult
10031003
matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
10041004
ConversionPatternRewriter &rewriter) const override {
1005-
auto vectorType = extractEltOp.getVectorType();
1005+
auto vectorType = extractEltOp.getSourceVectorType();
10061006
auto llvmType = typeConverter->convertType(vectorType.getElementType());
10071007

10081008
// Bail if result type cannot be lowered.

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ struct VectorBroadcastConvert final
8383
LogicalResult
8484
matchAndRewrite(vector::BroadcastOp castOp, OpAdaptor adaptor,
8585
ConversionPatternRewriter &rewriter) const override {
86-
Type resultType = getTypeConverter()->convertType(castOp.getVectorType());
86+
Type resultType =
87+
getTypeConverter()->convertType(castOp.getResultVectorType());
8788
if (!resultType)
8889
return failure();
8990

@@ -92,10 +93,10 @@ struct VectorBroadcastConvert final
9293
return success();
9394
}
9495

95-
SmallVector<Value, 4> source(castOp.getVectorType().getNumElements(),
96+
SmallVector<Value, 4> source(castOp.getResultVectorType().getNumElements(),
9697
adaptor.getSource());
9798
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
98-
castOp, castOp.getVectorType(), source);
99+
castOp, castOp.getResultVectorType(), source);
99100
return success();
100101
}
101102
};
@@ -405,7 +406,7 @@ struct VectorShuffleOpConvert final
405406
LogicalResult
406407
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
407408
ConversionPatternRewriter &rewriter) const override {
408-
auto oldResultType = shuffleOp.getVectorType();
409+
auto oldResultType = shuffleOp.getResultVectorType();
409410
if (!spirv::CompositeType::isValid(oldResultType))
410411
return failure();
411412
Type newResultType = getTypeConverter()->convertType(oldResultType);

0 commit comments

Comments
 (0)