-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize #93590
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Artem Kroviakov (akroviakov) ChangesAs it was suggested, the Full diff: https://github.com/llvm/llvm-project/pull/93590.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf742f6297..840fd384894df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstType).isScalable()) &&
- "scalable vectors are not supported.");
+ if (extractOp.getVector().getType().isScalable() ||
+ cast<VectorType>(dstType).isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -265,10 +266,11 @@ struct LinearizeVectorShuffle final
matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
- assert(!(shuffleOp.getV1VectorType().isScalable() ||
- shuffleOp.getV2VectorType().isScalable() ||
- cast<VectorType>(dstType).isScalable()) &&
- "scalable vectors are not supported.");
+ if (shuffleOp.getV1VectorType().isScalable() ||
+ shuffleOp.getV2VectorType().isScalable() ||
+ cast<VectorType>(dstType).isScalable())
+ return rewriter.notifyMatchFailure(shuffleOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -336,9 +338,10 @@ struct LinearizeVectorExtract final
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(extractOp.getType());
- assert(!(extractOp.getVector().getType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ if (extractOp.getVector().getType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable())
+ return rewriter.notifyMatchFailure(extractOp,
+ "scalable vectors are not supported.");
if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
return rewriter.notifyMatchFailure(
extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -395,9 +398,10 @@ struct LinearizeVectorInsert final
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
- assert(!(insertOp.getDestVectorType().isScalable() ||
- cast<VectorType>(dstTy).isScalable()) &&
- "scalable vectors are not supported.");
+ if (insertOp.getDestVectorType().isScalable() ||
+ cast<VectorType>(dstTy).isScalable())
+ return rewriter.notifyMatchFailure(insertOp,
+ "scalable vectors are not supported.");
if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
targetVectorBitWidth))
|
cc @Hardcode84 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my suggestion so swiftly!
Could you also add some basic tests to show that these patterns do not trigger for scalable vectors?
assert(!(shuffleOp.getV1VectorType().isScalable() || | ||
shuffleOp.getV2VectorType().isScalable() || | ||
cast<VectorType>(dstType).isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (shuffleOp.getV1VectorType().isScalable() || |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vector.shuffle
does not support scalable vectors, so keeping an assert should be fine for this one:
A comment explaining the rationale for using assert
rather than notifyMatchFailure
would be welcome :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 on test coverage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be addressed now
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final | |||
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, | |||
ConversionPatternRewriter &rewriter) const override { | |||
Type dstType = getTypeConverter()->convertType(extractOp.getType()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why don't we just cast dstType to ShapedType here? I think it carries more information/methods v.s. Type
. And you don't need to cast it in the below if condition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please provide more detail on how ShapedType would not need a cast for cast<VectorType>(dstType).isScalable()
? AFAIK ShapedType has no isScalable()
, I do not see other places in the pattern where we could use ShapedType's information/methods that are not provided by Type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion, what I meant is VectorType
. So it could either be
VectorType dstType = getTypeConverter()->convertType(extractOp.getType());
or
auto dstType = cast<VectorType>(getTypeConverter()->convertType(extractOp.getType()));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can actually do getTypeConverter()->convertType<VectorType>(...)
. Also, it's better to check convertType
result for null.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification, should now be addressed
38689e2
to
af59df9
Compare
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the dstTy
here is not always a vector type. It could be a scalar type too.
e.g., vector.extract %1 [0, 0]: f32 from vector<1024x1024xf32>
. So, cast(dstTy) may cause the crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you say the following is better?
if ( (auto vecDstTy = cast<VectorType>(dstTy) && vecDstTy.isScalable()) || extractOp.getVector().getType().isScalable() )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, cast(dstTy) may cause the crash.
Good point, but let's stick to one change per PR 😅 My recommendation:
- identify a test case that would indeed crash,
- fix the crash and use the test from 1. for a follow-up PR.
WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the update here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test case is there:
%0 = vector.extract %arg0[1,1,1]: f32 from vector<2x8x2xf32>
but fixing it can get a bit tricky because right now there are a lot of vector result assumptions (e.g., isLessThanTargetBitWidth()
, populateVectorLinearizeShuffleLikeOpsPatterns()
), so yes, it should be another PR.
Any suggestions to make the fix least invasive are welcomed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any suggestions to make the fix least invasive are welcomed.
You should disable this pattern when the rank of the output is <= 1.
"scalable vectors are not supported."); | ||
VectorType dstTy = getTypeConverter()->convertType<VectorType>( | ||
insertOp.getDestVectorType()); | ||
assert(dstTy && "vector type destination expected."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comments as above regarding to assert.
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, cast(dstTy) may cause the crash.
Good point, but let's stick to one change per PR 😅 My recommendation:
- identify a test case that would indeed crash,
- fix the crash and use the test from 1. for a follow-up PR.
WDYT?
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I introduced this "function" to complement test_linearize
:
llvm-project/mlir/test/Dialect/Vector/linearize.mlir
Lines 7 to 31 in e090bac
func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> { | |
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> | |
// DEFAULT: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> | |
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> | |
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32> | |
// BW-128: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32> | |
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32> | |
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32> | |
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32> | |
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> | |
// BW-128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32> | |
// BW-0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32> | |
%1 = math.sin %arg0 : vector<2x2xf32> | |
// DEFAULT: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32> | |
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32> | |
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32> | |
%2 = arith.addf %arg0, %0 : vector<2x2xf32> | |
// ALL: return %[[RES]] : vector<2x2xf32> | |
return %0 : vector<2x2xf32> | |
} |
test_linearize
to make this clear - that's my bad, sorry for that!
With this in mind, would you be OK writing:
@test_extract_strided_slice_1_scalable
,- `@test_extract_strided_slice_2_scalable,
and so on? The check lines could be as simple as:
// CHECK-LABEL:
// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
// CHECK: vector.extract_strided_slice
(as in, the main thing to check would be that e.g. vector.shuffle
Ops are not inserted).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope it is addressed now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to the function signature should be reverted.
af59df9
to
b1b0384
Compare
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> { | ||
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to the function signature should be reverted.
@@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> { | |||
return %0 : vector<8x2xf32> | |||
} | |||
|
|||
// ALL-LABEL: func.func @test_vector_extract_scalable( | |||
// ALL-SAME: %[[VAL_0:.*]]: vector<2x[2]xf32>) -> f32 { | |||
func.func @test_vector_extract_scalable(%arg1: vector<2x[2]xf32>) -> f32 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are types inside this tens and @test_vector_extract
different? Is this in any way significant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed the types for scalable tests to closely resemble normal ones.
b1b0384
to
ce9522c
Compare
cast<VectorType>(dstTy).isScalable()) && | ||
"scalable vectors are not supported."); | ||
if (extractOp.getVector().getType().isScalable() || | ||
cast<VectorType>(dstTy).isScalable()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the update here?
ce9522c
to
3392e43
Compare
Any further notes or can it be merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for addressing my comments and for following up on this 🙏🏻
3392e43
to
71cfc85
Compare
…arize (llvm#93590) As it was [suggested](llvm#92370 (comment)), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
As it was suggested, the
assert
is replaced bynotifyMatchFailure
for improved consistency.