Skip to content

Commit b452dac

Browse files
committed
[InstCombine] Decompose more icmps into masks
Extend decomposeBitTestICmp() to handle cases where the resulting comparison is of the form `icmp (X & Mask) pred Cmp` with non-zero `Cmp`. Add a flag to allow code to opt-in to this behavior and use it in the "log op of icmp" fold infrastructure. This addresses regressions from llvm#97289. Proofs: https://alive2.llvm.org/ce/z/hUhdbU
1 parent eb3361d commit b452dac

File tree

5 files changed

+66
-47
lines changed

5 files changed

+66
-47
lines changed

llvm/include/llvm/Analysis/CmpInstAnalysis.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,21 @@ namespace llvm {
9292
Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
9393
CmpInst::Predicate &Pred);
9494

95-
/// Represents the operation icmp (X & Mask) pred 0, where pred can only be
95+
/// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be
9696
/// eq or ne.
9797
struct DecomposedBitTest {
9898
Value *X;
9999
CmpInst::Predicate Pred;
100100
APInt Mask;
101+
APInt Cmp;
101102
};
102103

103-
/// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
104+
/// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible.
105+
/// Unless \p AllowNonZeroCmp is true, Cmp will always be 0.
104106
std::optional<DecomposedBitTest>
105107
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
106-
bool LookThroughTrunc = true);
108+
bool LookThroughTrunc = true,
109+
bool AllowNonZeroCmp = false);
107110

108111
} // end namespace llvm
109112

llvm/lib/Analysis/CmpInstAnalysis.cpp

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
7575

7676
std::optional<DecomposedBitTest>
7777
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
78-
bool LookThruTrunc) {
78+
bool LookThruTrunc, bool AllowNonZeroCmp) {
7979
using namespace PatternMatch;
8080

8181
const APInt *OrigC;
@@ -100,29 +100,65 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
100100
switch (Pred) {
101101
default:
102102
llvm_unreachable("Unexpected predicate");
103-
case ICmpInst::ICMP_SLT:
103+
case ICmpInst::ICMP_SLT: {
104104
// X < 0 is equivalent to (X & SignMask) != 0.
105-
if (!C.isZero())
106-
return std::nullopt;
107-
Result.Mask = APInt::getSignMask(C.getBitWidth());
108-
Result.Pred = ICmpInst::ICMP_NE;
109-
break;
105+
if (C.isZero()) {
106+
Result.Mask = APInt::getSignMask(C.getBitWidth());
107+
Result.Cmp = APInt::getZero(C.getBitWidth());
108+
Result.Pred = ICmpInst::ICMP_NE;
109+
break;
110+
}
111+
112+
APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
113+
if (FlippedSign.isPowerOf2()) {
114+
// X s< 10000100 is equivalent to (X & 11111100 == 10000000)
115+
Result.Mask = -FlippedSign;
116+
Result.Cmp = APInt::getSignMask(C.getBitWidth());
117+
Result.Pred = ICmpInst::ICMP_EQ;
118+
break;
119+
}
120+
121+
if (FlippedSign.isNegatedPowerOf2()) {
122+
// X s< 01111100 is equivalent to (X & 11111100 != 01111100)
123+
Result.Mask = FlippedSign;
124+
Result.Cmp = C;
125+
Result.Pred = ICmpInst::ICMP_NE;
126+
break;
127+
}
128+
129+
return std::nullopt;
130+
}
110131
case ICmpInst::ICMP_ULT:
111132
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
112-
if (!C.isPowerOf2())
113-
return std::nullopt;
114-
Result.Mask = -C;
115-
Result.Pred = ICmpInst::ICMP_EQ;
116-
break;
133+
if (C.isPowerOf2()) {
134+
Result.Mask = -C;
135+
Result.Cmp = APInt::getZero(C.getBitWidth());
136+
Result.Pred = ICmpInst::ICMP_EQ;
137+
break;
138+
}
139+
140+
// X u< 11111100 is equivalent to (X & 11111100 != 11111100)
141+
if (C.isNegatedPowerOf2()) {
142+
Result.Mask = C;
143+
Result.Cmp = C;
144+
Result.Pred = ICmpInst::ICMP_NE;
145+
break;
146+
}
147+
148+
return std::nullopt;
117149
}
118150

151+
if (!AllowNonZeroCmp && !Result.Cmp.isZero())
152+
return std::nullopt;
153+
119154
if (Inverted)
120155
Result.Pred = ICmpInst::getInversePredicate(Result.Pred);
121156

122157
Value *X;
123158
if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
124159
Result.X = X;
125160
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
161+
Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits());
126162
} else {
127163
Result.X = LHS;
128164
}

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
181181
// Adapts the external decomposeBitTestICmp for local use.
182182
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
183183
Value *&X, Value *&Y, Value *&Z) {
184-
auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
184+
auto Res = llvm::decomposeBitTestICmp(
185+
LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true);
185186
if (!Res)
186187
return false;
187188

188189
Pred = Res->Pred;
189190
X = Res->X;
190191
Y = ConstantInt::get(X->getType(), Res->Mask);
191-
Z = ConstantInt::get(X->getType(), 0);
192+
Z = ConstantInt::get(X->getType(), Res->Cmp);
192193
return true;
193194
}
194195

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5919,29 +5919,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
59195919
// This matches patterns corresponding to tests of the signbit as well as:
59205920
// (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
59215921
// (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
5922-
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
5922+
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
5923+
/*AllowNonZeroCmp=*/true)) {
59235924
Value *And = Builder.CreateAnd(Res->X, Res->Mask);
5924-
Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
5925+
Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp);
59255926
return new ICmpInst(Res->Pred, And, Zero);
59265927
}
59275928

59285929
unsigned SrcBits = X->getType()->getScalarSizeInBits();
5929-
if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
5930-
// If C is a negative power-of-2 (high-bit mask):
5931-
// (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
5932-
Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
5933-
Value *And = Builder.CreateAnd(X, MaskC);
5934-
return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
5935-
}
5936-
5937-
if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
5938-
// If C is not-of-power-of-2 (one clear bit):
5939-
// (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
5940-
Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
5941-
Value *And = Builder.CreateAnd(X, MaskC);
5942-
return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
5943-
}
5944-
59455930
if (auto *II = dyn_cast<IntrinsicInst>(X)) {
59465931
if (II->getIntrinsicID() == Intrinsic::cttz ||
59475932
II->getIntrinsicID() == Intrinsic::ctlz) {

llvm/test/Transforms/InstCombine/and-or-icmps.ll

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {
33353335

33363336
define i1 @and_slt_to_mask(i8 %x) {
33373337
; CHECK-LABEL: @and_slt_to_mask(
3338-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
3339-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
3340-
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
3341-
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
3338+
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
3339+
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128
33423340
; CHECK-NEXT: ret i1 [[AND2]]
33433341
;
33443342
%cmp = icmp slt i8 %x, -124
@@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {
33653363

33663364
define i1 @and_sgt_to_mask(i8 %x) {
33673365
; CHECK-LABEL: @and_sgt_to_mask(
3368-
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
3369-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
3370-
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
3371-
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
3366+
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
3367+
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
33723368
; CHECK-NEXT: ret i1 [[AND2]]
33733369
;
33743370
%cmp = icmp sgt i8 %x, 123
@@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {
33953391

33963392
define i1 @and_ugt_to_mask(i8 %x) {
33973393
; CHECK-LABEL: @and_ugt_to_mask(
3398-
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
3399-
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
3400-
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
3401-
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
3394+
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
3395+
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
34023396
; CHECK-NEXT: ret i1 [[AND2]]
34033397
;
34043398
%cmp = icmp ugt i8 %x, -5

0 commit comments

Comments
 (0)