Skip to content

[mlir][linalg] Simplify createWriteOrMaskedWrite (NFC) #141567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: users/banach-space/vector/update_vectorize_insert_slice
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 165 additions & 113 deletions mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1506,101 +1506,189 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
}

/// Determines whether a mask for xfer_write is trivially "all true"
///
/// Given all the inputs required to generate a mask (mask sizes and shapes),
/// and an xfer_write operation (write indices and the destination tensor
/// shape), determines whether the corresponding mask would be trivially
/// foldable (i.e., trivially "all true").
///
/// Use this method to avoid generating spurious masks and relaying on
/// vectorization post-processing to remove them.
///
/// Pre-conditions for a mask to be trivially foldable:
/// * All involved shapes (mask + destination tensor) are static.
/// * All write indices are constant.
/// * All mask sizes are constant (including `arith.constant`).
///
/// If the pre-conditions are met, the method checks for each destination
/// dimension `d`:
/// (1) destDimSize[rankDiff + d] <= maskShape[d]
/// (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
///
/// rankDiff = rank(dest) - rank(mask).
///
/// This method takes a conservative view: it may return false even if the mask
/// is technically foldable.
///
/// EXAMPLE 1 (trivially foldable, all shapes match, mask sizes match the shape
/// of the dest tensor):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<5x1xi32>, tensor<5x1xi32>
/// }
///
/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape,
/// mask is required to avoid out-of-bounds write):
/// %c0 = arith.constant 0 : index
/// %mask = vector.create_mask 5, 1
/// vector.mask %mask {
/// vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
/// {in_bounds = [true, true]}
/// : vector<8x1xi32>, tensor<5x1xi32>
/// }
///
/// TODO: Re-use in createReadOrMaskedRead
static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
SmallVector<Value> &writeIdxs,
ArrayRef<int64_t> destShape,
ArrayRef<int64_t> maskShape) {
// Masking is unavoidable in the case of dynamic tensors.
if (ShapedType::isDynamicShape(destShape))
return false;

// Collect all constant mask sizes.
SmallVector<int64_t, 4> cstMaskSizes;
for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
if (auto intSize = getConstantIntValue(dimSize)) {
cstMaskSizes.push_back(*intSize);
}
}

// If any of the mask sizes is non-constant, bail out.
if (cstMaskSizes.size() != maskShape.size())
return false;

// Collect all constant write indices.
SmallVector<int64_t, 4> cstWriteIdxs;
for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
APSInt intVal;
if (matchPattern(idx, m_ConstantInt(&intVal))) {
cstWriteIdxs.push_back(intVal.getSExtValue());
}
}

// If any of the write indices is non-constant, bail out.
if (cstWriteIdxs.size() != destShape.size())
return false;

// Go over all destination dims and check (1) and (2). Take into account that:
// * The number of mask sizes will match the rank of the vector to store.
// This could be lower than the rank of the destination tensor.
// * Mask sizes could be larger than the corresponding mask shape (hence
// `clamp`).
// TODO: The 2nd item should be rejected by the verifier.
int64_t rankDiff = destShape.size() - cstMaskSizes.size();
for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
/*(2)*/ destShape[rankDiff + i] <
(std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
cstWriteIdxs[i]))
return false;
}

return true;
}

/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
/// %res = vector.transfer_write %vectorToStore into %dest
/// %res = vector.transfer_write %vecToStore into %dest
///
/// If the leading N dimensions of the destination tensor do not match
/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
/// masking is applied to ensure correctness:
/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
/// %mask = vector.create_mask(%destShape)
/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
/// vector.transfer_write %vectorToStore into %dest
/// vector.transfer_write %vecToStore into %dest
/// }
///
/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
/// %write = vector.transfer_write %vectorToStore into %dest
/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
/// NOTE: All write offsets are set to 0.
/// TODO: Allow specyfying write offsets.
/// NOTE: When N < rank(input), the missing vector sizes are effectively
/// extracted from the trailing sizes of `destSizes`. This means those sizes
/// must be static.
/// TODO: Support cases where an arbitrary dim is dynamic - this will require
/// specifying all the vector sizes.
/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
/// are set to 0.
static Operation *
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
Value dest,
ArrayRef<int64_t> inputVecSizesForLeadingDims,
createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {

ShapedType destType = cast<ShapedType>(dest.getType());
assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
static_cast<int64_t>(destType.getRank()) &&
"Rank mismatch!");
(void)destType;
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();

int64_t rank = cast<ShapedType>(dest.getType()).getRank();
auto destShape = cast<ShapedType>(dest.getType()).getShape();
VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();

// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(rank, true);
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
// In this case, assume that all the required vector sizes have been
// provided.
assert(inputVecSizesForLeadingDims.size() ==
static_cast<size_t>(destType.getRank()) &&
"Insufficient number of input vector sizes!");
// Update the inBounds attribute.
for (unsigned i = 0; i < rank; i++)
inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
for (unsigned i = 0; i < destRank; i++)
inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}

// If missing, initialize the write indices to 0.
assert(writeIndices.empty() ||
writeIndices.size() == static_cast<size_t>(destRank) &&
"Invalid number of write indices!");
if (writeIndices.empty()) {
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
writeIndices = SmallVector<Value>(destRank, zero);
}

// Generate the xfer_write Op
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
Operation *write = builder.create<vector::TransferWriteOp>(
loc,
/*vector=*/vectorToStore,
/*source=*/dest,
/*indices=*/SmallVector<Value>(rank, zero),
/*inBounds=*/inBoundsVal);
assert(llvm::none_of(
destShape.drop_front(inputVecSizesForLeadingDims.size()),
[](int64_t size) { return size == ShapedType::kDynamic; }) &&
"Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
/*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);

// If masking is disabled, exit.
if (useInBoundsInsteadOfMasking)
return write;

// Check if masking is needed.
bool needMaskForWrite =
!llvm::equal(inputVecSizesForLeadingDims,
destShape.take_front(inputVecSizesForLeadingDims.size()));
// Check if masking is needed. If not, exit.
if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
return write;

// Compute the mask and mask the write Op.
auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());

// If masking is needed, generate the mask and mask the operation.
if (needMaskForWrite) {
SmallVector<int64_t> writeMaskShape;
writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
inputVecSizesForLeadingDims.end());
writeMaskShape.append(destShape.begin() +
inputVecSizesForLeadingDims.size(),
destShape.end());
auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
Value maskForWrite = builder.create<vector::CreateMaskOp>(
loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
write = mlir::vector::maskOperation(builder, write, maskForWrite);
}
SmallVector<OpFoldResult> destSizes =
tensor::getMixedSizes(builder, loc, dest);
SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
destSizes.end());

if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
vecToStoreShape))
return write;

return write;
Value maskForWrite =
builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
return mlir::vector::maskOperation(builder, write, maskForWrite);
}

/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
Expand Down Expand Up @@ -1702,7 +1790,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
transposeOp.getResult().getType().getElementType());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
/*writeIndices=*/{},
/*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
Expand Down Expand Up @@ -1839,10 +1927,9 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedRetShapes[0],
shapeCastOp.getResult().getType().getElementType());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
/*inputVecSizesForLeadingDims=*/writeVectorSizes,
useInBoundsInsteadOfMasking);
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
}
Expand Down Expand Up @@ -1875,8 +1962,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
Operation *write =
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
/*inputVecSizesForLeadingDims=*/inputVectorSizes,
createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
/*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
Expand Down Expand Up @@ -2922,53 +3008,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
auto vecType = VectorType::get(vecShape, sourceType.getElementType());

// 3. Generate TransferReadOp + TransferWriteOp
ReifiedRankedShapedTypeDims reifiedSrcSizes;
Value maskOp;

// If vector sizes are user provided, make sure to mask. First, generate the
// mask.
if (!inputVectorSizes.empty()) {
auto *srcDefOp = source.getDefiningOp();
if (!srcDefOp) {
LDBG("Unable to get the defining Op of " << sliceOp);
return failure();
}

LogicalResult status =
cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
rewriter, reifiedSrcSizes);
if (status.failed()) {
LDBG("Unable to reify result shapes of " << srcDefOp);
return failure();
}

// Create the mask
auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
maskOp = rewriter.create<vector::CreateMaskOp>(
sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
}
auto loc = sliceOp.getLoc();

// Create read
SmallVector<Value> readIndices(
vecType.getRank(),
rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
Operation *read = rewriter.create<vector::TransferReadOp>(
sliceOp.getLoc(), vecType, source, readIndices, padValue,
ArrayRef<bool>{readInBounds});

if (maskOp) {
read = mlir::vector::maskOperation(rewriter, read, maskOp);
}

auto writeIndices = getValueOrCreateConstantIndexOp(
rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());

Operation *write = rewriter.create<vector::TransferWriteOp>(
sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
ArrayRef<bool>{writeInBounds});

if (maskOp) {
write = mlir::vector::maskOperation(rewriter, write, maskOp);
}
vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
Value read = mlir::vector::createReadOrMaskedRead(
rewriter, loc, source, vecType.getShape(), padValue);

// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
sliceOp.getDest(), writeIndices);

// 4. Finalize
newResults.push_back(write->getResult(0));
Expand Down
4 changes: 3 additions & 1 deletion mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto sourceShape = sourceShapedType.getShape();
assert(sourceShape.size() == inputVectorSizes.size() &&
"expected same ranks.");
auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
assert(padValue.getType() == sourceShapedType.getElementType() &&
"expected same pad element type to match source element type");
int64_t readRank = inputVectorSizes.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
SmallVector<bool> inBoundsVal(readRank, true);

if (useInBoundsInsteadOfMasking) {
// Update the inBounds attribute.
for (unsigned i = 0; i < readRank; i++)
Expand All @@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
return transferReadOp;
SmallVector<OpFoldResult> mixedSourceDims =
tensor::getMixedSizes(builder, loc, source);

auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
Value mask =
builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
return mlir::vector::maskOperation(builder, transferReadOp, mask)
Expand Down
Loading