Skip to content

Commit 0d9b439

Browse files
authored
[mlir][vector] Use DenseI64ArrayAttr for constant_mask dim sizes (llvm#100997)
This prevents a bunch of boilerplate conversions to/from IntegerAttrs and int64_ts. Other than that this is a NFC.
1 parent 135a1e9 commit 0d9b439

File tree

5 files changed

+30
-45
lines changed

5 files changed

+30
-45
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2443,7 +2443,7 @@ def Vector_TypeCastOp :
24432443

24442444
def Vector_ConstantMaskOp :
24452445
Vector_Op<"constant_mask", [Pure]>,
2446-
Arguments<(ins I64ArrayAttr:$mask_dim_sizes)>,
2446+
Arguments<(ins DenseI64ArrayAttr:$mask_dim_sizes)>,
24472447
Results<(outs VectorOfAnyRankOf<[I1]>)> {
24482448
let summary = "creates a constant vector mask";
24492449
let description = [{

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,14 @@ static MaskFormat getMaskFormat(Value mask) {
8888
// Inspect constant mask index. If the index exceeds the
8989
// dimension size, all bits are set. If the index is zero
9090
// or less, no bits are set.
91-
ArrayAttr masks = m.getMaskDimSizes();
91+
ArrayRef<int64_t> masks = m.getMaskDimSizes();
9292
auto shape = m.getType().getShape();
9393
bool allTrue = true;
9494
bool allFalse = true;
9595
for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
96-
int64_t i = llvm::cast<IntegerAttr>(maskIdx).getInt();
97-
if (i < dimSize)
96+
if (maskIdx < dimSize)
9897
allTrue = false;
99-
if (i > 0)
98+
if (maskIdx > 0)
10099
allFalse = false;
101100
}
102101
if (allTrue)
@@ -3593,8 +3592,7 @@ class StridedSliceConstantMaskFolder final
35933592
if (extractStridedSliceOp.hasNonUnitStrides())
35943593
return failure();
35953594
// Gather constant mask dimension sizes.
3596-
SmallVector<int64_t, 4> maskDimSizes;
3597-
populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
3595+
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
35983596
// Gather strided slice offsets and sizes.
35993597
SmallVector<int64_t, 4> sliceOffsets;
36003598
populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
@@ -3625,7 +3623,7 @@ class StridedSliceConstantMaskFolder final
36253623
// region.
36263624
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
36273625
extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
3628-
vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
3626+
sliceMaskDimSizes);
36293627
return success();
36303628
}
36313629
};
@@ -5410,21 +5408,19 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
54105408
}
54115409

54125410
if (constantMaskOp) {
5413-
auto maskDimSizes = constantMaskOp.getMaskDimSizes().getValue();
5411+
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
54145412
auto numMaskOperands = maskDimSizes.size();
54155413

54165414
// Check every mask dim size to see whether it can be dropped
54175415
for (size_t i = numMaskOperands - 1; i >= numMaskOperands - numDimsToDrop;
54185416
--i) {
5419-
if (cast<IntegerAttr>(maskDimSizes[i]).getValue() != 1)
5417+
if (maskDimSizes[i] != 1)
54205418
return failure();
54215419
}
54225420

54235421
auto newMaskOperands = maskDimSizes.drop_back(numDimsToDrop);
5424-
ArrayAttr newMaskOperandsAttr = rewriter.getArrayAttr(newMaskOperands);
5425-
54265422
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(shapeOp, shapeOpResTy,
5427-
newMaskOperandsAttr);
5423+
newMaskOperands);
54285424
return success();
54295425
}
54305426

@@ -5804,12 +5800,10 @@ class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
58045800

58055801
// ConstantMaskOp case.
58065802
auto maskDimSizes = constantMaskOp.getMaskDimSizes();
5807-
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.getValue());
5808-
applyPermutationToVector(newMaskDimSizes, permutation);
5803+
auto newMaskDimSizes = applyPermutation(maskDimSizes, permutation);
58095804

58105805
rewriter.replaceOpWithNewOp<vector::ConstantMaskOp>(
5811-
transpOp, transpOp.getResultVectorType(),
5812-
ArrayAttr::get(transpOp.getContext(), newMaskDimSizes));
5806+
transpOp, transpOp.getResultVectorType(), newMaskDimSizes);
58135807
return success();
58145808
}
58155809
};
@@ -5832,7 +5826,7 @@ LogicalResult ConstantMaskOp::verify() {
58325826
if (resultType.getRank() == 0) {
58335827
if (getMaskDimSizes().size() != 1)
58345828
return emitError("array attr must have length 1 for 0-D vectors");
5835-
auto dim = llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt();
5829+
auto dim = getMaskDimSizes()[0];
58365830
if (dim != 0 && dim != 1)
58375831
return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
58385832
return success();
@@ -5846,17 +5840,15 @@ LogicalResult ConstantMaskOp::verify() {
58465840
// result dimension size.
58475841
auto resultShape = resultType.getShape();
58485842
auto resultScalableDims = resultType.getScalableDims();
5849-
SmallVector<int64_t, 4> maskDimSizes;
5850-
for (const auto [index, intAttr] : llvm::enumerate(getMaskDimSizes())) {
5851-
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
5843+
ArrayRef<int64_t> maskDimSizes = getMaskDimSizes();
5844+
for (const auto [index, maskDimSize] : llvm::enumerate(maskDimSizes)) {
58525845
if (maskDimSize < 0 || maskDimSize > resultShape[index])
58535846
return emitOpError(
58545847
"array attr of size out of bounds of vector result dimension size");
58555848
if (resultScalableDims[index] && maskDimSize != 0 &&
58565849
maskDimSize != resultShape[index])
58575850
return emitOpError(
58585851
"only supports 'none set' or 'all set' scalable dimensions");
5859-
maskDimSizes.push_back(maskDimSize);
58605852
}
58615853
// Verify that if one mask dim size is zero, they all should be zero (because
58625854
// the mask region is a conjunction of each mask dimension interval).
@@ -5873,11 +5865,10 @@ bool ConstantMaskOp::isAllOnesMask() {
58735865
// Check the corner case of 0-D vectors first.
58745866
if (resultType.getRank() == 0) {
58755867
assert(getMaskDimSizes().size() == 1 && "invalid sizes for zero rank mask");
5876-
return llvm::cast<IntegerAttr>(getMaskDimSizes()[0]).getInt() == 1;
5868+
return getMaskDimSizes()[0] == 1;
58775869
}
5878-
for (const auto [resultSize, intAttr] :
5870+
for (const auto [resultSize, maskDimSize] :
58795871
llvm::zip_equal(resultType.getShape(), getMaskDimSizes())) {
5880-
int64_t maskDimSize = llvm::cast<IntegerAttr>(intAttr).getInt();
58815872
if (maskDimSize < resultSize)
58825873
return false;
58835874
}
@@ -6007,9 +5998,8 @@ class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
60075998
}
60085999

60096000
// Replace 'createMaskOp' with ConstantMaskOp.
6010-
rewriter.replaceOpWithNewOp<ConstantMaskOp>(
6011-
createMaskOp, retTy,
6012-
vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
6001+
rewriter.replaceOpWithNewOp<ConstantMaskOp>(createMaskOp, retTy,
6002+
maskDimSizes);
60136003
return success();
60146004
}
60156005
};

mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,15 +111,15 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
111111
if (rank == 0) {
112112
assert(dimSizes.size() == 1 &&
113113
"Expected exactly one dim size for a 0-D vector");
114-
bool value = cast<IntegerAttr>(dimSizes[0]).getInt() == 1;
114+
bool value = dimSizes.front() == 1;
115115
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
116116
op, dstType,
117117
DenseIntElementsAttr::get(VectorType::get({}, rewriter.getI1Type()),
118118
value));
119119
return success();
120120
}
121121

122-
int64_t trueDimSize = cast<IntegerAttr>(dimSizes[0]).getInt();
122+
int64_t trueDimSize = dimSizes.front();
123123

124124
if (rank == 1) {
125125
if (trueDimSize == 0 || trueDimSize == dstType.getDimSize(0)) {
@@ -147,7 +147,7 @@ class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
147147

148148
VectorType lowType = VectorType::Builder(dstType).dropDim(0);
149149
Value trueVal = rewriter.create<vector::ConstantMaskOp>(
150-
loc, lowType, rewriter.getArrayAttr(dimSizes.getValue().drop_front()));
150+
loc, lowType, dimSizes.drop_front());
151151
Value result = rewriter.create<arith::ConstantOp>(
152152
loc, dstType, rewriter.getZeroAttr(dstType));
153153
for (int64_t d = 0; d < trueDimSize; d++)

mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,7 @@ struct CastAwayConstantMaskLeadingOneDim
550550
return failure();
551551

552552
int64_t dropDim = oldType.getRank() - newType.getRank();
553-
SmallVector<int64_t> dimSizes;
554-
for (auto attr : mask.getMaskDimSizes())
555-
dimSizes.push_back(llvm::cast<IntegerAttr>(attr).getInt());
553+
ArrayRef<int64_t> dimSizes = mask.getMaskDimSizes();
556554

557555
// If any of the dropped unit dims has a size of `0`, the entire mask is a
558556
// zero mask, else the unit dim has no effect on the mask.
@@ -563,7 +561,7 @@ struct CastAwayConstantMaskLeadingOneDim
563561
newDimSizes.append(dimSizes.begin() + dropDim + 1, dimSizes.end());
564562

565563
auto newMask = rewriter.create<vector::ConstantMaskOp>(
566-
mask.getLoc(), newType, rewriter.getI64ArrayAttr(newDimSizes));
564+
mask.getLoc(), newType, newDimSizes);
567565
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(mask, oldType, newMask);
568566
return success();
569567
}

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,14 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter,
8383
newMask = rewriter.create<vector::CreateMaskOp>(loc, newMaskType,
8484
newMaskOperands);
8585
} else if (constantMaskOp) {
86-
ArrayRef<Attribute> maskDimSizes =
87-
constantMaskOp.getMaskDimSizes().getValue();
86+
ArrayRef<int64_t> maskDimSizes = constantMaskOp.getMaskDimSizes();
8887
size_t numMaskOperands = maskDimSizes.size();
89-
auto origIndex =
90-
cast<IntegerAttr>(maskDimSizes[numMaskOperands - 1]).getInt();
91-
IntegerAttr maskIndexAttr =
92-
rewriter.getI64IntegerAttr((origIndex + scale - 1) / scale);
93-
SmallVector<Attribute> newMaskDimSizes(maskDimSizes.drop_back());
94-
newMaskDimSizes.push_back(maskIndexAttr);
95-
newMask = rewriter.create<vector::ConstantMaskOp>(
96-
loc, newMaskType, rewriter.getArrayAttr(newMaskDimSizes));
88+
int64_t origIndex = maskDimSizes[numMaskOperands - 1];
89+
int64_t maskIndex = (origIndex + scale - 1) / scale;
90+
SmallVector<int64_t> newMaskDimSizes(maskDimSizes.drop_back());
91+
newMaskDimSizes.push_back(maskIndex);
92+
newMask = rewriter.create<vector::ConstantMaskOp>(loc, newMaskType,
93+
newMaskDimSizes);
9794
}
9895

9996
while (!extractOps.empty()) {

0 commit comments

Comments
 (0)