-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir] [vector] Add linearization pattern for vector.create_mask #138214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 8 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
3a83e2d
Add linearization pattern for vector.create_mask
nbpatel 203ec82
Merge branch 'main' into vector_linearize_create_mask
nbpatel 24b0739
Use CHECKS
nbpatel 2b8a653
Add test case for scalable vector
nbpatel 8e8de7a
Clean up
nbpatel 528f913
Address Feedback
nbpatel c2c1a22
Replace select with mul
nbpatel c5b2e81
Fix typo
nbpatel 8fca9c1
Address comments
nbpatel 2b7e06e
Fix doc
nbpatel db484c9
Clang-format
nbpatel 7b714ce
Merge branch 'main' into vector_linearize_create_mask
nbpatel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -445,6 +445,64 @@ struct LinearizeVectorSplat final | |
} | ||
}; | ||
|
||
/// This pattern converts the CreateMaskOp to work on a linearized vector. | ||
/// It currently supports only 2D masks with a unit outer dimension. | ||
/// Following, | ||
/// vector.create_mask %arg0, %arg1 : vector<1x4xi1> | ||
/// is converted to: | ||
/// %zero = arith.constant 0 : index | ||
/// %cmpi = arith.cmpi sgt, %arg0, %zero : index | ||
/// %index = arith.index_cast %cmpi : i1 to index | ||
/// %mul = arith.muli %index, %arg1 : index | ||
/// %mask = vector.create_mask %mul : vector<4xi1> | ||
/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> | ||
struct LinearizeVectorCreateMask final | ||
: OpConversionPattern<vector::CreateMaskOp> { | ||
using OpConversionPattern::OpConversionPattern; | ||
|
||
LinearizeVectorCreateMask(const TypeConverter &typeConverter, | ||
MLIRContext *context, PatternBenefit benefit = 1) | ||
: OpConversionPattern(typeConverter, context, benefit) {} | ||
|
||
LogicalResult | ||
matchAndRewrite(vector::CreateMaskOp createMaskOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Location loc = createMaskOp.getLoc(); | ||
VectorType srcTy = createMaskOp.getType(); | ||
auto srcShape = srcTy.getShape(); | ||
if (srcShape.size() != 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding |
||
return rewriter.notifyMatchFailure(createMaskOp, | ||
"only 2D mask is supported."); | ||
|
||
if (srcShape[0] != 1) | ||
return rewriter.notifyMatchFailure( | ||
createMaskOp, "only unit outer dimension is supported."); | ||
|
||
auto dstTy = getTypeConverter()->convertType(srcTy); | ||
if (!dstTy) | ||
return rewriter.notifyMatchFailure(createMaskOp, "cannot convert type."); | ||
|
||
// Compare the first operand with 0. If it is greater than 0, the | ||
// corresponding mask element is set to true, otherwise false. | ||
// The result of the comparison is then multiplied with | ||
// the second operand of create_mask to get the 1D mask. | ||
auto firstOperand = adaptor.getOperands().front(); | ||
auto zero = rewriter.create<mlir::arith::ConstantIndexOp>(loc, 0); | ||
auto isNonZero = rewriter.create<mlir::arith::CmpIOp>( | ||
loc, mlir::arith::CmpIPredicate::sgt, firstOperand, zero); | ||
auto isNonZeroIndex = rewriter.create<mlir::arith::IndexCastOp>( | ||
loc, rewriter.getIndexType(), isNonZero); | ||
auto secondOperand = adaptor.getOperands().back(); | ||
auto maskSize = rewriter.create<mlir::arith::MulIOp>( | ||
loc, rewriter.getIndexType(), isNonZeroIndex, secondOperand); | ||
|
||
auto newMask = rewriter.create<mlir::vector::CreateMaskOp>( | ||
loc, dstTy, maskSize.getResult()); | ||
nbpatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rewriter.replaceOp(createMaskOp, newMask); | ||
return success(); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
/// Return true if the operation `op` does not support scalable vectors and | ||
|
@@ -530,9 +588,10 @@ void mlir::vector::populateForVectorLinearize(TypeConverter &typeConverter, | |
void mlir::vector::populateVectorLinearizeBasePatterns( | ||
const TypeConverter &typeConverter, const ConversionTarget &target, | ||
RewritePatternSet &patterns) { | ||
patterns.add<LinearizeConstantLike, LinearizeVectorizable, | ||
LinearizeVectorBitCast, LinearizeVectorSplat>( | ||
typeConverter, patterns.getContext()); | ||
patterns | ||
.add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, | ||
LinearizeVectorSplat, LinearizeVectorCreateMask>( | ||
typeConverter, patterns.getContext()); | ||
} | ||
|
||
void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this mul looks like and bitwise 'and' operation to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yes, good catch. Thanks