-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize #93590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,10 +151,12 @@ struct LinearizeVectorExtractStridedSlice final | |
LogicalResult | ||
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Type dstType = getTypeConverter()->convertType(extractOp.getType()); | ||
assert(!(extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstType).isScalable()) && | ||
"scalable vectors are not supported."); | ||
VectorType dstType = | ||
getTypeConverter()->convertType<VectorType>(extractOp.getType()); | ||
assert(dstType && "vector type destination expected."); | ||
if (extractOp.getVector().getType().isScalable() || dstType.isScalable()) | ||
return rewriter.notifyMatchFailure(extractOp, | ||
"scalable vectors are not supported."); | ||
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) | ||
return rewriter.notifyMatchFailure( | ||
extractOp, "Can't flatten since targetBitWidth <= OpSize"); | ||
|
@@ -264,10 +266,14 @@ struct LinearizeVectorShuffle final | |
LogicalResult | ||
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Type dstType = getTypeConverter()->convertType(shuffleOp.getType()); | ||
VectorType dstType = | ||
getTypeConverter()->convertType<VectorType>(shuffleOp.getType()); | ||
assert(dstType && "vector type destination expected."); | ||
// The assert is used because vector.shuffle does not support scalable | ||
// vectors. | ||
akroviakov marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert(!(shuffleOp.getV1VectorType().isScalable() || | ||
shuffleOp.getV2VectorType().isScalable() || | ||
cast<VectorType>(dstType).isScalable()) && | ||
dstType.isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth)) | ||
return rewriter.notifyMatchFailure( | ||
|
@@ -336,9 +342,10 @@ struct LinearizeVectorExtract final | |
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Type dstTy = getTypeConverter()->convertType(extractOp.getType()); | ||
assert(!(extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you say the following is better? if ( (auto vecDstTy = cast<VectorType>(dstTy) && vecDstTy.isScalable()) || extractOp.getVector().getType().isScalable() ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good point, but let's stick to one change per PR 😅 My recommendation:
WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the update here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test case is there: %0 = vector.extract %arg0[1,1,1]: f32 from vector<2x8x2xf32> but fixing it can get a bit tricky because right now there are a lot of vector result assumptions (e.g., There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You should disable this pattern when the rank of the output is <= 1. |
||
return rewriter.notifyMatchFailure(extractOp, | ||
"scalable vectors are not supported."); | ||
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth)) | ||
return rewriter.notifyMatchFailure( | ||
extractOp, "Can't flatten since targetBitWidth <= OpSize"); | ||
|
@@ -394,10 +401,12 @@ struct LinearizeVectorInsert final | |
LogicalResult | ||
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType()); | ||
assert(!(insertOp.getDestVectorType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); | ||
VectorType dstTy = getTypeConverter()->convertType<VectorType>( | ||
insertOp.getDestVectorType()); | ||
assert(dstTy && "vector type destination expected."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comments as above regarding to assert. |
||
if (insertOp.getDestVectorType().isScalable() || dstTy.isScalable()) | ||
return rewriter.notifyMatchFailure(insertOp, | ||
"scalable vectors are not supported."); | ||
|
||
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(), | ||
targetVectorBitWidth)) | ||
|
Uh oh!
There was an error while loading. Please reload this page.