Skip to content

[mlir][linalg] Refactor vectorization hooks to improve code reuse #141244

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented May 23, 2025

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN createWriteOrMaskedWrite

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds attr now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank - this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all - mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
} : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32>

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing - this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.
  • Extend isMaskTriviallyFoldable with value-bounds-analysis. See the
    updated test in transform-vector.mlir for an example that would
    benefit from this.

(*) This method will eventually be moved out of Vectorization.cpp, which
isn't the right long-term home for it.

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-vector

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-linalg

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir-llvm

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

This patch refactors two vectorization hooks in Vectorization.cpp:

  • createWriteOrMaskedWrite gains a new parameter for write indices,
    aligning it with its counterpart createReadOrMaskedRead.
  • vectorizeAsInsertSliceOp is updated to reuse both of the above
    hooks, rather than re-implementing similar logic.

CONTEXT

This is effectively a refactoring of the logic for vectorizing
tensor.insert_slice. Recent updates added masking support:

At the time, reuse of the shared create* hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN vectorizeAsInsertSliceOp

  • Introduces a clear distinction between the destination tensor and the
    vector to store, via named variables like destType/vecToStoreType,
    destShape/vecToStoreShape, etc.
  • Ensures the correct rank and shape are used for attributes like
    in_bounds. For example, the size of the in_bounds array now matches
    the source vector rank, not the tensor rank.
  • Drops the assumption that vecToStoreRank == destRank — this doesn't
    hold in many real examples.
  • Deduces mask dimensions from vecToStoreShape (vector) instead of
    destShape (tensor). (Eventually we should not require
    inputVecSizesForLeadingDims at all — mask shape should be inferred.)

NEW HELPER: isMaskTriviallyFoldable

Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:

%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector&lt;1x2x3xf32&gt;, tensor&lt;9x8x7x1x2x3xf32&gt;
} : vector&lt;1x2x3xi1&gt; -&gt; tensor&lt;9x8x7x1x2x3xf32&gt;

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES

This patch primarily affects vectorization of:

  • tensor.insert_slice, now refactored to use shared hooks.

tensor.pad vectorization patterns, which internally use
tensor.insert_slice, are also effectively updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (arith.constant 0 : index),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS

This is a fairly substantial rewrite. You may find it easier to review
createWriteOrMaskedWrite as a new method rather than diffing
line-by-line.

TODOs (future PRs)

Further alignment of createWriteOrMaskedWrite and
createReadOrMaskedRead:

  • Move createWriteOrMaskedWrite next to createReadOrMaskedRead (in
    VectorUtils.cpp)
  • Make createReadOrMaskedRead leverage isMaskTriviallyFoldable.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)


Patch is 34.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/141244.diff

7 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp (+166-92)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+3-1)
  • (modified) mlir/test/Dialect/LLVM/transform-e2e.mlir (+6-4)
  • (modified) mlir/test/Dialect/Linalg/vectorization.mlir (-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir (+10-1)
  • (modified) mlir/test/Dialect/Linalg/vectorization/insert-slice.mlir (+51-30)
  • (modified) mlir/test/Dialect/Linalg/vectorization/pad-with-patterns.mlir (+17-10)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index c5b62227777a7..0113ba86a5ae3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1506,20 +1506,104 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
   return applyPermutation(destShape, linalg::getPackInverseDestPerm(packOp));
 }
 
+/// Determines whether the mask for a corresponding `vector.transfer_write` op
+/// is trivially foldable (i.e., guaranteed to be all true).
+///
+/// Requirements:
+///   * All involved shapes (destination, mask) are static.
+///   * All write indices are constant.
+///   * All mask sizes are constant.
+///
+/// Once verified, the method checks for each destination dimension `d`:
+///   (1) destDimSize[rankDiff + d] <= maskShape[d]
+///   (2) destDimSize[rankDiff + d] <= writeIndex[d] + maskSize[d]
+///
+/// rankDiff = rank(dest) - rank(mask).
+///
+/// This method takes a conservative view: it may return false even if the mask
+/// is technically foldable.
+///
+/// EXAMPLE 1 (trivially foldable):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_1, %dest{[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<5x1xi32>, tensor<5x1xi32>
+///
+/// EXAMPLE 2 (not trivially foldable - vector shape exceeds the tensor shape):
+///   %c0 = arith.constant 0 : index
+///   vector.transfer_write %vecToStore_2, %dest[%c0, %c0]
+///     {in_bounds = [true, true]}
+///   : vector<8x1xi32>, tensor<5x1xi32>
+///
+/// TODO: Re-use in createReadOrMaskedRead
+static bool isMaskTriviallyFoldable(SmallVector<OpFoldResult> &maskSizes,
+                                    SmallVector<Value> &writeIdxs,
+                                    ArrayRef<int64_t> destShape,
+                                    ArrayRef<int64_t> maskShape) {
+  // Masking is unavoidable in the case of dynamic tensors.
+  if (ShapedType::isDynamicShape(destShape))
+    return false;
+
+  // Collect all constant mask sizes.
+  SmallVector<int64_t, 4> cstMaskSizes;
+  for (auto [i, dimSize] : llvm::enumerate(maskSizes)) {
+    if (auto intSize = getConstantIntValue(dimSize)) {
+      cstMaskSizes.push_back(*intSize);
+    }
+  }
+
+  // If any of the mask sizes is non-constant, bail out.
+  if (cstMaskSizes.size() != maskShape.size())
+    return false;
+
+  // Collect all constant write indices.
+  SmallVector<int64_t, 4> cstWriteIdxs;
+  for (auto [i, idx] : llvm::enumerate(writeIdxs)) {
+    APSInt intVal;
+    if (matchPattern(idx, m_ConstantInt(&intVal))) {
+      cstWriteIdxs.push_back(intVal.getSExtValue());
+    }
+  }
+
+  // If any of the write indices is non-constant, bail out.
+  if (cstWriteIdxs.size() != destShape.size())
+    return false;
+
+  // Go over all destination dims and check (1) and (2). Take into account that:
+  //  * The number of mask sizes will match the rank of the vector to store.
+  //    This could be lower than the rank of the destination tensor.
+  //  * Mask sizes could be larger than the corresponding mask shape (hence
+  //  `clamp`).
+  // TODO: The 2nd item should be rejected by the verifier.
+  int64_t rankDiff = destShape.size() - cstMaskSizes.size();
+  for (auto [i, idx] : llvm::enumerate(cstMaskSizes)) {
+    if (/*(1)*/ maskShape[i] > destShape[rankDiff + i] ||
+        /*(2)*/ destShape[rankDiff + i] <
+            (std::clamp(cstMaskSizes[i], int64_t(0), maskShape[i]) +
+             cstWriteIdxs[i]))
+      return false;
+  }
+
+  return true;
+}
+
 /// Creates an optionally masked TransferWriteOp
 ///
 /// Generates the following operation:
 ///   %res = vector.transfer_write %vectorToStore into %dest
 ///
-/// If the leading N dimensions of the destination tensor do not match
+/// If the leading N dimensions of the vector to store do not match
 /// `inputVecSizesForLeadingDims` (N = rank(inputVecSizesForLeadingDims)),
 /// masking is applied to ensure correctness:
 ///
-///   %mask = vector.create_mask(%destShape)
+///   %mask = vector.create_mask(%destShape) : %vectorToStoreShape
 ///   %res = vector.mask %mask {
 ///     vector.transfer_write %vectorToStore into %dest
 ///   }
 ///
+/// The mask shape is identical to `vectorToStore` (with the element type ==
+/// i1), and the mask values are based on the shape of the `dest` tensor.
+///
 /// If `useInBoundsInsteadOfMasking` is set to `true`, the `in_bounds` attribute
 /// is used instead of masking:
 ///
@@ -1528,75 +1612,99 @@ static SmallVector<int64_t> getTiledPackShape(linalg::PackOp packOp,
 ///   %res = vector.transfer_write %input into %dest
 ///       {in_bounds = in_bounds_flags}
 ///
-/// NOTE: All write offsets are set to 0.
-/// TODO: Allow specyfying write offsets.
-/// NOTE: When N < rank(input), the missing vector sizes are effectively
-/// extracted from the trailing sizes of `destSizes`. This means those sizes
-/// must be static.
-/// TODO: Support cases where an arbitrary dim is dynamic - this will require
-/// specifying all the vector sizes.
+/// `writeIndices` specifies the offsets to use. If empty, all indices are set
+/// to 0.
+///
+/// NOTE: When N < rank(vectorToStore), the missing vector sizes are taken from
+/// `valueToStore`.
+/// TODO: `inputVecSizesForLeadingDims` should not be required - these sizes are
+/// already provided in `vectorToStore`.
 static Operation *
 createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vectorToStore,
                          Value dest,
                          ArrayRef<int64_t> inputVecSizesForLeadingDims,
+                         SmallVector<Value> writeIndices = {},
                          bool useInBoundsInsteadOfMasking = false) {
 
   ShapedType destType = cast<ShapedType>(dest.getType());
-  assert(cast<VectorType>(vectorToStore.getType()).getRank() ==
-             static_cast<int64_t>(destType.getRank()) &&
-         "Rank mismatch!");
-  (void)destType;
+  int64_t destRank = destType.getRank();
+  auto destShape = destType.getShape();
 
-  int64_t rank = cast<ShapedType>(dest.getType()).getRank();
-  auto destShape = cast<ShapedType>(dest.getType()).getShape();
+  VectorType vecToStoreType = cast<VectorType>(vectorToStore.getType());
+  int64_t vecToStoreRank = vecToStoreType.getRank();
+  auto vecToStoreShape = vecToStoreType.getShape();
 
   // Compute the in_bounds attribute
-  SmallVector<bool> inBoundsVal(rank, true);
+  SmallVector<bool> inBoundsVal(vecToStoreRank, true);
   if (useInBoundsInsteadOfMasking) {
     // In this case, assume that all the required vector sizes have been
     // provided.
     assert(inputVecSizesForLeadingDims.size() ==
-               static_cast<size_t>(destType.getRank()) &&
+               static_cast<size_t>(vecToStoreType.getRank()) &&
            "Insufficient number of input vector sizes!");
     // Update the inBounds attribute.
-    for (unsigned i = 0; i < rank; i++)
+    for (unsigned i = 0; i < destRank; i++)
       inBoundsVal[i] = (destShape[i] == inputVecSizesForLeadingDims[i]) &&
                        !ShapedType::isDynamic(destShape[i]);
   }
 
+  // If missing, initialize the write indices to 0.
+  assert(writeIndices.empty() ||
+         writeIndices.size() == static_cast<size_t>(destRank) &&
+             "Invalid number of write indices!");
+  if (writeIndices.empty()) {
+    auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+    writeIndices = SmallVector<Value>(destRank, zero);
+  }
+
   // Generate the xfer_write Op
-  auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
-  Operation *write = builder.create<vector::TransferWriteOp>(
-      loc,
-      /*vector=*/vectorToStore,
-      /*source=*/dest,
-      /*indices=*/SmallVector<Value>(rank, zero),
-      /*inBounds=*/inBoundsVal);
-  assert(llvm::none_of(
-             destShape.drop_front(inputVecSizesForLeadingDims.size()),
-             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
-         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+  Operation *write =
+      builder.create<vector::TransferWriteOp>(loc,
+                                              /*vector=*/vectorToStore,
+                                              /*source=*/dest,
+                                              /*indices=*/writeIndices,
+                                              /*inBounds=*/inBoundsVal);
 
   // If masking is disabled, exit.
   if (useInBoundsInsteadOfMasking)
     return write;
 
+  assert(llvm::none_of(
+             destShape.drop_front(inputVecSizesForLeadingDims.size()),
+             [](int64_t size) { return size == ShapedType::kDynamic; }) &&
+         "Only dims aligned with inputVecSizesForLeadingDims may be dynamic");
+
   // Check if masking is needed.
   bool needMaskForWrite =
       !llvm::equal(inputVecSizesForLeadingDims,
-                   destShape.take_front(inputVecSizesForLeadingDims.size()));
+                   destShape.take_front(destRank - vecToStoreRank +
+                                        inputVecSizesForLeadingDims.size()));
 
   // If masking is needed, generate the mask and mask the operation.
   if (needMaskForWrite) {
+    // Get the mask shape + type. Missing mask dimensions are taken from
+    // `vectorToStore`.
     SmallVector<int64_t> writeMaskShape;
     writeMaskShape.append(inputVecSizesForLeadingDims.begin(),
                           inputVecSizesForLeadingDims.end());
-    writeMaskShape.append(destShape.begin() +
-                              inputVecSizesForLeadingDims.size(),
-                          destShape.end());
+    if (vecToStoreRank >
+        static_cast<int64_t>(inputVecSizesForLeadingDims.size()))
+      writeMaskShape.append(vecToStoreShape.begin() +
+                                inputVecSizesForLeadingDims.size(),
+                            vecToStoreShape.end());
     auto writeMaskType = VectorType::get(writeMaskShape, builder.getI1Type());
-    Value maskForWrite = builder.create<vector::CreateMaskOp>(
-        loc, writeMaskType, tensor::getMixedSizes(builder, loc, dest));
+
+    SmallVector<OpFoldResult> destSizes =
+        tensor::getMixedSizes(builder, loc, dest);
+    SmallVector<OpFoldResult> maskSizes(destSizes.end() - writeMaskShape.size(),
+                                        destSizes.end());
+
+    if (isMaskTriviallyFoldable(maskSizes, writeIndices, destShape,
+                                writeMaskShape))
+      return write;
+
+    Value maskForWrite = builder.createOrFold<vector::CreateMaskOp>(
+        loc, writeMaskType, maskSizes);
     write = mlir::vector::maskOperation(builder, write, maskForWrite);
   }
 
@@ -1700,10 +1808,10 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0],
       transposeOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, transposeOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, /*writeIndices=*/{},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1839,10 +1947,10 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedRetShapes[0],
       shapeCastOp.getResult().getType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, shapeCastOp.getResult(), dest,
-                               /*inputVecSizesForLeadingDims=*/writeVectorSizes,
-                               useInBoundsInsteadOfMasking);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, shapeCastOp.getResult(), dest,
+      /*inputVecSizesForLeadingDims=*/writeVectorSizes,
+      /*writeIndices=*/{}, useInBoundsInsteadOfMasking);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -1874,10 +1982,10 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
   // Create Xfer write Op
   Value dest = rewriter.create<tensor::EmptyOp>(
       loc, reifiedReturnShapes[0], padOp.getResultType().getElementType());
