Skip to content

[mlir] [vector] Add linearization pattern for vector.create_mask #138214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 14, 2025

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented May 1, 2025

This PR is a breakdown [3 / 4] of the PR #136193
The PR adds linearization patterns for vector.create_mask

@llvmbot
Copy link
Member

llvmbot commented May 1, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Nishant Patel (nbpatel)

Changes

This PR is a breakdown [3 / 4] of the PR #136193
The PR adds linearization patterns for vector.create_mask


Full diff: https://github.com/llvm/llvm-project/pull/138214.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+62-3)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+38)
  • (modified) mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp (+2-1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..cdd937eed6569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
   }
 };
 
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+///   vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+///   %out_1d = vector.create_mask %dims : vector<4xi1>
+///   %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+    : OpConversionPattern<vector::CreateMaskOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+                            MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+
+  LogicalResult
+  matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    auto srcTy = createMaskOp.getType();
+    auto srcShape = srcTy.getShape();
+    if (srcShape.size() != 2)
+      return rewriter.notifyMatchFailure(createMaskOp,
+                                         "only 2D mask is supported.");
+
+    if (srcShape[0] != 1)
+      return rewriter.notifyMatchFailure(
+          createMaskOp, "only unit outer dimension is supported.");
+
+    auto dstTy = getTypeConverter()->convertType(srcTy);
+    if (!dstTy)
+      return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+    // Compare the first operand with 0. If it's less than or equal to 0,
+    // create a zero mask, else strip the first operand and create a mask
+    // using the second operand.
+    auto firstOperand = adaptor.getOperands().front();
+    auto zero =
+        rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+    auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+        createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+        zero);
+    auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+        createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+    // Use a select operation to choose between the masks.
+    auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+        createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+    auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+        createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+    auto result = rewriter.create<mlir::arith::SelectOp>(
+        createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+    rewriter.replaceOp(createMaskOp, result.getResult());
+    return success();
+  }
+};
+
 } // namespace
 
 /// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
 void mlir::vector::populateVectorLinearizeBasePatterns(
     const TypeConverter &typeConverter, const ConversionTarget &target,
     RewritePatternSet &patterns) {
-  patterns.add<LinearizeConstantLike, LinearizeVectorizable,
-               LinearizeVectorBitCast, LinearizeVectorSplat>(
-      typeConverter, patterns.getContext());
+  patterns
+      .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+           LinearizeVectorSplat, LinearizeVectorCreateMask>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..01872426c77bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -345,3 +345,41 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
   %0 = vector.splat %arg0 : vector<4x[2]xi32>
   return %0 : vector<4x[2]xi32>
 }
+
+// -----
+// ALL-LABEL: linearize_create_mask
+func.func @linearize_create_mask() -> vector<1x16xi1> {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C20:.*]] = arith.constant 20 : index
+  // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+  // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+  // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+  // CHECK: return %[[CAST]] : vector<1x16xi1>
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+  return %0 : vector<1x16xi1>
+}
+
+// -----
+// ALL-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
+  // CHECK: %[[C0:.*]] = arith.constant 0 : index
+  // CHECK: %[[C20:.*]] = arith.constant 20 : index
+  // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+  // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+  // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
+  // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
+  // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+  // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
+  // CHECK: return %[[CAST]] : vector<1x[16]xi1>
+  %c0 = arith.constant 0 : index
+  %c20 = arith.constant 20 : index
+  %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+  return %0 : vector<1x[16]xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..2d5e90908d4d0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
     return "Linearizes ND vectors for N >= 2 into 1D vectors";
   }
   void getDependentDialects(DialectRegistry &registry) const override {
-    registry.insert<vector::VectorDialect>();
+    registry.insert<vector::VectorDialect, memref::MemRefDialect,
+                    arith::ArithDialect>();
   }
 
   void runOnOperation() override {

@nbpatel
Copy link
Contributor Author

nbpatel commented May 1, 2025

@Hardcode84

@nbpatel
Copy link
Contributor Author

nbpatel commented May 6, 2025

Ping for review

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Mostly some folding needed.

createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
auto result = rewriter.create<mlir::arith::SelectOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may want to use some createAndFold here to get rid of redundant IR...

Copy link
Contributor

@newling newling May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating the test @nbpatel to use non-constant operands.

nit: my guess is that most of the time, the unit dimension will have extent which is the constant 1, and so createOrFold will still be advisable.

// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

most of these checks should be folded with createAndFold

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, use an %arg0: index instead of %c0

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clear PR!

// CHECK: %[[C0_0:.*]] = arith.constant 0 : index
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, use an %arg0: index instead of %c0

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing the comments. A couple more:

Comment on lines 455 to 456
/// %index = arith.index_cast %cmpi : i1 to index
/// %mul = arith.muli %index, %arg1 : index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this mul looks like and bitwise 'and' operation to me?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yes, good catch. Thanks

@nbpatel
Copy link
Contributor Author

nbpatel commented May 13, 2025

@newling @dcaballe addressed the comments. Thanks for the feedback.

Copy link

github-actions bot commented May 13, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM. I've added 2 minor suggestions but I'm happy for this to land as is.

It looks like all review comments have been addressed, but I'll give other reviewers some more time to comment, otherwise I'll land this at the end of week if that's ok.


// -----
// CHECK-LABEL: linearize_scalable_create_mask
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x[16]xi1>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it is better to not capture arguments (ARG0 and and ARG1) if they are not subsequently used. Thanks for reducing the size of this test! In fact I think it could reduced even further to

// This test is the same as linearize_create_mask but with a scalable mask. Confirms that it linearizes.
// CHECK-LABEL: linearize_scalable_create_mask
// CHECK: vector.create_mask {{.*}} vector<[16]xi1>

Location loc = createMaskOp.getLoc();
VectorType srcTy = createMaskOp.getType();
auto srcShape = srcTy.getShape();
if (srcShape.size() != 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding
// FIXME: add support for any vector with at most 1 non-unit dimension (like vector<1x4x1xi1>)
here

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@nbpatel
Copy link
Contributor Author

nbpatel commented May 14, 2025

@dcaballe or @newling can one of you help me land this?

@newling
Copy link
Contributor

newling commented May 14, 2025

@dcaballe or @newling can one of you help me land this?

Will do 👍

@newling newling merged commit 1778d3b into llvm:main May 14, 2025
11 checks passed
TIFitis pushed a commit to TIFitis/llvm-project that referenced this pull request May 19, 2025
…m#138214)

This PR is a breakdown [3 / 4] of the PR llvm#136193 
The PR adds linearization patterns for vector.create_mask
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants