-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][linalg] Simplify createWriteOrMaskedWrite
(NFC)
#141567
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/banach-space/vector/update_vectorize_insert_slice
Are you sure you want to change the base?
[mlir][linalg] Simplify createWriteOrMaskedWrite
(NFC)
#141567
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir-linalg Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
@llvm/pr-subscribers-mlir-vector Author: Andrzej Warzyński (banach-space) Changes
Full diff: https://github.com/llvm/llvm-project/pull/141567.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0113ba86a5ae3..2abb2f0ea467c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1590,61 +1590,46 @@ static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
/// Creates an optionally masked TransferWriteOp
///
/// Generates the following operation:
-/// %res = vector.transfer_write %vectorToStore into %dest
+/// %res = vector.transfer_write %vecToStore into %dest
///
-/// If the leading N dimensions of the vector to store do not match
-/// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
-/// masking is applied to ensure correctness:
+/// If shape(vecToStore) != shape(dest), masking is used to ensure correctness:
///
-/// %mask = vector.create_mask(%destShape) : %vectorToStoreShape
+/// %mask = vector.create_mask(%destShape) : %vecToStoreShape
/// %res = vector.mask %mask {
-/// vector.transfer_write %vectorToStore into %dest
+/// vector.transfer_write %vecToStore into %dest
/// }
///
-/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// The mask shape is identical to `vecToStore` (with the element type ==
/// i1), and the mask values are based on the shape of the `dest` tensor.
///
/// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
/// is used instead of masking:
///
-/// %write = vector.transfer_write %vectorToStore into %dest
+/// %write = vector.transfer_write %vecToStore into %dest
/// in_bounds_flags = (...)
/// %res = vector.transfer_write %input into %dest
/// {in_bounds = in_bounds_flags}
///
-/// `writeIndices` specifies the offsets to use. If empty, all indices are set
-/// to 0.
-///
-/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
-/// `valueToStore`.
-/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
-/// already provided in `vectorToStore`.
+/// Finally, `writeIndices` specifies the offsets to use. If empty, all indices
+/// are set to 0.
static Operation *
-createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
- Value dest,
- ArrayRef<int64_t> inputVecSizesForLeadingDims,
- SmallVector<Value> writeIndices = {},
+createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore,
+ Value dest, SmallVector<Value> writeIndices = {},
bool useInBoundsInsteadOfMasking = false) {
ShapedType destType = cast<ShapedType>(dest.getType());
int64_t destRank = destType.getRank();
auto destShape = destType.getShape();
- VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+ VectorType vecToStoreType = cast<VectorType>(vecToStore.getType());
int64_t vecToStoreRank = vecToStoreType.getRank();
auto vecToStoreShape = vecToStoreType.getShape();
// Compute the in_bounds attribute
SmallVector<bool> inBoundsVal(vecToStoreRank, true);
if (useInBoundsInsteadOfMasking) {
- // In this case, assume that all the required vector sizes have been
- // provided.
- assert(inputVecSizesForLeadingDims.size() ==
- static_cast<size_t>(vecToStoreType.getRank()) &&
- "Insufficient number of input vector sizes!");
- // Update the inBounds attribute.
for (unsigned i = 0; i < destRank; i++)
- inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
+ inBoundsVal[i] = (destShape[i] == vecToStoreShape[i]) &&
!ShapedType::isDynamic(destShape[i]);
}
@@ -1660,7 +1645,7 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
// Generate the xfer_write Op
Operation *write =
builder.create<vector::TransferWriteOp>(loc,
- /*vector=*/vectorToStore,
+ /*vector=*/vecToStore,
/*source=*/dest,
/*indices=*/writeIndices,
/*inBounds=*/inBoundsVal);
@@ -1669,46 +1654,25 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
if (useInBoundsInsteadOfMasking)
return write;
- assert(llvm::none_of(
- destShape.drop_front(inputVecSizesForLeadingDims.size()),
- [](int64_t size) { return size == ShapedType::kDynamic; }) &&
- "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
-
- // Check if masking is needed.
- bool needMaskForWrite =
- !llvm::equal(inputVecSizesForLeadingDims,
- destShape.take_front(destRank - vecToStoreRank +
- inputVecSizesForLeadingDims.size()));
-
- // If masking is needed, generate the mask and mask the operation.
- if (needMaskForWrite) {
- // Get the mask shape + type. Missing mask dimensions are taken from
- // `vectorToStore`.
- SmallVector<int64_t> writeMaskShape;
- writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
- inputVecSizesForLeadingDims.end());
- if (vecToStoreRank >
- static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
- writeMaskShape.append(vecToStoreShape.begin() +
- inputVecSizesForLeadingDims.size(),
- vecToStoreShape.end());
- auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-
- SmallVector<OpFoldResult> destSizes =
- tensor::getMixedSizes(builder, loc, dest);
- SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
- destSizes.end());
-
- if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
- writeMaskShape))
- return write;
-
- Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
- loc, writeMaskType, maskSizes);
- write = mlir::vector::maskOperation(builder, write, maskForWrite);
- }
+ // Check if masking is needed. If not, exit.
+ if (llvm::equal(vecToStoreShape, destShape.take_back(vecToStoreRank)))
+ return write;
+
+ // Compute the mask and mask the write Op.
+ auto writeMaskType = VectorType::get(vecToStoreShape, builder.getI1Type());
+
+ SmallVector<OpFoldResult> destSizes =
+ tensor::getMixedSizes(builder, loc, dest);
+ SmallVector<OpFoldResult> maskSizes(destSizes.end() - vecToStoreRank,
+ destSizes.end());
+
+ if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+ vecToStoreShape))
+ return write;
- return write;
+ Value maskForWrite =
+ builder.createOrFold<vector::CreateMaskOp>(loc, writeMaskType, maskSizes);
+ return mlir::vector::maskOperation(builder, write, maskForWrite);
}
/// Vectorize linalg::PackOp with (1) static inner_tiles (2) constant
@@ -1808,10 +1772,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0],
transposeOp.getResult().getType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, transposeOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
+ /*writeIndices=*/{},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -1949,7 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
shapeCastOp.getResult().getType().getElementType());
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), dest,
- /*inputVecSizesForLeadingDims=*/writeVectorSizes,
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
newResults.push_back(write->getResult(0));
return success();
@@ -1982,10 +1945,9 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
// Create Xfer write Op
Value dest = rewriter.create<tensor::EmptyOp>(
loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, maskedRead, dest,
- /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
- /*useInBoundsInsteadOfMasking=*/false);
+ Operation *write =
+ createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest, {},
+ /*useInBoundsInsteadOfMasking=*/false);
newResults.push_back(write->getResult(0));
return success();
}
@@ -3041,8 +3003,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
// Create write
auto writeIndices =
getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
- Operation *write = createWriteOrMaskedWrite(
- rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
+ Operation *write = createWriteOrMaskedWrite(rewriter, loc, read,
+ sliceOp.getDest(), writeIndices);
// 4. Finalize
newResults.push_back(write->getResult(0));
|
createWriteOrMaskedWrite
(NFC)
This patch refactors two vectorization hooks in Vectorization.cpp: * `createWriteOrMaskedWrite` gains a new parameter for write indices, aligning it with its counterpart `createReadOrMaskedRead`. * `vectorizeAsInsertSliceOp` is updated to reuse both of the above hooks, rather than re-implementing similar logic. CONTEXT ------- This is effectively a refactoring of the logic for vectorizing `tensor.insert_slice`. Recent updates added masking support: * #122927 * #123031 At the time, reuse of the shared `create*` hooks wasn't feasible due to missing parameters and overly rigid assumptions. This patch resolves that and moves us closer to a more maintainable structure. CHANGES IN `vectorizeAsInsertSliceOp` ------------------------------------- * Introduces a clear distinction between the destination tensor and the vector to store, via named variables like `destType`/`vecToStoreType`, `destShape`/`vecToStoreShape`, etc. * Ensures the correct rank and shape are used for attributes like in_bounds. For example, the size of the in_bounds array now matches the source vector rank, not the tensor rank. * Drops the assumption that `vecToStoreRank == destRank` — this doesn't hold in many real examples. * Deduces mask dimensions from `vecToStoreShape` (vector) instead of `destShape` (tensor). (Eventually we should not require `inputVecSizesForLeadingDims` at all — mask shape should be inferred.) NEW HELPER: `isMaskTriviallyFoldable` ------------------------------------- Adds a utility to detect when masking is unnecessary. This avoids inserting redundant masks and reduces the burden on canonicalization to clean them up later. Example where masking is provably unnecessary: ```mlir %2 = vector.mask %1 { vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32> } : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32> ``` Also, without this hook, tests are more complicated and require more matching. TEST CHANGES ----------- This patch primarily affects vectorization of: * `tensor.insert_slice`, now refactored to use shared hooks. `tensor.pad` vectorization patterns, which internally use `tensor.insert_slice`, are also _effectively_ updated. Note, only pad-with-patterns.mlir is affected. Most test updates involve the insertion of masks that were previously missing — this reflects a correctness fix, not a regression. In all cases, the added masks are indeed required. You’ll also notice more repeated constants (`arith.constant 0 : index`), due to increased use of helper hooks. This will be cleaned up separately via a constant cache (see #138265 for discussion). NOTE FOR REVIEWERS ------------------ This is a fairly substantial rewrite. You may find it easier to review `createWriteOrMaskedWrite` as a new method rather than diffing line-by-line. TODOs (future PRs) ------------------ Further alignment of `createWriteOrMaskedWrite` and `createReadOrMaskedRead`: * Move `createWriteOrMaskedWrite` next to `createReadOrMaskedRead` (in VectorUtils.cpp) * Make `createReadOrMaskedRead` leverage `isMaskTriviallyFoldable`. * Extend `isMaskTriviallyFoldable` with value-bounds-analysis. See the updated test in transform-vector.mlir for an example that would benefit from this. (* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)
This patch removes `inputVecSizesForLeadingDims` from the parameter list of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes can be obtained from the `vecToStore` parameter. Since this doesn't change behavior or test results, it's marked as NFC. Additional cleanups: * Renamed `vectorToStore` to `vecToStore` for consistency and brevity. * Rewrote a conditional at the end of the function to use early exit, improving readability: ```cpp // BEFORE: if (maskingRequried) { Value maskForWrite = ...; write = maskOperation(write, maskForWrite); } return write; // AFTER if (!maskingRequried) return write; Value maskFroWrite = ...; return vector::maskOperation(builder, write, maskForWrite); ``` This change addresses a TODO from #141244.
53d435d
to
ca24a26
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice clean up, thanks!
42b1783
to
373036e
Compare
This patch removes
inputVecSizesForLeadingDims
from the parameter listof
createWriteOrMaskedWrite
. That argument is unnecessary — vector sizescan be obtained from the
vecToStore
parameter. Since this doesn't changebehavior or test results, it's marked as NFC.
Additional cleanups:
vectorToStore
tovecToStore
for consistency and brevity.improving readability: