Skip to content

[mlir][vector] Use notifyMatchFailure instead of assert in VectorLinearize #93590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2024

Conversation

akroviakov
Copy link
Contributor

As it was suggested, the assert is replaced by notifyMatchFailure for improved consistency.

@llvmbot
Copy link
Member

llvmbot commented May 28, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Artem Kroviakov (akroviakov)

Changes

As it was suggested, the assert is replaced by notifyMatchFailure for improved consistency.


Full diff: https://github.com/llvm/llvm-project/pull/93590.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+17-13)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 156bf742f6297..840fd384894df 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
   matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (extractOp.getVector().getType().isScalable() ||
+        cast<VectorType>(dstType).isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -265,10 +266,11 @@ struct LinearizeVectorShuffle final
   matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstType = getTypeConverter()->convertType(shuffleOp.getType());
-    assert(!(shuffleOp.getV1VectorType().isScalable() ||
-             shuffleOp.getV2VectorType().isScalable() ||
-             cast<VectorType>(dstType).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (shuffleOp.getV1VectorType().isScalable() ||
+        shuffleOp.getV2VectorType().isScalable() ||
+        cast<VectorType>(dstType).isScalable())
+      return rewriter.notifyMatchFailure(shuffleOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(shuffleOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           shuffleOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -336,9 +338,10 @@ struct LinearizeVectorExtract final
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(extractOp.getType());
-    assert(!(extractOp.getVector().getType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (extractOp.getVector().getType().isScalable() ||
+        cast<VectorType>(dstTy).isScalable())
+      return rewriter.notifyMatchFailure(extractOp,
+                                         "scalable vectors are not supported.");
     if (!isLessThanTargetBitWidth(extractOp, targetVectorBitWidth))
       return rewriter.notifyMatchFailure(
           extractOp, "Can't flatten since targetBitWidth <= OpSize");
@@ -395,9 +398,10 @@ struct LinearizeVectorInsert final
   matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
     Type dstTy = getTypeConverter()->convertType(insertOp.getDestVectorType());
-    assert(!(insertOp.getDestVectorType().isScalable() ||
-             cast<VectorType>(dstTy).isScalable()) &&
-           "scalable vectors are not supported.");
+    if (insertOp.getDestVectorType().isScalable() ||
+        cast<VectorType>(dstTy).isScalable())
+      return rewriter.notifyMatchFailure(insertOp,
+                                         "scalable vectors are not supported.");
 
     if (!isLessThanOrEqualTargetBitWidth(insertOp.getSourceType(),
                                          targetVectorBitWidth))

@Garra1980
Copy link

cc @Hardcode84

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my suggestion so swiftly!

Could you also add some basic tests to show that these patterns do not trigger for scalable vectors?

Comment on lines 268 to 269
assert(!(shuffleOp.getV1VectorType().isScalable() ||
shuffleOp.getV2VectorType().isScalable() ||
cast<VectorType>(dstType).isScalable()) &&
"scalable vectors are not supported.");
if (shuffleOp.getV1VectorType().isScalable() ||
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector.shuffle does not support scalable vectors, so keeping an assert should be fine for this one:

A comment explaining the rationale for using assert rather than notifyMatchFailure would be welcome :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on test coverage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be addressed now

@@ -152,9 +152,10 @@ struct LinearizeVectorExtractStridedSlice final
matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type dstType = getTypeConverter()->convertType(extractOp.getType());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we just cast dstType to ShapedType here? I think it carries more information/methods v.s. Type. And you don't need to cast it in the below if condition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please provide more detail on how ShapedType would not need a cast for cast<VectorType>(dstType).isScalable()? AFAIK ShapedType has no isScalable(), I do not see other places in the pattern where we could use ShapedType's information/methods that are not provided by Type.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion, what I meant is VectorType. So it could either be

VectorType dstType = getTypeConverter()->convertType(extractOp.getType());

or

auto dstType = cast<VectorType>(getTypeConverter()->convertType(extractOp.getType()));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can actually do getTypeConverter()->convertType<VectorType>(...). Also, it's better to check convertType result for null.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification, should now be addressed

@akroviakov akroviakov force-pushed the refactor_vector_linearize branch 2 times, most recently from 38689e2 to af59df9 Compare June 6, 2024 12:24
cast<VectorType>(dstTy).isScalable()) &&
"scalable vectors are not supported.");
if (extractOp.getVector().getType().isScalable() ||
cast<VectorType>(dstTy).isScalable())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the dstTy here is not always a vector type. It could be a scalar type too.
e.g., vector.extract %1 [0, 0]: f32 from vector<1024x1024xf32>. So, cast(dstTy) may cause the crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you say the following is better?

if ( (auto vecDstTy = cast<VectorType>(dstTy) && vecDstTy.isScalable()) || extractOp.getVector().getType().isScalable() )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, cast(dstTy) may cause the crash.

Good point, but let's stick to one change per PR 😅 My recommendation:

  1. identify a test case that would indeed crash,
  2. fix the crash and use the test from 1. for a follow-up PR.

WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the update here?

Copy link
Contributor Author

@akroviakov akroviakov Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test case is there:

%0 = vector.extract %arg0[1,1,1]: f32 from vector<2x8x2xf32>

but fixing it can get a bit tricky because right now there are a lot of vector result assumptions (e.g., isLessThanTargetBitWidth(), populateVectorLinearizeShuffleLikeOpsPatterns()), so yes, it should be another PR.
Any suggestions to make the fix least invasive are welcomed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions to make the fix least invasive are welcomed.

You should disable this pattern when the rank of the output is <= 1.

"scalable vectors are not supported.");
VectorType dstTy = getTypeConverter()->convertType<VectorType>(
insertOp.getDestVectorType());
assert(dstTy && "vector type destination expected.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comments as above regarding to assert.

cast<VectorType>(dstTy).isScalable()) &&
"scalable vectors are not supported.");
if (extractOp.getVector().getType().isScalable() ||
cast<VectorType>(dstTy).isScalable())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, cast(dstTy) may cause the crash.

Good point, but let's stick to one change per PR 😅 My recommendation:

  1. identify a test case that would indeed crash,
  2. fix the crash and use the test from 1. for a follow-up PR.

WDYT?

// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I introduced this "function" to complement test_linearize:

func.func @test_linearize(%arg0: vector<2x2xf32>) -> vector<2x2xf32> {
// DEFAULT: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
// DEFAULT: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
// DEFAULT: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-128: %[[ARG:.*]] = vector.shape_cast %[[ORIG_ARG]] : vector<2x2xf32> to vector<4xf32>
// BW-128: %[[CST:.*]] = arith.constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
// BW-128: %[[RES:.*]] = vector.shape_cast %[[CST]] : vector<4xf32> to vector<2x2xf32>
// BW-0: %[[RES:.*]] = arith.constant dense<{{.*}}> : vector<2x2xf32>
%0 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
// DEFAULT: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
// BW-128: %{{.*}} = math.sin %[[ARG]] : vector<4xf32>
// BW-0: %{{.*}} = math.sin %{{.*}} : vector<2x2xf32>
%1 = math.sin %arg0 : vector<2x2xf32>
// DEFAULT: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
// BW-128: %{{.*}} = arith.addf %[[ARG]], %[[CST]] : vector<4xf32>
// BW-0: %{{.*}} = arith.addf %{{.*}} : vector<2x2xf32>
%2 = arith.addf %arg0, %0 : vector<2x2xf32>
// ALL: return %[[RES]] : vector<2x2xf32>
return %0 : vector<2x2xf32>
}
. This way, for a set of tests for "fixed width" vectors it's quite easy to find to corresponding tests for "scalable vectors". I should've moved it next to test_linearize to make this clear - that's my bad, sorry for that!

With this in mind, would you be OK writing:

  • @test_extract_strided_slice_1_scalable,
  • `@test_extract_strided_slice_2_scalable,

and so on? The check lines could be as simple as:

// CHECK-LABEL:
// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.shape_cast
// CHECK: vector.extract_strided_slice

(as in, the main thing to check would be that e.g. vector.shuffle Ops are not inserted).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope it is addressed now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to the function signature should be reverted.

@akroviakov akroviakov force-pushed the refactor_vector_linearize branch from af59df9 to b1b0384 Compare June 12, 2024 13:33
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>) -> vector<[2]x[2]xf32> {
// ALL-SAME: %[[VAL_0:.*]]: vector<[2]x[2]xf32>, %[[VAL_1:.*]]: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
func.func @test_scalable_no_linearize(%arg0: vector<[2]x[2]xf32>, %arg1: vector<2x[2]xf32>) -> vector<[2]x[2]xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to the function signature should be reverted.

@@ -246,6 +257,16 @@ func.func @test_vector_extract(%arg0: vector<2x8x2xf32>) -> vector<8x2xf32> {
return %0 : vector<8x2xf32>
}

// ALL-LABEL: func.func @test_vector_extract_scalable(
// ALL-SAME: %[[VAL_0:.*]]: vector<2x[2]xf32>) -> f32 {
func.func @test_vector_extract_scalable(%arg1: vector<2x[2]xf32>) -> f32 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are types inside this tens and @test_vector_extract different? Is this in any way significant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the types for scalable tests to closely resemble normal ones.

@akroviakov akroviakov force-pushed the refactor_vector_linearize branch from b1b0384 to ce9522c Compare June 13, 2024 09:26
cast<VectorType>(dstTy).isScalable()) &&
"scalable vectors are not supported.");
if (extractOp.getVector().getType().isScalable() ||
cast<VectorType>(dstTy).isScalable())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the update here?

@akroviakov akroviakov force-pushed the refactor_vector_linearize branch from ce9522c to 3392e43 Compare June 13, 2024 19:47
@akroviakov
Copy link
Contributor Author

Any further notes or can it be merged?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for addressing my comments and for following up on this 🙏🏻

@akroviakov akroviakov force-pushed the refactor_vector_linearize branch from 3392e43 to 71cfc85 Compare June 20, 2024 16:35
@chencha3 chencha3 merged commit 74a105a into llvm:main Jun 21, 2024
7 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
…arize (llvm#93590)

As it was [suggested](llvm#92370 (comment)), the `assert` is replaced by `notifyMatchFailure` for improved consistency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants