Skip to content

Commit 87663fd

Browse files
authored
[VectorCombine] Don't shrink lshr if the shamt is not less than bitwidth (#108705)
Consider the following case: ``` define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) { %19 = icmp eq <2 x i64> %vec.ind16, zeroinitializer %20 = zext <2 x i1> %19 to <2 x i32> %21 = lshr <2 x i32> %20, %broadcast.splat20 ret <2 x i32> %21 } ``` After #104606, we shrink the lshr into: ``` define <2 x i32> @test(<2 x i64> %vec.ind16, <2 x i32> %broadcast.splat20) { %1 = icmp eq <2 x i64> %vec.ind16, zeroinitializer %2 = trunc <2 x i32> %broadcast.splat20 to <2 x i1> %3 = lshr <2 x i1> %1, %2 %4 = zext <2 x i1> %3 to <2 x i32> ret <2 x i32> %4 } ``` It is incorrect since `lshr i1 X, 1` returns `poison`. This patch adds additional check on the shamt operand. The lshr will get shrunk iff we ensure that the shamt is less than bitwidth of the smaller type. As `computeKnownBits(&I, *DL).countMaxActiveBits() > BW` always evaluates to true for `lshr(zext(X), Y)`, this check will only apply to bitwise logical instructions. Alive2: https://alive2.llvm.org/ce/z/j_RmTa Fixes #108698.
1 parent ba8e424 commit 87663fd

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2597,11 +2597,19 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
25972597
auto *SmallTy = cast<FixedVectorType>(ZExted->getType());
25982598
unsigned BW = SmallTy->getElementType()->getPrimitiveSizeInBits();
25992599

2600-
// Check that the expression overall uses at most the same number of bits as
2601-
// ZExted
2602-
KnownBits KB = computeKnownBits(&I, *DL);
2603-
if (KB.countMaxActiveBits() > BW)
2604-
return false;
2600+
if (I.getOpcode() == Instruction::LShr) {
2601+
// Check that the shift amount is less than the number of bits in the
2602+
// smaller type. Otherwise, the smaller lshr will return a poison value.
2603+
KnownBits ShAmtKB = computeKnownBits(I.getOperand(1), *DL);
2604+
if (ShAmtKB.getMaxValue().uge(BW))
2605+
return false;
2606+
} else {
2607+
// Check that the expression overall uses at most the same number of bits as
2608+
// ZExted
2609+
KnownBits KB = computeKnownBits(&I, *DL);
2610+
if (KB.countMaxActiveBits() > BW)
2611+
return false;
2612+
}
26052613

26062614
// Calculate costs of leaving current IR as it is and moving ZExt operation
26072615
// later, along with adding truncates if needed
@@ -2628,7 +2636,7 @@ bool VectorCombine::shrinkType(llvm::Instruction &I) {
26282636
return false;
26292637

26302638
// Check if we can propagate ZExt through its other users
2631-
KB = computeKnownBits(UI, *DL);
2639+
KnownBits KB = computeKnownBits(UI, *DL);
26322640
if (KB.countMaxActiveBits() > BW)
26332641
return false;
26342642

llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,17 @@ vector.body:
100100
ret i32 %2
101101
}
102102

103+
define <2 x i32> @pr108698(<2 x i64> %x, <2 x i32> %y) {
104+
; CHECK-LABEL: @pr108698(
105+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i64> [[X:%.*]], zeroinitializer
106+
; CHECK-NEXT: [[EXT:%.*]] = zext <2 x i1> [[CMP]] to <2 x i32>
107+
; CHECK-NEXT: [[LSHR:%.*]] = lshr <2 x i32> [[EXT]], [[Y:%.*]]
108+
; CHECK-NEXT: ret <2 x i32> [[LSHR]]
109+
;
110+
%cmp = icmp eq <2 x i64> %x, zeroinitializer
111+
%ext = zext <2 x i1> %cmp to <2 x i32>
112+
%lshr = lshr <2 x i32> %ext, %y
113+
ret <2 x i32> %lshr
114+
}
115+
103116
declare i32 @llvm.vector.reduce.add.v16i32(<16 x i32>)

0 commit comments

Comments
 (0)