Skip to content

Commit f0922e9

Browse files
committed
[[mlir][linalg] Refactor vectorization hooks to improve code reuse
This patch refactors two vectorization hooks in Vectorization.cpp: * `createWriteOrMaskedWrite` gains a new parameter for write indices, aligning it with its counterpart `createReadOrMaskedRead`. * `vectorizeAsInsertSliceOp` is updated to reuse both of the above hooks, rather than re-implementing similar logic. CONTEXT ------- This is effectively a refactoring of the logic for vectorizing `tensor.insert_slice`. Recent updates added masking support: * #122927 * #123031 At the time, reuse of the shared `create*` hooks wasn't feasible due to missing parameters and overly rigid assumptions. This patch resolves that and moves us closer to a more maintainable structure. CHANGES IN `vectorizeAsInsertSliceOp` ------------------------------------- * Introduces a clear distinction between the destination tensor and the vector to store, via named variables like `destType`/`vecToStoreType`, `destShape`/`vecToStoreShape`, etc. * Ensures the correct rank and shape are used for attributes like in_bounds. For example, the size of the in_bounds array now matches the source vector rank, not the tensor rank. * Drops the assumption that `vecToStoreRank == destRank` — this doesn't hold in many real examples. * Deduces mask dimensions from `vecToStoreShape` (vector) instead of `destShape` (tensor). (Eventually we should not require `inputVecSizesForLeadingDims` at all — mask shape should be inferred.) NEW HELPER: `isMaskTriviallyFoldable` ------------------------------------- Adds a utility to detect when masking is unnecessary. This avoids inserting redundant masks and reduces the burden on canonicalization to clean them up later. Example where masking is provably unnecessary: ```mlir %2 = vector.mask %1 { vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32> } : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32> ``` Also, without this hook, tests are more complicated and require more matching. TEST CHANGES ----------- This patch primarily affects vectorization of: * `tensor.insert_slice`, now refactored to use shared hooks. `tensor.pad` vectorization patterns, which internally use `tensor.insert_slice`, are also _effectively_ updated. Note, only pad-with-patterns.mlir is affected. Most test updates involve the insertion of masks that were previously missing — this reflects a correctness fix, not a regression. In all cases, the added masks are indeed required. You’ll also notice more repeated constants (`arith.constant 0 : index`), due to increased use of helper hooks. This will be cleaned up separately via a constant cache (see #138265 for discussion). NOTE FOR REVIEWERS ------------------ This is a fairly substantial rewrite. You may find it easier to review `createWriteOrMaskedWrite` as a new method rather than diffing line-by-line. TODOs (future PRs) ------------------ Further alignment of `createWriteOrMaskedWrite` and `createReadOrMaskedRead`: * Move `createWriteOrMaskedWrite` next to `createReadOrMaskedRead` (in VectorUtils.cpp) * Make `createReadOrMaskedRead` leverage `isMaskTriviallyFoldable`. * Extend `isMaskTriviallyFoldable` with value-bounds-analysis. See the updated test in transform-vector.mlir for an example that would benefit from this. (* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)
1 parent 6e98c8c commit f0922e9

File tree

8 files changed

+283
-149
lines changed

8 files changed

+283
-149
lines changed

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 166 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15061506
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
15071507
}
15081508

1509+
/// Determines whether the mask for a corresponding `vector.transfer_write` op
1510+
/// is trivially foldable (i.e., guaranteed to be all true).
1511+
///
1512+
/// Requirements:
1513+
/// * All involved shapes (destination, mask) are static.
1514+
/// * All write indices are constant.
1515+
/// * All mask sizes are constant.
1516+
///
1517+
/// Once verified, the method checks for each destination dimension `d`:
1518+
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
1519+
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
1520+
///
1521+
/// rankDiff = rank(dest) - rank(mask).
1522+
///
1523+
/// This method takes a conservative view: it may return false even if the mask
1524+
/// is technically foldable.
1525+
///
1526+
/// EXAMPLE 1 (trivially foldable):
1527+
/// %c0 = arith.constant 0 : index
1528+
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
1529+
/// {in_bounds = [true, true]}
1530+
/// : vector<5x1xi32>, tensor<5x1xi32>
1531+
///
1532+
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
1533+
/// %c0 = arith.constant 0 : index
1534+
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
1535+
/// {in_bounds = [true, true]}
1536+
/// : vector<8x1xi32>, tensor<5x1xi32>
1537+
///
1538+
/// TODO: Re-use in createReadOrMaskedRead
1539+
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
1540+
SmallVector<Value> &writeIdxs,
1541+
ArrayRef<int64_t> destShape,
1542+
ArrayRef<int64_t> maskShape) {
1543+
// Masking is unavoidable in the case of dynamic tensors.
1544+
if (ShapedType::isDynamicShape(destShape))
1545+
return false;
1546+
1547+
// Collect all constant mask sizes.
1548+
SmallVector<int64_t, 4> cstMaskSizes;
1549+
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
1550+
if (auto intSize = getConstantIntValue(dimSize)) {
1551+
cstMaskSizes.push_back(*intSize);
1552+
}
1553+
}
1554+
1555+
// If any of the mask sizes is non-constant, bail out.
1556+
if (cstMaskSizes.size() != maskShape.size())
1557+
return false;
1558+
1559+
// Collect all constant write indices.
1560+
SmallVector<int64_t, 4> cstWriteIdxs;
1561+
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
1562+
APSInt intVal;
1563+
if (matchPattern(idx, m_ConstantInt(&intVal))) {
1564+
cstWriteIdxs.push_back(intVal.getSExtValue());
1565+
}
1566+
}
1567+
1568+
// If any of the write indices is non-constant, bail out.
1569+
if (cstWriteIdxs.size() != destShape.size())
1570+
return false;
1571+
1572+
// Go over all destination dims and check (1) and (2). Take into account that:
1573+
// * The number of mask sizes will match the rank of the vector to store.
1574+
// This could be lower than the rank of the destination tensor.
1575+
// * Mask sizes could be larger than the corresponding mask shape (hence
1576+
// `clamp`).
1577+
// TODO: The 2nd item should be rejected by the verifier.
1578+
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
1579+
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
1580+
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
1581+
/*(2)*/ destShape[rankDiff + i] <
1582+
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
1583+
cstWriteIdxs[i]))
1584+
return false;
1585+
}
1586+
1587+
return true;
1588+
}
1589+
15091590
/// Creates an optionally masked TransferWriteOp
15101591
///
15111592
/// Generates the following operation:
15121593
/// %res = vector.transfer_write %vectorToStore into %dest
15131594
///
1514-
/// If the leading N dimensions of the destination tensor do not match
1595+
/// If the leading N dimensions of the vector to store do not match
15151596
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
15161597
/// masking is applied to ensure correctness:
15171598
///
1518-
/// %mask = vector.create_mask(%destShape)
1599+
/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
15191600
/// %res = vector.mask %mask {
15201601
/// vector.transfer_write %vectorToStore into %dest
15211602
/// }
15221603
///
1604+
/// The mask shape is identical to `vectorToStore` (with the element type ==
1605+
/// i1), and the mask values are based on the shape of the `dest` tensor.
1606+
///
15231607
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
15241608
/// is used instead of masking:
15251609
///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
15281612
/// %res = vector.transfer_write %input into %dest
15291613
/// {in_bounds = in_bounds_flags}
15301614
///
1531-
/// NOTE: All write offsets are set to 0.
1532-
/// TODO: Allow specyfying write offsets.
1533-
/// NOTE: When N < rank(input), the missing vector sizes are effectively
1534-
/// extracted from the trailing sizes of `destSizes`. This means those sizes
1535-
/// must be static.
1536-
/// TODO: Support cases where an arbitrary dim is dynamic - this will require
1537-
/// specifying all the vector sizes.
1615+
/// `writeIndices` specifies the offsets to use. If empty, all indices are set
1616+
/// to 0.
1617+
///
1618+
/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
1619+
/// `valueToStore`.
1620+
/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
1621+
/// already provided in `vectorToStore`.
15381622
static Operation *
15391623
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
15401624
Value dest,
15411625
ArrayRef<int64_t> inputVecSizesForLeadingDims,
1626+
SmallVector<Value> writeIndices = {},
15421627
bool useInBoundsInsteadOfMasking = false) {
15431628

15441629
ShapedType destType = cast<ShapedType>(dest.getType());
1545-
assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
1546-
static_cast<int64_t>(destType.getRank()) &&
1547-
"Rank mismatch!");
1548-
(void)destType;
1630+
int64_t destRank = destType.getRank();
1631+
auto destShape = destType.getShape();
15491632

1550-
int64_t rank = cast<ShapedType>(dest.getType()).getRank();
1551-
auto destShape = cast<ShapedType>(dest.getType()).getShape();
1633+
VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
1634+
int64_t vecToStoreRank = vecToStoreType.getRank();
1635+
auto vecToStoreShape = vecToStoreType.getShape();
15521636

15531637
// Compute the in_bounds attribute
1554-
SmallVector<bool> inBoundsVal(rank, true);
1638+
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
15551639
if (useInBoundsInsteadOfMasking) {
15561640
// In this case, assume that all the required vector sizes have been
15571641
// provided.
15581642
assert(inputVecSizesForLeadingDims.size() ==
1559-
static_cast<size_t>(destType.getRank()) &&
1643+
static_cast<size_t>(vecToStoreType.getRank()) &&
15601644
"Insufficient number of input vector sizes!");
15611645
// Update the inBounds attribute.
1562-
for (unsigned i = 0; i < rank; i++)
1646+
for (unsigned i = 0; i < destRank; i++)
15631647
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
15641648
!ShapedType::isDynamic(destShape[i]);
15651649
}
15661650

1651+
// If missing, initialize the write indices to 0.
1652+
assert(writeIndices.empty() ||
1653+
writeIndices.size() == static_cast<size_t>(destRank) &&
1654+
"Invalid number of write indices!");
1655+
if (writeIndices.empty()) {
1656+
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1657+
writeIndices = SmallVector<Value>(destRank, zero);
1658+
}
1659+
15671660
// Generate the xfer_write Op
1568-
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
1569-
Operation *write = builder.create<vector::TransferWriteOp>(
1570-
loc,
1571-
/*vector=*/vectorToStore,
1572-
/*source=*/dest,
1573-
/*indices=*/SmallVector<Value>(rank, zero),
1574-
/*inBounds=*/inBoundsVal);
1575-
assert(llvm::none_of(
1576-
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1577-
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1578-
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1661+
Operation *write =
1662+
builder.create<vector::TransferWriteOp>(loc,
1663+
/*vector=*/vectorToStore,
1664+
/*source=*/dest,
1665+
/*indices=*/writeIndices,
1666+
/*inBounds=*/inBoundsVal);
15791667

15801668
// If masking is disabled, exit.
15811669
if (useInBoundsInsteadOfMasking)
15821670
return write;
15831671

1672+
assert(llvm::none_of(
1673+
destShape.drop_front(inputVecSizesForLeadingDims.size()),
1674+
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
1675+
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
1676+
15841677
// Check if masking is needed.
15851678
bool needMaskForWrite =
15861679
!llvm::equal(inputVecSizesForLeadingDims,
1587-
destShape.take_front(inputVecSizesForLeadingDims.size()));
1680+
destShape.take_front(destRank - vecToStoreRank +
1681+
inputVecSizesForLeadingDims.size()));
15881682

15891683
// If masking is needed, generate the mask and mask the operation.
15901684
if (needMaskForWrite) {
1685+
// Get the mask shape + type. Missing mask dimensions are taken from
1686+
// `vectorToStore`.
15911687
SmallVector<int64_t> writeMaskShape;
15921688
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
15931689
inputVecSizesForLeadingDims.end());
1594-
writeMaskShape.append(destShape.begin() +
1595-
inputVecSizesForLeadingDims.size(),
1596-
destShape.end());
1690+
if (vecToStoreRank >
1691+
static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
1692+
writeMaskShape.append(vecToStoreShape.begin() +
1693+
inputVecSizesForLeadingDims.size(),
1694+
vecToStoreShape.end());
15971695
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
1598-
Value maskForWrite = builder.create<vector::CreateMaskOp>(
1599-
loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
1696+
1697+
SmallVector<OpFoldResult> destSizes =
1698+
tensor::getMixedSizes(builder, loc, dest);
1699+
SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
1700+
destSizes.end());
1701+
1702+
if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
1703+
writeMaskShape))
1704+
return write;
1705+
1706+
Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
1707+
loc, writeMaskType, maskSizes);
16001708
write = mlir::vector::maskOperation(builder, write, maskForWrite);
16011709
}
16021710

@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
17001808
Value dest = rewriter.create<tensor::EmptyOp>(
17011809
loc, reifiedReturnShapes[0],
17021810
transposeOp.getResult().getType().getElementType());
1703-
Operation *write =
1704-
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
1705-
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1706-
/*useInBoundsInsteadOfMasking=*/false);
1811+
Operation *write = createWriteOrMaskedWrite(
1812+
rewriter, loc, transposeOp.getResult(), dest,
1813+
/*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
1814+
/*useInBoundsInsteadOfMasking=*/false);
17071815
newResults.push_back(write->getResult(0));
17081816
return success();
17091817
}
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
18391947
Value dest = rewriter.create<tensor::EmptyOp>(
18401948
loc, reifiedRetShapes[0],
18411949
shapeCastOp.getResult().getType().getElementType());
1842-
Operation *write =
1843-
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
1844-
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1845-
useInBoundsInsteadOfMasking);
1950+
Operation *write = createWriteOrMaskedWrite(
1951+
rewriter, loc, shapeCastOp.getResult(), dest,
1952+
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
1953+
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
18461954
newResults.push_back(write->getResult(0));
18471955
return success();
18481956
}
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
18741982
// Create Xfer write Op
18751983
Value dest = rewriter.create<tensor::EmptyOp>(
18761984
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
1877-
Operation *write =
1878-
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
1879-
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
1880-
/*useInBoundsInsteadOfMasking=*/false);
1985+
Operation *write = createWriteOrMaskedWrite(
1986+
rewriter, loc, maskedRead, dest,
1987+
/*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
1988+
/*useInBoundsInsteadOfMasking=*/false);
18811989
newResults.push_back(write->getResult(0));
18821990
return success();
18831991
}
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
29223030
auto vecType = VectorType::get(vecShape, sourceType.getElementType());
29233031

29243032
// 3. Generate TransferReadOp + TransferWriteOp
2925-
ReifiedRankedShapedTypeDims reifiedSrcSizes;
2926-
Value maskOp;
2927-
2928-
// If vector sizes are user provided, make sure to mask. First, generate the
2929-
// mask.
2930-
if (!inputVectorSizes.empty()) {
2931-
auto *srcDefOp = source.getDefiningOp();
2932-
if (!srcDefOp) {
2933-
LDBG("Unable to get the defining Op of " << sliceOp);
2934-
return failure();
2935-
}
2936-
2937-
LogicalResult status =
2938-
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
2939-
rewriter, reifiedSrcSizes);
2940-
if (status.failed()) {
2941-
LDBG("Unable to reify result shapes of " << srcDefOp);
2942-
return failure();
2943-
}
2944-
2945-
// Create the mask
2946-
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
2947-
maskOp = rewriter.create<vector::CreateMaskOp>(
2948-
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
2949-
}
3033+
auto loc = sliceOp.getLoc();
29503034

3035+
// Create read
29513036
SmallVector<Value> readIndices(
2952-
vecType.getRank(),
2953-
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
2954-
Operation *read = rewriter.create<vector::TransferReadOp>(
2955-
sliceOp.getLoc(), vecType, source, readIndices, padValue,
2956-
ArrayRef<bool>{readInBounds});
2957-
2958-
if (maskOp) {
2959-
read = mlir::vector::maskOperation(rewriter, read, maskOp);
2960-
}
2961-
2962-
auto writeIndices = getValueOrCreateConstantIndexOp(
2963-
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
2964-
2965-
Operation *write = rewriter.create<vector::TransferWriteOp>(
2966-
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
2967-
ArrayRef<bool>{writeInBounds});
2968-
2969-
if (maskOp) {
2970-
write = mlir::vector::maskOperation(rewriter, write, maskOp);
2971-
}
3037+
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
3038+
Value read = mlir::vector::createReadOrMaskedRead(
3039+
rewriter, loc, source, vecType.getShape(), padValue);
3040+
3041+
// Create write
3042+
auto writeIndices =
3043+
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
3044+
Operation *write = createWriteOrMaskedWrite(
3045+
rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
29723046

29733047
// 4. Finalize
29743048
newResults.push_back(write->getResult(0));

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
337337
auto sourceShape = sourceShapedType.getShape();
338338
assert(sourceShape.size() == inputVectorSizes.size() &&
339339
"expected same ranks.");
340-
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
341340
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
342341
assert(padValue.getType() == sourceShapedType.getElementType() &&
343342
"expected same pad element type to match source element type");
344343
int64_t readRank = inputVectorSizes.size();
345344
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
346345
SmallVector<bool> inBoundsVal(readRank, true);
346+
347347
if (useInBoundsInsteadOfMasking) {
348348
// Update the inBounds attribute.
349349
for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
362362
return transferReadOp;
363363
SmallVector<OpFoldResult> mixedSourceDims =
364364
tensor::getMixedSizes(builder, loc, source);
365+
366+
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
365367
Value mask =
366368
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
367369
return mlir::vector::maskOperation(builder, transferReadOp, mask)

0 commit comments

Comments
 (0)