-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][Tensor] Fix Chained tensor.cast canonicalization pattern #113551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash. Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case. Signed-Off-By: Vivek Khandelwal <[email protected]>
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Vivek Khandelwal (vivekkhandelwal1) ChangesThis commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash. Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case. Full diff: https://github.com/llvm/llvm-project/pull/113551.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 603e86ca3d7668..13af1497d3790e 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -434,17 +434,23 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
- joinShapes(joinShapes(sourceType, intermediateType), resultType);
+ joinShapes(sourceType, intermediateType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
+ auto secondJoin = joinShapes(firstJoin, resultType);
+
+ // The join might not exist if the cast sequence would fail at runtime.
+ if (!secondJoin)
+ return failure();
+
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
- if (firstJoin != newJoin)
+ if (secondJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add a test case for this? Maybe something like func @chained_wrong_tensor_cast_avoid_crash
.
I did try to add that, but as I mentioned in the commit that such a case would be cast incompatible, hence it fails during the verify stage itself, it didn't even go into that code. |
You can add a test case disabling the verifier, i.e., with |
I am confused: if the verifier does not allow the IR, then how can this hit a crash? |
Doesn't it mean that the program is invalid or there is a bug in shape inference like things?
+1, I'm confused as well. |
@vivekkhandelwal1 your change itself makes sense to me, but can you just post the IR snippet that cause you to hit this. The verifier issue seems to be a red herring. |
Hi @MaheshRavishankar, I have a repro for this issue from IREE since I was debugging the issue originally there. To repro the issue just run the below command over the following IR: https://gist.github.com/vivekkhandelwal1/c587a207562fe3daa5a79b1526362db4
On running this, you will see a crash which would atleast "justify" the logic behind why this change is important and required. |
This commit fixes the bug with the chained tensor.cast canonicalization pattern. When the sourceType and itermediateType both contains a dim which is static and not equal then the joinShapes utility returns a null value. And, this null value during the next call to the joinShapes utility results in a crash.
Although, this instance of tensor.cast is invalid since the operand shape and result shape are incompatible but in any case the code should not crash, and this commit particularly fixes this kind of case.