Skip to content

Commit fe91054

Browse files
committed
[VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth
1 parent 96f3521 commit fe91054

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
25972597
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
25982598
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
25992599

2600-
// Check that the expression overall uses at most the same number of bits as
2601-
// ZExted
2602-
KnownBits KB = computeKnownBits(&I, *DL);
2603-
if (KB.countMaxActiveBits() > BW)
2604-
return false;
2600+
if (I.getOpcode() == Instruction::LShr) {
2601+
// Check that the shift amount is less than the number of bits in the
2602+
// smaller type. Otherwise, the smaller lshr will return a poison value.
2603+
KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
2604+
if (ShAmtKB.getMaxValue().uge(BW))
2605+
return false;
2606+
} else {
2607+
// Check that the expression overall uses at most the same number of bits as
2608+
// ZExted
2609+
KnownBits KB = computeKnownBits(&I, *DL);
2610+
if (KB.countMaxActiveBits() > BW)
2611+
return false;
2612+
}
26052613

26062614
// Calculate costs of leaving current IR as it is and moving ZExt operation
26072615
// later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
26282636
return false;
26292637

26302638
// Check if we can propagate ZExt through its other users
2631-
KB = computeKnownBits(UI, *DL);
2639+
KnownBits KB = computeKnownBits(UI, *DL);
26322640
if (KB.countMaxActiveBits() > BW)
26332641
return false;
26342642

llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,8 @@ vector.body:
103103
define <2 x i32> @pr108698(<2 x i64> %x, <2 x i32> %y) {
104104
; CHECK-LABEL: @pr108698(
105105
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], zeroinitializer
106-
; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i32> [[Y:%.*]] to <2 x i1>
107-
; CHECK-NEXT: [[TMP2:%.*]] = lshr <2 x i1> [[CMP]], [[TMP1]]
108-
; CHECK-NEXT: [[LSHR:%.*]] = zext <2 x i1> [[TMP2]] to <2 x i32>
106+
; CHECK-NEXT: [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
107+
; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y:%.*]]
109108
; CHECK-NEXT: ret <2 x i32> [[LSHR]]
110109
;
111110
%cmp = icmp eq <2 x i64> %x, zeroinitializer

0 commit comments

Comments
 (0)