-  Operation *write =
-      createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest,
-                               /*inputVecSizesForLeadingDims=*/inputVectorSizes,
-                               /*useInBoundsInsteadOfMasking=*/false);
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, maskedRead, dest,
+      /*inputVecSizesForLeadingDims=*/inputVectorSizes, {},
+      /*useInBoundsInsteadOfMasking=*/false);
   newResults.push_back(write->getResult(0));
   return success();
 }
@@ -2922,53 +3030,19 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
   auto vecType = VectorType::get(vecShape, sourceType.getElementType());
 
   // 3. Generate TransferReadOp + TransferWriteOp
-  ReifiedRankedShapedTypeDims reifiedSrcSizes;
-  Value maskOp;
-
-  // If vector sizes are user provided, make sure to mask. First, generate the
-  // mask.
-  if (!inputVectorSizes.empty()) {
-    auto *srcDefOp = source.getDefiningOp();
-    if (!srcDefOp) {
-      LDBG("Unable to get the defining Op of " << sliceOp);
-      return failure();
-    }
-
-    LogicalResult status =
-        cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes(
-            rewriter, reifiedSrcSizes);
-    if (status.failed()) {
-      LDBG("Unable to reify result shapes of " << srcDefOp);
-      return failure();
-    }
-
-    // Create the mask
-    auto readMaskType = VectorType::get(inputVectorSizes, rewriter.getI1Type());
-    maskOp = rewriter.create<vector::CreateMaskOp>(
-        sliceOp.getLoc(), readMaskType, reifiedSrcSizes[0]);
-  }
+  auto loc = sliceOp.getLoc();
 
+  // Create read
   SmallVector<Value> readIndices(
-      vecType.getRank(),
-      rewriter.create<arith::ConstantIndexOp>(sliceOp.getLoc(), 0));
-  Operation *read = rewriter.create<vector::TransferReadOp>(
-      sliceOp.getLoc(), vecType, source, readIndices, padValue,
-      ArrayRef<bool>{readInBounds});
-
-  if (maskOp) {
-    read = mlir::vector::maskOperation(rewriter, read, maskOp);
-  }
-
-  auto writeIndices = getValueOrCreateConstantIndexOp(
-      rewriter, sliceOp.getLoc(), sliceOp.getMixedOffsets());
-
-  Operation *write = rewriter.create<vector::TransferWriteOp>(
-      sliceOp.getLoc(), read->getResult(0), sliceOp.getDest(), writeIndices,
-      ArrayRef<bool>{writeInBounds});
-
-  if (maskOp) {
-    write = mlir::vector::maskOperation(rewriter, write, maskOp);
-  }
+      vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0));
+  Value read = mlir::vector::createReadOrMaskedRead(
+      rewriter, loc, source, vecType.getShape(), padValue);
+
+  // Create write
+  auto writeIndices =
+      getValueOrCreateConstantIndexOp(rewriter, loc, sliceOp.getMixedOffsets());
+  Operation *write = createWriteOrMaskedWrite(
+      rewriter, loc, read, sliceOp.getDest(), vecType.getShape(), writeIndices);
 
   // 4. Finalize
   newResults.push_back(write->getResult(0));
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index d5dd6f2027be8..dda4856596bba 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -337,13 +337,13 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
   auto sourceShape = sourceShapedType.getShape();
   assert(sourceShape.size() == inputVectorSizes.size() &&
          "expected same ranks.");
-  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   auto vectorType = VectorType::get(inputVectorSizes, padValue.getType());
   assert(padValue.getType() == sourceShapedType.getElementType() &&
          "expected same pad element type to match source element type");
   int64_t readRank = inputVectorSizes.size();
   auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
   SmallVector<bool> inBoundsVal(readRank, true);
+
   if (useInBoundsInsteadOfMasking) {
     // Update the inBounds attribute.
     for (unsigned i = 0; i < readRank; i++)
@@ -362,6 +362,8 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc,
     return transferReadOp;
   SmallVector<OpFoldResult> mixedSourceDims =
       tensor::getMixedSizes(builder, loc, source);
+
+  auto maskType = VectorType::get(inputVectorSizes, builder.getI1Type());
   Value mask =
       builder.create<vector::CreateMaskOp>(loc, maskType, mixedSourceDims);
   return mlir::vector::maskOperation(builder, transferReadOp, mask)
diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir
index c00b47fb936e9..98cfaf249c898 100644
--- a/mlir/test/Dialect/LLVM/transform-e2e.mlir
+++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir
@@ -18,16 +18,14 @@ module attributes {transform.with_named_sequence} {
     %1, %loops:3 = transform.structured.tile_using_for %0 tile_sizes [2, 2, 2] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
     %2 = transform.get_parent_op %1 {isolated_from_above} : (!transform.any_op) -> !transform.any_op
     transform.structured.vectorize_children_and_apply_patterns %2 : (!transform.any_op) -> !transform.any_op
-    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
-        %module_op {bufferize_function_boundaries = true}
-        : (!transform.any_op) -> !transform.any_op
 
-    %f = transform.structured.match ops{["func.func"]} in %b
+    %f = transform.structured.match ops{["func.func"]} in %module_op
       : (!transform.any_op) -> !transform.any_op
 
     // TODO: group these lower-level controls into various properly named vector
     // lowering TD macros.
     transform.apply_patterns to %f {
+      transform.apply_patterns.vector.lower_masked_transfers
       transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
       transform.apply_patterns.vector.transfer_permutation_patterns
       transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel"
@@ -37,6 +35,10 @@ module attributes {transform.with_named_sequence} {
       transform.apply_patterns.vector.lower_shape_cast
       transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d"
     } : !transform.any_op
+
+    %b = transform.bufferization.one_shot_bufferize layout{IdentityLayoutMap}
+        %module_op {bufferize_function_boundaries = true}
+        : (!transform.any_op) -> !transform.any_op
     transform.yield
   }
 }
diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir
index 8c6760fa50325..9a18f040d57cd 100644
--- a/mlir/test/Dialect/Linalg/vectorization.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization.mlir
@@ -1027,4 +1027,3 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf
     transform.yield
   }
  }
-
diff --git a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
index f7764be9be73f..d1f2ed194f6ce 100644
--- a/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
+++ b/mlir/test/Dialect/Linalg/vectorization/insert-slice-with-patterns.mlir
@@ -67,10 +67,19 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:      %[[ARG_0:.*]]: tensor<1x?x3xf32>,
 // CHECK-SAME:      %[[PAD:.*]]: f32,
 // CHECK-SAME:      %[[SIZE:.*]]: index) -> tensor<9x8x7x1x2x3xf32> {
+// CHECK:           %[[C3:.*]] = arith.constant 3 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
 // CHECK:           %[[EMPTY:.*]] = tensor.empty() : tensor<9x8x7x1x2x3xf32>
 // CHECK:           %[[BC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<9x8x7x1x2x3xf32>
 // CHECK:           %[[WRITE:.*]] = vector.transfer_write %[[BC]], %[[EMPTY]]{{.*}} {in_bounds = [true, true, true, true, true, true]} : vector<9x8x7...
[truncated]

@banach-space banach-space changed the title [[mlir][linalg] Refactor vectorization hooks to improve code reuse [mlir][linalg] Refactor vectorization hooks to improve code reuse May 23, 2025
This patch refactors two vectorization hooks in Vectorization.cpp:
 * `createWriteOrMaskedWrite` gains a new parameter for write indices,
   aligning it with its counterpart `createReadOrMaskedRead`.
 * `vectorizeAsInsertSliceOp` is updated to reuse both of the above
   hooks, rather than re-implementing similar logic.

CONTEXT
-------
This is effectively a refactoring of the logic for vectorizing
`tensor.insert_slice`. Recent updates added masking support:
  * #122927
  * #123031

At the time, reuse of the shared `create*` hooks wasn't feasible due to
missing parameters and overly rigid assumptions. This patch resolves
that and moves us closer to a more maintainable structure.

CHANGES IN `vectorizeAsInsertSliceOp`
-------------------------------------
* Introduces a clear distinction between the destination tensor and the
  vector to store, via named variables like `destType`/`vecToStoreType`,
  `destShape`/`vecToStoreShape`, etc.
* Ensures the correct rank and shape are used for attributes like
  in_bounds. For example, the size of the in_bounds array now matches
  the source vector rank, not the tensor rank.
* Drops the assumption that `vecToStoreRank == destRank` — this doesn't
  hold in many real examples.
*  Deduces mask dimensions from `vecToStoreShape` (vector) instead of
   `destShape` (tensor). (Eventually we should not require
   `inputVecSizesForLeadingDims` at all — mask shape should be inferred.)

NEW HELPER: `isMaskTriviallyFoldable`
-------------------------------------
Adds a utility to detect when masking is unnecessary. This avoids
inserting redundant masks and reduces the burden on canonicalization to
clean them up later.

Example where masking is provably unnecessary:
```mlir
%2 = vector.mask %1 {
  vector.transfer_write %0, %arg1[%c0, %c0, %c0, %c0, %c0, %c0]
    {in_bounds = [true, true, true]}
    : vector<1x2x3xf32>, tensor<9x8x7x1x2x3xf32>
} : vector<1x2x3xi1> -> tensor<9x8x7x1x2x3xf32>
```

Also, without this hook, tests are more complicated and require more
matching.

TEST CHANGES
-----------
This patch primarily affects vectorization of:
  * `tensor.insert_slice`, now refactored to use shared hooks.

`tensor.pad` vectorization patterns, which internally use
`tensor.insert_slice`, are also _effectively_ updated. Note, only
pad-with-patterns.mlir is affected.

Most test updates involve the insertion of masks that were previously
missing — this reflects a correctness fix, not a regression. In all
cases, the added masks are indeed required.

You’ll also notice more repeated constants (`arith.constant 0 : index`),
due to increased use of helper hooks. This will be cleaned up separately
via a constant cache (see #138265 for discussion).

NOTE FOR REVIEWERS
------------------
This is a fairly substantial rewrite. You may find it easier to review
`createWriteOrMaskedWrite` as a new method rather than diffing
line-by-line.

TODOs (future PRs)
------------------
Further alignment of `createWriteOrMaskedWrite` and
`createReadOrMaskedRead`:
  * Move `createWriteOrMaskedWrite` next to `createReadOrMaskedRead` (in
    VectorUtils.cpp)
  * Make `createReadOrMaskedRead` leverage `isMaskTriviallyFoldable`.
  * Extend `isMaskTriviallyFoldable` with value-bounds-analysis. See the
    updated test in transform-vector.mlir for an example that would
    benefit from this.

(* This method will eventually be moved out of Vectorization.cpp, which isn't the right long-term home for it.)
@banach-space banach-space force-pushed the users/banach-space/vector/update_vectorize_insert_slice branch from eccff09 to f0922e9 Compare May 24, 2025 11:49
banach-space added a commit that referenced this pull request May 27, 2025
This patch removes `inputVecSizesForLeadingDims` from the parameter list
of `createWriteOrMaskedWrite`. That argument is unnecessary — vector sizes
can be obtained from the `vecToStore` parameter. Since this doesn't change
behavior or test results, it's marked as NFC.

Additional cleanups:
  * Renamed `vectorToStore` to `vecToStore` for consistency and brevity.
  * Rewrote a conditional at the end of the function to use early exit,
    improving readability:

```cpp
  // BEFORE:
  if (maskingRequried) {
    Value maskForWrite = ...;
    write = maskOperation(write, maskForWrite);
  }
  return write;

  // AFTER
  if (!maskingRequried)
    return write;

  Value maskFroWrite = ...;
  return vector::maskOperation(builder, write, maskForWrite);
```

This change addresses a TODO from #141244.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants