Skip to content

Commit 4973f02

Browse files
committed
[ConstantRange] Estimate tighter lower (upper) bounds for masked binary and (or)
1 parent 4a7673d commit 4973f02

File tree

2 files changed

+188
-6
lines changed

2 files changed

+188
-6
lines changed

llvm/lib/IR/ConstantRange.cpp

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,15 +1520,102 @@ ConstantRange ConstantRange::binaryNot() const {
15201520
return ConstantRange(APInt::getAllOnes(getBitWidth())).sub(*this);
15211521
}
15221522

1523+
/// Estimate the 'bit-masked AND' operation's lower bound.
1524+
///
1525+
/// E.g., given two ranges as follows (single quotes are separators and
1526+
/// have no meaning here),
1527+
///
1528+
/// LHS = [10'001'010, ; LLo
1529+
/// 10'100'000] ; LHi
1530+
/// RHS = [10'111'010, ; RLo
1531+
/// 10'111'100] ; RHi
1532+
///
1533+
/// we know that the higher 2 bits of the result is always '10'; and note that
1534+
/// there's at least one bit is 1 in LHS[3:6] (since the range is continuous),
1535+
/// and all bits in RHS[3:6] are 1, so we know the lower bound of the result is
1536+
/// 10'001'000.
1537+
///
1538+
/// The algorithm is as follows,
1539+
/// 1. we first calculate a mask to mask out the higher common bits by
1540+
/// Mask = (LLo ^ LHi) | (LLo ^ LHi) | (LLo ^ RLo);
1541+
/// Mask = set all non-leading-zero bits to 1 for Mask;
1542+
/// 2. find the bit field with at least 1 in LHS (i.e., bit 3:6 in the example)
1543+
/// after applying the mask, with
1544+
/// StartBit = BitWidth - (LLo & Mask).clz() - 1;
1545+
/// EndBit = BitWidth - (LHi & Mask).clz();
1546+
/// 3. check if all bits in [StartBit:EndBit] in RHS are 1, and all bits of
1547+
/// RLo and RHi in [StartBit:BitWidth] are same, and if so, the lower bound
1548+
/// can be updated to
1549+
/// LowerBound = LLo & Keep;
1550+
/// where Keep is a mask to mask out trailing bits (the lower 3 bits in the
1551+
/// example);
1552+
/// 4. repeat the step 2 and 3 with LHS and RHS swapped, and update the lower
1553+
/// bound with the smaller one.
1554+
static APInt estimateBitMaskedAndLowerBound(const ConstantRange &LHS,
1555+
const ConstantRange &RHS) {
1556+
auto BitWidth = LHS.getBitWidth();
1557+
// If either is full set or unsigned wrapped, then the range must contain '0'
1558+
// which leads the lower bound to 0.
1559+
if ((LHS.isFullSet() || RHS.isFullSet()) ||
1560+
(LHS.isWrappedSet() || RHS.isWrappedSet()))
1561+
return APInt::getZero(BitWidth);
1562+
1563+
auto LLo = LHS.getLower();
1564+
auto LHi = LHS.getUpper() - 1;
1565+
auto RLo = RHS.getLower();
1566+
auto RHi = RHS.getUpper() - 1;
1567+
1568+
// Calculate the mask that mask out the higher common bits.
1569+
auto Mask = (LLo ^ LHi) | (RLo ^ RHi) | (LLo ^ RLo);
1570+
unsigned LeadingZeros = Mask.countLeadingZeros();
1571+
Mask.setLowBits(BitWidth - LeadingZeros);
1572+
1573+
auto estimateBound =
1574+
[BitWidth, &Mask](const APInt &ALo, const APInt &AHi, const APInt &BLo,
1575+
const APInt &BHi) -> std::optional<APInt> {
1576+
unsigned LeadingZeros = (ALo & Mask).countLeadingZeros();
1577+
if (LeadingZeros == BitWidth)
1578+
return std::nullopt;
1579+
1580+
unsigned StartBit = BitWidth - LeadingZeros - 1;
1581+
1582+
if (BLo.extractBits(BitWidth - StartBit, StartBit) !=
1583+
BHi.extractBits(BitWidth - StartBit, StartBit))
1584+
return std::nullopt;
1585+
1586+
unsigned EndBit = BitWidth - (AHi & Mask).countLeadingZeros();
1587+
if (!(BLo.extractBits(EndBit - StartBit, StartBit) &
1588+
BHi.extractBits(EndBit - StartBit, StartBit))
1589+
.isAllOnes())
1590+
return std::nullopt;
1591+
1592+
APInt Keep(BitWidth, 0);
1593+
Keep.setBits(StartBit, BitWidth);
1594+
return Keep & ALo;
1595+
};
1596+
1597+
auto LowerBoundByLHS = estimateBound(LLo, LHi, RLo, RHi);
1598+
auto LowerBoundByRHS = estimateBound(RLo, RHi, LLo, LHi);
1599+
1600+
if (LowerBoundByLHS && LowerBoundByRHS)
1601+
return LowerBoundByLHS->ult(*LowerBoundByRHS) ? *LowerBoundByLHS
1602+
: *LowerBoundByRHS;
1603+
if (LowerBoundByLHS)
1604+
return *LowerBoundByLHS;
1605+
if (LowerBoundByRHS)
1606+
return *LowerBoundByRHS;
1607+
return APInt::getZero(BitWidth);
1608+
}
1609+
15231610
ConstantRange ConstantRange::binaryAnd(const ConstantRange &Other) const {
15241611
if (isEmptySet() || Other.isEmptySet())
15251612
return getEmpty();
15261613

15271614
ConstantRange KnownBitsRange =
15281615
fromKnownBits(toKnownBits() & Other.toKnownBits(), false);
1529-
ConstantRange UMinUMaxRange =
1530-
getNonEmpty(APInt::getZero(getBitWidth()),
1531-
APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
1616+
auto LowerBound = estimateBitMaskedAndLowerBound(*this, Other);
1617+
ConstantRange UMinUMaxRange = getNonEmpty(
1618+
LowerBound, APIntOps::umin(Other.getUnsignedMax(), getUnsignedMax()) + 1);
15321619
return KnownBitsRange.intersectWith(UMinUMaxRange);
15331620
}
15341621

@@ -1538,10 +1625,17 @@ ConstantRange ConstantRange::binaryOr(const ConstantRange &Other) const {
15381625

15391626
ConstantRange KnownBitsRange =
15401627
fromKnownBits(toKnownBits() | Other.toKnownBits(), false);
1628+
1629+
// ~a & ~b >= x
1630+
// <=> ~(~a & ~b) <= ~x
1631+
// <=> a | b <= ~x
1632+
// <=> a | b < ~x + 1
1633+
// thus, UpperBound(a | b) == ~LowerBound(~a & ~b) + 1
1634+
auto UpperBound =
1635+
~estimateBitMaskedAndLowerBound(binaryNot(), Other.binaryNot()) + 1;
15411636
// Upper wrapped range.
1542-
ConstantRange UMaxUMinRange =
1543-
getNonEmpty(APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()),
1544-
APInt::getZero(getBitWidth()));
1637+
ConstantRange UMaxUMinRange = getNonEmpty(
1638+
APIntOps::umax(getUnsignedMin(), Other.getUnsignedMin()), UpperBound);
15451639
return KnownBitsRange.intersectWith(UMaxUMinRange);
15461640
}
15471641

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
2+
; RUN: opt -S -passes=ipsccp %s | FileCheck %s
3+
4+
declare void @use(i1)
5+
6+
define i1 @test1(i64 %x) {
7+
; CHECK-LABEL: @test1(
8+
; CHECK-NEXT: entry:
9+
; CHECK-NEXT: [[COND:%.*]] = icmp ugt i64 [[X:%.*]], 65535
10+
; CHECK-NEXT: call void @llvm.assume(i1 [[COND]])
11+
; CHECK-NEXT: [[MASK:%.*]] = and i64 [[X]], -65521
12+
; CHECK-NEXT: ret i1 false
13+
;
14+
entry:
15+
%cond = icmp ugt i64 %x, 65535
16+
call void @llvm.assume(i1 %cond)
17+
%mask = and i64 %x, -65521
18+
%cmp = icmp eq i64 %mask, 0
19+
ret i1 %cmp
20+
}
21+
22+
define void @test.and(i64 %x, i64 %y) {
23+
; CHECK-LABEL: @test.and(
24+
; CHECK-NEXT: entry:
25+
; CHECK-NEXT: [[C0:%.*]] = icmp uge i64 [[X:%.*]], 138
26+
; CHECK-NEXT: [[C1:%.*]] = icmp ule i64 [[X]], 161
27+
; CHECK-NEXT: call void @llvm.assume(i1 [[C0]])
28+
; CHECK-NEXT: call void @llvm.assume(i1 [[C1]])
29+
; CHECK-NEXT: [[C2:%.*]] = icmp uge i64 [[Y:%.*]], 186
30+
; CHECK-NEXT: [[C3:%.*]] = icmp ule i64 [[Y]], 188
31+
; CHECK-NEXT: call void @llvm.assume(i1 [[C2]])
32+
; CHECK-NEXT: call void @llvm.assume(i1 [[C3]])
33+
; CHECK-NEXT: [[AND:%.*]] = and i64 [[X]], [[Y]]
34+
; CHECK-NEXT: call void @use(i1 false)
35+
; CHECK-NEXT: [[R1:%.*]] = icmp ult i64 [[AND]], 137
36+
; CHECK-NEXT: call void @use(i1 [[R1]])
37+
; CHECK-NEXT: ret void
38+
;
39+
entry:
40+
%c0 = icmp uge i64 %x, 138 ; 0b10001010
41+
%c1 = icmp ule i64 %x, 161 ; 0b10100000
42+
call void @llvm.assume(i1 %c0)
43+
call void @llvm.assume(i1 %c1)
44+
%c2 = icmp uge i64 %y, 186 ; 0b10111010
45+
%c3 = icmp ule i64 %y, 188 ; 0b10111110
46+
call void @llvm.assume(i1 %c2)
47+
call void @llvm.assume(i1 %c3)
48+
%and = and i64 %x, %y
49+
%r0 = icmp ult i64 %and, 136 ; 0b10001000
50+
call void @use(i1 %r0) ; false
51+
%r1 = icmp ult i64 %and, 137
52+
call void @use(i1 %r1) ; unknown
53+
ret void
54+
}
55+
56+
define void @test.or(i64 %x, i64 %y) {
57+
; CHECK-LABEL: @test.or(
58+
; CHECK-NEXT: entry:
59+
; CHECK-NEXT: [[C0:%.*]] = icmp ule i64 [[X:%.*]], 117
60+
; CHECK-NEXT: [[C1:%.*]] = icmp uge i64 [[X]], 95
61+
; CHECK-NEXT: call void @llvm.assume(i1 [[C0]])
62+
; CHECK-NEXT: call void @llvm.assume(i1 [[C1]])
63+
; CHECK-NEXT: [[C2:%.*]] = icmp ule i64 [[Y:%.*]], 69
64+
; CHECK-NEXT: [[C3:%.*]] = icmp uge i64 [[Y]], 67
65+
; CHECK-NEXT: call void @llvm.assume(i1 [[C2]])
66+
; CHECK-NEXT: call void @llvm.assume(i1 [[C3]])
67+
; CHECK-NEXT: [[OR:%.*]] = or i64 [[X]], [[Y]]
68+
; CHECK-NEXT: call void @use(i1 false)
69+
; CHECK-NEXT: [[R1:%.*]] = icmp ugt i64 [[OR]], 118
70+
; CHECK-NEXT: call void @use(i1 [[R1]])
71+
; CHECK-NEXT: ret void
72+
;
73+
entry:
74+
%c0 = icmp ule i64 %x, 117 ; 0b01110101
75+
%c1 = icmp uge i64 %x, 95 ; 0b01011111
76+
call void @llvm.assume(i1 %c0)
77+
call void @llvm.assume(i1 %c1)
78+
%c2 = icmp ule i64 %y, 69 ; 0b01000101
79+
%c3 = icmp uge i64 %y, 67 ; 0b01000011
80+
call void @llvm.assume(i1 %c2)
81+
call void @llvm.assume(i1 %c3)
82+
%or = or i64 %x, %y
83+
%r0 = icmp ugt i64 %or, 119 ; 0b01110111
84+
call void @use(i1 %r0) ; false
85+
%r1 = icmp ugt i64 %or, 118
86+
call void @use(i1 %r1) ; unknown
87+
ret void
88+
}

0 commit comments

Comments
 (0)