Skip to content

Commit 781d077

Browse files
committed
[InstCombine] reassociateShiftAmtsOfTwoSameDirectionShifts(): fix miscompile (PR44802)
As input, we have the following pattern: Sh0 (Sh1 X, Q), K We want to rewrite that as: Sh x, (Q+K) iff (Q+K) u< bitwidth(x) While we know that originally (Q+K) would not overflow (because 2 * (N-1) u<= iN -1), we may have looked past extensions of shift amounts. so it may now overflow in smaller bitwidth. To ensure that does not happen, we need to ensure that the total maximal shift amount is still representable in that smaller bitwidth. If the overflow would happen, (Q+K) u< bitwidth(x) check would be bogus. https://bugs.llvm.org/show_bug.cgi?id=44802
1 parent 425ef99 commit 781d077

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

+22-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@ using namespace PatternMatch;
2323
// Given pattern:
2424
// (x shiftopcode Q) shiftopcode K
2525
// we should rewrite it as
26-
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x)
27-
// This is valid for any shift, but they must be identical.
26+
// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and
27+
//
28+
// This is valid for any shift, but they must be identical, and we must be
29+
// careful in case we have (zext(Q)+zext(K)) and look past extensions,
30+
// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus.
2831
//
2932
// AnalyzeForSignBitExtraction indicates that we will only analyze whether this
3033
// pattern has any 2 right-shifts that sum to 1 less than original bit width.
@@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts(
5861
if (ShAmt0->getType() != ShAmt1->getType())
5962
return nullptr;
6063

64+
// As input, we have the following pattern:
65+
// Sh0 (Sh1 X, Q), K
66+
// We want to rewrite that as:
67+
// Sh x, (Q+K) iff (Q+K) u< bitwidth(x)
68+
// While we know that originally (Q+K) would not overflow
69+
// (because 2 * (N-1) u<= iN -1), we have looked past extensions of
70+
// shift amounts. so it may now overflow in smaller bitwidth.
71+
// To ensure that does not happen, we need to ensure that the total maximal
72+
// shift amount is still representable in that smaller bit width.
73+
unsigned MaximalPossibleTotalShiftAmount =
74+
(Sh0->getType()->getScalarSizeInBits() - 1) +
75+
(Sh1->getType()->getScalarSizeInBits() - 1);
76+
APInt MaximalRepresentableShiftAmount =
77+
APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits());
78+
if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
79+
return nullptr;
80+
6181
// We are only looking for signbit extraction if we have two right shifts.
6282
bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) &&
6383
match(Sh1, m_Shr(m_Value(), m_Value()));

llvm/test/Transforms/InstCombine/shift-amount-reassociation.ll

+5-2
Original file line numberDiff line numberDiff line change
@@ -320,12 +320,15 @@ define i32 @n20(i32 %x, i32 %y) {
320320
ret i32 %t3
321321
}
322322

323-
; FIXME: this is a miscompile. We should not transform this.
324323
; See https://bugs.llvm.org/show_bug.cgi?id=44802
325324
define i3 @pr44802(i3 %t0) {
326325
; CHECK-LABEL: @pr44802(
327326
; CHECK-NEXT: [[T1:%.*]] = sub i3 0, [[T0:%.*]]
328-
; CHECK-NEXT: ret i3 [[T1]]
327+
; CHECK-NEXT: [[T2:%.*]] = icmp ne i3 [[T0]], 0
328+
; CHECK-NEXT: [[T3:%.*]] = zext i1 [[T2]] to i3
329+
; CHECK-NEXT: [[T4:%.*]] = lshr i3 [[T1]], [[T3]]
330+
; CHECK-NEXT: [[T5:%.*]] = lshr i3 [[T4]], [[T3]]
331+
; CHECK-NEXT: ret i3 [[T5]]
329332
;
330333
%t1 = sub i3 0, %t0
331334
%t2 = icmp ne i3 %t0, 0

0 commit comments

Comments
 (0)