-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth #108705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Yingwei Zheng (dtcxzyw) ChangesConsider the following case:
After #104606, we shrink the lshr into:
It is incorrect since Alive2: https://alive2.llvm.org/ce/z/j_RmTa Full diff: https://github.com/llvm/llvm-project/pull/108705.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
index d7afe2f426d392..58701bfa60a33e 100644
--- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
+++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
- // Check that the expression overall uses at most the same number of bits as
- // ZExted
- KnownBits KB = computeKnownBits(&I, *DL);
- if (KB.countMaxActiveBits() > BW)
- return false;
+ if (I.getOpcode() == Instruction::LShr) {
+ // Check that the shift amount is less than the number of bits in the
+ // smaller type. Otherwise, the smaller lshr will return a poison value.
+ KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
+ if (ShAmtKB.getMaxValue().uge(BW))
+ return false;
+ } else {
+ // Check that the expression overall uses at most the same number of bits as
+ // ZExted
+ KnownBits KB = computeKnownBits(&I, *DL);
+ if (KB.countMaxActiveBits() > BW)
+ return false;
+ }
// Calculate costs of leaving current IR as it is and moving ZExt operation
// later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
return false;
// Check if we can propagate ZExt through its other users
- KB = computeKnownBits(UI, *DL);
+ KnownBits KB = computeKnownBits(UI, *DL);
if (KB.countMaxActiveBits() > BW)
return false;
diff --git a/llvm/test/Transforms/VectorCombine/X86/pr108698.ll b/llvm/test/Transforms/VectorCombine/X86/pr108698.ll
new file mode 100644
index 00000000000000..675cf6ed7da51f
--- /dev/null
+++ b/llvm/test/Transforms/VectorCombine/X86/pr108698.ll
@@ -0,0 +1,16 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=vector-combine -S -mtriple=x86_64 | FileCheck %s
+
+define <2 x i32> @test(<2 x i64> %x, <2 x i32> %y) {
+; CHECK-LABEL: define <2 x i32> @test(
+; CHECK-SAME: <2 x i64> [[X:%.*]], <2 x i32> [[Y:%.*]]) {
+; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X]], zeroinitializer
+; CHECK-NEXT: [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
+; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y]]
+; CHECK-NEXT: ret <2 x i32> [[LSHR]]
+;
+ %cmp = icmp eq <2 x i64> %x, zeroinitializer
+ %ext = zext <2 x i1> %cmp to <2 x i32>
+ %lshr = lshr <2 x i32> %ext, %y
+ ret <2 x i32> %lshr
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
3a11412
to
fe91054
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - cheers
Consider the following case:
After #104606, we shrink the lshr into:
It is incorrect since
lshr i1 X, 1
returnspoison
.This patch adds additional check on the shamt operand. The lshr will get shrunk iff we ensure that the shamt is less than bitwidth of the smaller type. As
computeKnownBits(&I, *DL).countMaxActiveBits() > BW
always evaluates to true forlshr(zext(X), Y)
, this check will only apply to bitwise logical instructions.Alive2: https://alive2.llvm.org/ce/z/j_RmTa
Fixes #108698.