Skip to content

Commit b4444dc

Browse files
authored
[mlir][vector] Use DenseI64ArrayAttr for shuffle masks (llvm#101163)
Follow on from llvm#100997. This again removes from boilerplate conversions to/from IntegerAttr and int64_t (otherwise, this is a NFC).
1 parent d92a484 commit b4444dc

File tree

6 files changed

+35
-55
lines changed

6 files changed

+35
-55
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def Vector_ShuffleOp :
421421
TCresVTEtIsSameAsOpBase<0, 1>>,
422422
InferTypeOpAdaptor]>,
423423
Arguments<(ins AnyFixedVector:$v1, AnyFixedVector:$v2,
424-
I64ArrayAttr:$mask)>,
424+
DenseI64ArrayAttr:$mask)>,
425425
Results<(outs AnyVector:$vector)> {
426426
let summary = "shuffle operation";
427427
let description = [{
@@ -459,11 +459,7 @@ def Vector_ShuffleOp :
459459
: vector<f32>, vector<f32> ; yields vector<2xf32>
460460
```
461461
}];
462-
let builders = [
463-
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayRef<int64_t>")>
464-
];
465-
let hasFolder = 1;
466-
let hasCanonicalizer = 1;
462+
467463
let extraClassDeclaration = [{
468464
VectorType getV1VectorType() {
469465
return ::llvm::cast<VectorType>(getV1().getType());
@@ -475,7 +471,10 @@ def Vector_ShuffleOp :
475471
return ::llvm::cast<VectorType>(getVector().getType());
476472
}
477473
}];
474+
478475
let assemblyFormat = "operands $mask attr-dict `:` type(operands)";
476+
477+
let hasFolder = 1;
479478
let hasVerifier = 1;
480479
let hasCanonicalizer = 1;
481480
}

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ class VectorShuffleOpConversion
994994
auto v2Type = shuffleOp.getV2VectorType();
995995
auto vectorType = shuffleOp.getResultVectorType();
996996
Type llvmType = typeConverter->convertType(vectorType);
997-
auto maskArrayAttr = shuffleOp.getMask();
997+
ArrayRef<int64_t> mask = shuffleOp.getMask();
998998

999999
// Bail if result type cannot be lowered.
10001000
if (!llvmType)
@@ -1015,7 +1015,7 @@ class VectorShuffleOpConversion
10151015
if (rank <= 1 && v1Type == v2Type) {
10161016
Value llvmShuffleOp = rewriter.create<LLVM::ShuffleVectorOp>(
10171017
loc, adaptor.getV1(), adaptor.getV2(),
1018-
LLVM::convertArrayToIndices<int32_t>(maskArrayAttr));
1018+
llvm::to_vector_of<int32_t>(mask));
10191019
rewriter.replaceOp(shuffleOp, llvmShuffleOp);
10201020
return success();
10211021
}
@@ -1029,8 +1029,7 @@ class VectorShuffleOpConversion
10291029
eltType = cast<VectorType>(llvmType).getElementType();
10301030
Value insert = rewriter.create<LLVM::UndefOp>(loc, llvmType);
10311031
int64_t insPos = 0;
1032-
for (const auto &en : llvm::enumerate(maskArrayAttr)) {
1033-
int64_t extPos = cast<IntegerAttr>(en.value()).getInt();
1032+
for (int64_t extPos : mask) {
10341033
Value value = adaptor.getV1();
10351034
if (extPos >= v1Dim) {
10361035
extPos -= v1Dim;

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -527,10 +527,7 @@ struct VectorShuffleOpConvert final
527527
return rewriter.notifyMatchFailure(shuffleOp,
528528
"unsupported result vector type");
529529

530-
SmallVector<int32_t, 4> mask = llvm::map_to_vector<4>(
531-
shuffleOp.getMask(), [](Attribute attr) -> int32_t {
532-
return cast<IntegerAttr>(attr).getValue().getZExtValue();
533-
});
530+
auto mask = llvm::to_vector_of<int32_t>(shuffleOp.getMask());
534531

535532
VectorType oldV1Type = shuffleOp.getV1VectorType();
536533
VectorType oldV2Type = shuffleOp.getV2VectorType();

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2464,11 +2464,6 @@ void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
24642464
// ShuffleOp
24652465
//===----------------------------------------------------------------------===//
24662466

2467-
void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
2468-
Value v2, ArrayRef<int64_t> mask) {
2469-
build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
2470-
}
2471-
24722467
LogicalResult ShuffleOp::verify() {
24732468
VectorType resultType = getResultVectorType();
24742469
VectorType v1Type = getV1VectorType();
@@ -2491,19 +2486,18 @@ LogicalResult ShuffleOp::verify() {
24912486
return emitOpError("dimension mismatch");
24922487
}
24932488
// Verify mask length.
2494-
auto maskAttr = getMask().getValue();
2495-
int64_t maskLength = maskAttr.size();
2489+
ArrayRef<int64_t> mask = getMask();
2490+
int64_t maskLength = mask.size();
24962491
if (maskLength <= 0)
24972492
return emitOpError("invalid mask length");
24982493
if (maskLength != resultType.getDimSize(0))
24992494
return emitOpError("mask length mismatch");
25002495
// Verify all indices.
25012496
int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
25022497
(v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
2503-
for (const auto &en : llvm::enumerate(maskAttr)) {
2504-
auto attr = llvm::dyn_cast<IntegerAttr>(en.value());
2505-
if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
2506-
return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2498+
for (auto [idx, maskPos] : llvm::enumerate(mask)) {
2499+
if (maskPos < 0 || maskPos >= indexSize)
2500+
return emitOpError("mask index #") << (idx + 1) << " out of range";
25072501
}
25082502
return success();
25092503
}
@@ -2527,13 +2521,12 @@ ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
25272521
return success();
25282522
}
25292523

2530-
static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2531-
uint64_t expected = begin;
2532-
return idxArr.size() == width &&
2533-
llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2534-
[&expected](auto attr) {
2535-
return attr.getZExtValue() == expected++;
2536-
});
2524+
template <typename T>
2525+
static bool isStepIndexArray(ArrayRef<T> idxArr, uint64_t begin, size_t width) {
2526+
T expected = begin;
2527+
return idxArr.size() == width && llvm::all_of(idxArr, [&expected](T value) {
2528+
return value == expected++;
2529+
});
25372530
}
25382531

25392532
OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
@@ -2568,8 +2561,7 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
25682561
SmallVector<Attribute> results;
25692562
auto lhsElements = llvm::cast<DenseElementsAttr>(lhs).getValues<Attribute>();
25702563
auto rhsElements = llvm::cast<DenseElementsAttr>(rhs).getValues<Attribute>();
2571-
for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2572-
int64_t i = index.getZExtValue();
2564+
for (int64_t i : this->getMask()) {
25732565
if (i >= lhsSize) {
25742566
results.push_back(rhsElements[i - lhsSize]);
25752567
} else {
@@ -2590,13 +2582,13 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
25902582
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
25912583
PatternRewriter &rewriter) const override {
25922584
VectorType v1VectorType = shuffleOp.getV1VectorType();
2593-
ArrayAttr mask = shuffleOp.getMask();
2585+
ArrayRef<int64_t> mask = shuffleOp.getMask();
25942586
if (v1VectorType.getRank() > 0)
25952587
return failure();
25962588
if (mask.size() != 1)
25972589
return failure();
25982590
VectorType resType = VectorType::Builder(v1VectorType).setShape({1});
2599-
if (llvm::cast<IntegerAttr>(mask[0]).getInt() == 0)
2591+
if (mask[0] == 0)
26002592
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
26012593
shuffleOp.getV1());
26022594
else
@@ -2651,11 +2643,11 @@ class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
26512643
op, "ShuffleOp types don't match an interleave");
26522644
}
26532645

2654-
ArrayAttr shuffleMask = op.getMask();
2646+
ArrayRef<int64_t> shuffleMask = op.getMask();
26552647
int64_t resultVectorSize = resultType.getNumElements();
26562648
for (int i = 0, e = resultVectorSize / 2; i < e; ++i) {
2657-
int64_t maskValueA = cast<IntegerAttr>(shuffleMask[i * 2]).getInt();
2658-
int64_t maskValueB = cast<IntegerAttr>(shuffleMask[(i * 2) + 1]).getInt();
2649+
int64_t maskValueA = shuffleMask[i * 2];
2650+
int64_t maskValueB = shuffleMask[(i * 2) + 1];
26592651
if (maskValueA != i || maskValueB != (resultVectorSize / 2) + i)
26602652
return rewriter.notifyMatchFailure(op,
26612653
"ShuffleOp mask not interleaving");

mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,7 @@ class Convert1DExtractStridedSliceIntoShuffle
225225
off += stride)
226226
offsets.push_back(off);
227227
rewriter.replaceOpWithNewOp<ShuffleOp>(op, dstType, op.getVector(),
228-
op.getVector(),
229-
rewriter.getI64ArrayAttr(offsets));
228+
op.getVector(), offsets);
230229
return success();
231230
}
232231
};

mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,7 @@ struct LinearizeVectorExtractStridedSlice final
232232
}
233233
// Perform a shuffle to extract the kD vector.
234234
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
235-
extractOp, dstType, srcVector, srcVector,
236-
rewriter.getI64ArrayAttr(indices));
235+
extractOp, dstType, srcVector, srcVector, indices);
237236
return success();
238237
}
239238

@@ -298,20 +297,17 @@ struct LinearizeVectorShuffle final
298297
// that needs to be shuffled to the destination vector. If shuffleSliceLen >
299298
// 1 we need to shuffle the slices (consecutive shuffleSliceLen number of
300299
// elements) instead of scalars.
301-
ArrayAttr mask = shuffleOp.getMask();
300+
ArrayRef<int64_t> mask = shuffleOp.getMask();
302301
int64_t totalSizeOfShuffledElmnts = mask.size() * shuffleSliceLen;
303302
llvm::SmallVector<int64_t, 2> indices(totalSizeOfShuffledElmnts);
304-
for (auto [i, value] :
305-
llvm::enumerate(mask.getAsValueRange<IntegerAttr>())) {
306-
307-
int64_t v = value.getZExtValue();
303+
for (auto [i, value] : llvm::enumerate(mask)) {
308304
std::iota(indices.begin() + shuffleSliceLen * i,
309305
indices.begin() + shuffleSliceLen * (i + 1),
310-
shuffleSliceLen * v);
306+
shuffleSliceLen * value);
311307
}
312308

313-
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
314-
shuffleOp, dstType, vec1, vec2, rewriter.getI64ArrayAttr(indices));
309+
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(shuffleOp, dstType, vec1,
310+
vec2, indices);
315311
return success();
316312
}
317313

@@ -368,8 +364,7 @@ struct LinearizeVectorExtract final
368364
llvm::SmallVector<int64_t, 2> indices(size);
369365
std::iota(indices.begin(), indices.end(), linearizedOffset);
370366
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
371-
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(),
372-
rewriter.getI64ArrayAttr(indices));
367+
extractOp, dstTy, adaptor.getVector(), adaptor.getVector(), indices);
373368

374369
return success();
375370
}
@@ -452,8 +447,7 @@ struct LinearizeVectorInsert final
452447
// [offset+srcNumElements, end)
453448

454449
rewriter.replaceOpWithNewOp<vector::ShuffleOp>(
455-
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(),
456-
rewriter.getI64ArrayAttr(indices));
450+
insertOp, dstTy, adaptor.getDest(), adaptor.getSource(), indices);
457451

458452
return success();
459453
}

0 commit comments

Comments
 (0)