Skip to content

[MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern #113551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Contributor

This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case.

This commit fixes the bug with the chained tensor.cast canonicalization
pattern. When the sourceType and itermediateType both contains a dim
which is static and not equal then the joinShapes utility returns a
null value. And, this null value during the next call to the joinShapes
utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape
and result shape are incompatible but in any case the code should not crash,
and this commit particularly fixes this kind of case.

Signed-Off-By: Vivek Khandelwal <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Oct 24, 2024

@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: Vivek Khandelwal (vivekkhandelwal1)

Changes

This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash.

Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case.


Full diff: https://github.com/llvm/llvm-project/pull/113551.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+8-2)
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..13af1497d3790e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
     // We can remove the intermediate cast if joining all three produces the
     // same result as just joining the source and result shapes.
     auto firstJoin =
-        joinShapes(joinShapes(sourceType, intermediateType), resultType);
+        joinShapes(sourceType, intermediateType);
 
     // The join might not exist if the cast sequence would fail at runtime.
     if (!firstJoin)
       return failure();
 
+    auto secondJoin = joinShapes(firstJoin, resultType);
+
+    // The join might not exist if the cast sequence would fail at runtime.
+    if (!secondJoin)
+      return failure();
+
     // The newJoin always exists if the above join exists, it might just contain
     // less information. If so, we cannot drop the intermediate cast, as doing
     // so would remove runtime checks.
     auto newJoin = joinShapes(sourceType, resultType);
-    if (firstJoin != newJoin)
+    if (secondJoin != newJoin)
       return failure();
 
     rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,

Copy link

github-actions bot commented Oct 24, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@pashu123 pashu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add a test case for this? Maybe something like func @chained_wrong_tensor_cast_avoid_crash.

@vivekkhandelwal1
Copy link
Contributor Author

Do we want to add a test case for this? Maybe something like func @chained_wrong_tensor_cast_avoid_crash.

I did try to add that, but as I mentioned in the commit that such a case would be cast incompatible, hence it fails during the verify stage itself, it didn't even go into that code.

@pashu123
Copy link
Member

Do we want to add a test case for this? Maybe something like func @chained_wrong_tensor_cast_avoid_crash.

I did try to add that, but as I mentioned in the commit that such a case would be cast incompatible, hence it fails during the verify stage itself, it didn't even go into that code.

You can add a test case disabling the verifier, i.e., with verify-each=0, but I am unsure whether we want to do that. Maybe @MaheshRavishankar or @joker-eph would comment more on this.

@joker-eph
Copy link
Collaborator

joker-eph commented Oct 24, 2024

I am confused: if the verifier does not allow the IR, then how can this hit a crash?

@hanhanW
Copy link
Contributor

hanhanW commented Oct 24, 2024

When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value.

Doesn't it mean that the program is invalid or there is a bug in shape inference like things?

I am confused: if the verifier does not allow the IR, then how can this hit a crash?

+1, I'm confused as well.

@MaheshRavishankar
Copy link
Contributor

@vivekkhandelwal1 your change itself makes sense to me, but can you just post the IR snippet that cause you to hit this. The verifier issue seems to be a red herring.

@vivekkhandelwal1
Copy link
Contributor Author

vivekkhandelwal1 commented Nov 15, 2024

@vivekkhandelwal1 your change itself makes sense to me, but can you just post the IR snippet that cause you to hit this. The verifier issue seems to be a red herring.

Hi @MaheshRavishankar, I have a repro for this issue from IREE since I was debugging the issue originally there. To repro the issue just run the below command over the following IR: https://gist.github.com/vivekkhandelwal1/c587a207562fe3daa5a79b1526362db4

iree-opt  --iree-util-fold-globals repro_ir.mlir

On running this, you will see a crash which would atleast "justify" the logic behind why this change is important and required.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants