-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] [vector] Add linearization pattern for vector.create_mask #138214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: Nishant Patel (nbpatel) ChangesThis PR is a breakdown [3 / 4] of the PR #136193 Full diff: https://github.com/llvm/llvm-project/pull/138214.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index b9cef003fa365..cdd937eed6569 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final
}
};
+/// This pattern converts the CreateMaskOp to work on a
+/// linearized vector. The pattern currently
+/// supports only 2D masks with a unit outer dimension.
+/// Following,
+/// vector.create_mask %dims : vector<1x4xi1>
+/// is converted to:
+/// %out_1d = vector.create_mask %dims : vector<4xi1>
+/// %out_nd = vector.shape_cast %out_1d : vector<4xi1> to vector<1x4xi1>
+struct LinearizeVectorCreateMask final
+ : OpConversionPattern<vector::CreateMaskOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LinearizeVectorCreateMask(const TypeConverter &typeConverter,
+ MLIRContext *context, PatternBenefit benefit = 1)
+ : OpConversionPattern(typeConverter, context, benefit) {}
+
+ LogicalResult
+ matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto srcTy = createMaskOp.getType();
+ auto srcShape = srcTy.getShape();
+ if (srcShape.size() != 2)
+ return rewriter.notifyMatchFailure(createMaskOp,
+ "only 2D mask is supported.");
+
+ if (srcShape[0] != 1)
+ return rewriter.notifyMatchFailure(
+ createMaskOp, "only unit outer dimension is supported.");
+
+ auto dstTy = getTypeConverter()->convertType(srcTy);
+ if (!dstTy)
+ return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type.");
+
+ // Compare the first operand with 0. If it's less than or equal to 0,
+ // create a zero mask, else strip the first operand and create a mask
+ // using the second operand.
+ auto firstOperand = adaptor.getOperands().front();
+ auto zero =
+ rewriter.create<mlir::arith::ConstantIndexOp>(createMaskOp.getLoc(), 0);
+ auto isZeroOrNegative = rewriter.create<mlir::arith::CmpIOp>(
+ createMaskOp.getLoc(), mlir::arith::CmpIPredicate::sle, firstOperand,
+ zero);
+ auto isZeroOrNegativeSplat = rewriter.create<mlir::vector::SplatOp>(
+ createMaskOp.getLoc(), dstTy, isZeroOrNegative);
+
+ // Use a select operation to choose between the masks.
+ auto zeroMask = rewriter.create<mlir::arith::ConstantOp>(
+ createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy));
+ auto newMask = rewriter.create<mlir::vector::CreateMaskOp>(
+ createMaskOp.getLoc(), dstTy, adaptor.getOperands().back());
+ auto result = rewriter.create<mlir::arith::SelectOp>(
+ createMaskOp.getLoc(), isZeroOrNegativeSplat, zeroMask, newMask);
+
+ rewriter.replaceOp(createMaskOp, result.getResult());
+ return success();
+ }
+};
+
} // namespace
/// Return true if the operation `op` does not support scalable vectors and
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter,
void mlir::vector::populateVectorLinearizeBasePatterns(
const TypeConverter &typeConverter, const ConversionTarget &target,
RewritePatternSet &patterns) {
- patterns.add<LinearizeConstantLike, LinearizeVectorizable,
- LinearizeVectorBitCast, LinearizeVectorSplat>(
- typeConverter, patterns.getContext());
+ patterns
+ .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
+ LinearizeVectorSplat, LinearizeVectorCreateMask>(
+ typeConverter, patterns.getContext());
}
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 01ad1ac48b012..01872426c77bb 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -345,3 +345,41 @@ func.func @linearize_scalable_vector_splat(%arg0: i32) -> vector<4x[2]xi32> {
%0 = vector.splat %arg0 : vector<4x[2]xi32>
return %0 : vector<4x[2]xi32>
}
+
+// -----
+// ALL-LABEL: linearize_create_mask
+func.func @linearize_create_mask() -> vector<1x16xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<16xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<16xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<16xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<16xi1>, vector<16xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<16xi1> to vector<1x16xi1>
+ // CHECK: return %[[CAST]] : vector<1x16xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x16xi1>
+ return %0 : vector<1x16xi1>
+}
+
+// -----
+// ALL-LABEL: linearize_scalable_create_mask
+func.func @linearize_scalable_create_mask() -> vector<1x[16]xi1> {
+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
+ // CHECK: %[[C20:.*]] = arith.constant 20 : index
+ // CHECK: %[[C0_0:.*]] = arith.constant 0 : index
+ // CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index
+ // CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1>
+ // CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1>
+ // CHECK: %[[MASK_1D:.*]] = vector.create_mask %[[C20]] : vector<[16]xi1>
+ // CHECK: %[[SELECT:.*]] = arith.select %[[SPLAT]], %[[CST]], %[[MASK_1D]] : vector<[16]xi1>, vector<[16]xi1>
+ // CHECK: %[[CAST:.*]] = vector.shape_cast %[[SELECT]] : vector<[16]xi1> to vector<1x[16]xi1>
+ // CHECK: return %[[CAST]] : vector<1x[16]xi1>
+ %c0 = arith.constant 0 : index
+ %c20 = arith.constant 20 : index
+ %0 = vector.create_mask %c0, %c20 : vector<1x[16]xi1>
+ return %0 : vector<1x[16]xi1>
+}
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index eda2594fbc7c7..2d5e90908d4d0 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -973,7 +973,8 @@ struct TestVectorLinearize final
return "Linearizes ND vectors for N >= 2 into 1D vectors";
}
void getDependentDialects(DialectRegistry ®istry) const override {
- registry.insert<vector::VectorDialect>();
+ registry.insert<vector::VectorDialect, memref::MemRefDialect,
+ arith::ArithDialect>();
}
void runOnOperation() override {
|
Ping for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks! Mostly some folding needed.
createMaskOp.getLoc(), dstTy, rewriter.getZeroAttr(dstTy)); | ||
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>( | ||
createMaskOp.getLoc(), dstTy, adaptor.getOperands().back()); | ||
auto result = rewriter.create<mlir::arith::SelectOp>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may want to use some createAndFold
here to get rid of redundant IR...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for updating the test @nbpatel to use non-constant operands.
nit: my guess is that most of the time, the unit dimension will have extent which is the constant 1, and so createOrFold
will still be advisable.
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index | ||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> | ||
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
most of these checks should be folded with createAndFold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, use an %arg0: index
instead of %c0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clear PR!
// CHECK: %[[C0_0:.*]] = arith.constant 0 : index | ||
// CHECK: %[[CMP:.*]] = arith.cmpi sle, %[[C0]], %[[C0_0]] : index | ||
// CHECK: %[[SPLAT:.*]] = vector.splat %[[CMP]] : vector<[16]xi1> | ||
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<[16]xi1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, use an %arg0: index
instead of %c0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing the comments. A couple more:
/// %index = arith.index_cast %cmpi : i1 to index | ||
/// %mul = arith.muli %index, %arg1 : index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this mul looks like and bitwise 'and' operation to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, good catch. Thanks
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, LGTM. I've added 2 minor suggestions but I'm happy for this to land as is.
It looks like all review comments have been addressed, but I'll give other reviewers some more time to comment, otherwise I'll land this at the end of week if that's ok.
|
||
// ----- | ||
// CHECK-LABEL: linearize_scalable_create_mask | ||
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index) -> vector<1x[16]xi1> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it is better to not capture arguments (ARG0 and and ARG1) if they are not subsequently used. Thanks for reducing the size of this test! In fact I think it could reduced even further to
// This test is the same as linearize_create_mask
but with a scalable mask. Confirms that it linearizes.
// CHECK-LABEL: linearize_scalable_create_mask
// CHECK: vector.create_mask {{.*}} vector<[16]xi1>
Location loc = createMaskOp.getLoc(); | ||
VectorType srcTy = createMaskOp.getType(); | ||
auto srcShape = srcTy.getShape(); | ||
if (srcShape.size() != 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding
// FIXME: add support for any vector with at most 1 non-unit dimension (like vector<1x4x1xi1>)
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
…m#138214) This PR is a breakdown [3 / 4] of the PR llvm#136193 The PR adds linearization patterns for vector.create_mask
This PR is a breakdown [3 / 4] of the PR #136193
The PR adds linearization patterns for vector.create_mask