Skip to content

Commit 5d8fb47

Browse files
authored
[InstCombine] Fold comparison of adding two z/sext booleans (#67895)
- Add test coverage for sext/zext boolean additions - [InstCombine] Fold comparison of adding two z/sext booleans Fixes #64859.
1 parent 185e16d commit 5d8fb47

File tree

2 files changed

+157
-285
lines changed

2 files changed

+157
-285
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

Lines changed: 73 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/IR/PatternMatch.h"
2727
#include "llvm/Support/KnownBits.h"
2828
#include "llvm/Transforms/InstCombine/InstCombiner.h"
29+
#include <bitset>
2930

3031
using namespace llvm;
3132
using namespace PatternMatch;
@@ -2895,19 +2896,89 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp,
28952896
return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C));
28962897
}
28972898

2899+
static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0,
2900+
Value *Op1, IRBuilderBase &Builder,
2901+
bool HasOneUse) {
2902+
switch (Table.to_ulong()) {
2903+
case 0: // 0 0 0 0
2904+
return Builder.getFalse();
2905+
case 1: // 0 0 0 1
2906+
return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr;
2907+
case 2: // 0 0 1 0
2908+
return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr;
2909+
case 3: // 0 0 1 1
2910+
return Builder.CreateNot(Op0);
2911+
case 4: // 0 1 0 0
2912+
return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr;
2913+
case 5: // 0 1 0 1
2914+
return Builder.CreateNot(Op1);
2915+
case 6: // 0 1 1 0
2916+
return Builder.CreateXor(Op0, Op1);
2917+
case 7: // 0 1 1 1
2918+
return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr;
2919+
case 8: // 1 0 0 0
2920+
return Builder.CreateAnd(Op0, Op1);
2921+
case 9: // 1 0 0 1
2922+
return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr;
2923+
case 10: // 1 0 1 0
2924+
return Op1;
2925+
case 11: // 1 0 1 1
2926+
return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr;
2927+
case 12: // 1 1 0 0
2928+
return Op0;
2929+
case 13: // 1 1 0 1
2930+
return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr;
2931+
case 14: // 1 1 1 0
2932+
return Builder.CreateOr(Op0, Op1);
2933+
case 15: // 1 1 1 1
2934+
return Builder.getTrue();
2935+
default:
2936+
llvm_unreachable("Invalid Operation");
2937+
}
2938+
return nullptr;
2939+
}
2940+
28982941
/// Fold icmp (add X, Y), C.
28992942
Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
29002943
BinaryOperator *Add,
29012944
const APInt &C) {
29022945
Value *Y = Add->getOperand(1);
2946+
Value *X = Add->getOperand(0);
2947+
2948+
Value *Op0, *Op1;
2949+
Instruction *Ext0, *Ext1;
2950+
const CmpInst::Predicate Pred = Cmp.getPredicate();
2951+
if (match(Add,
2952+
m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))),
2953+
m_CombineAnd(m_Instruction(Ext1),
2954+
m_ZExtOrSExt(m_Value(Op1))))) &&
2955+
Op0->getType()->isIntOrIntVectorTy(1) &&
2956+
Op1->getType()->isIntOrIntVectorTy(1)) {
2957+
unsigned BW = C.getBitWidth();
2958+
std::bitset<4> Table;
2959+
auto ComputeTable = [&](bool Op0Val, bool Op1Val) {
2960+
int Res = 0;
2961+
if (Op0Val)
2962+
Res += isa<ZExtInst>(Ext0) ? 1 : -1;
2963+
if (Op1Val)
2964+
Res += isa<ZExtInst>(Ext1) ? 1 : -1;
2965+
return ICmpInst::compare(APInt(BW, Res, true), C, Pred);
2966+
};
2967+
2968+
Table[0] = ComputeTable(false, false);
2969+
Table[1] = ComputeTable(false, true);
2970+
Table[2] = ComputeTable(true, false);
2971+
Table[3] = ComputeTable(true, true);
2972+
if (auto *Cond =
2973+
createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse()))
2974+
return replaceInstUsesWith(Cmp, Cond);
2975+
}
29032976
const APInt *C2;
29042977
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
29052978
return nullptr;
29062979

29072980
// Fold icmp pred (add X, C2), C.
2908-
Value *X = Add->getOperand(0);
29092981
Type *Ty = Add->getType();
2910-
const CmpInst::Predicate Pred = Cmp.getPredicate();
29112982

29122983
// If the add does not wrap, we can always adjust the compare by subtracting
29132984
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
@@ -6410,60 +6481,6 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) {
64106481
Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE)
64116482
return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y);
64126483

6413-
const APInt *C;
6414-
if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) &&
6415-
match(I.getOperand(1), m_APInt(C)) &&
6416-
X->getType()->isIntOrIntVectorTy(1) &&
6417-
Y->getType()->isIntOrIntVectorTy(1)) {
6418-
unsigned BitWidth = C->getBitWidth();
6419-
Pred = I.getPredicate();
6420-
APInt Zero = APInt::getZero(BitWidth);
6421-
APInt MinusOne = APInt::getAllOnes(BitWidth);
6422-
APInt One(BitWidth, 1);
6423-
if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) ||
6424-
(C->slt(Zero) && Pred == ICmpInst::ICMP_SLT))
6425-
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
6426-
if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) ||
6427-
(C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT))
6428-
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
6429-
6430-
if (I.getOperand(0)->hasOneUse()) {
6431-
APInt NewC = *C;
6432-
// canonicalize predicate to eq/ne
6433-
if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) ||
6434-
(*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) {
6435-
// x s< 0 in [-1, 1] --> x == -1
6436-
// x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1
6437-
NewC = MinusOne;
6438-
Pred = ICmpInst::ICMP_EQ;
6439-
} else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) ||
6440-
(*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) {
6441-
// x s> -1 in [-1, 1] --> x != -1
6442-
// x u< -1 in [-1, 1] --> x != -1
6443-
Pred = ICmpInst::ICMP_NE;
6444-
} else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) {
6445-
// x s> 0 in [-1, 1] --> x == 1
6446-
NewC = One;
6447-
Pred = ICmpInst::ICMP_EQ;
6448-
} else if (*C == One && Pred == ICmpInst::ICMP_SLT) {
6449-
// x s< 1 in [-1, 1] --> x != 1
6450-
Pred = ICmpInst::ICMP_NE;
6451-
}
6452-
6453-
if (NewC == MinusOne) {
6454-
if (Pred == ICmpInst::ICMP_EQ)
6455-
return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y);
6456-
if (Pred == ICmpInst::ICMP_NE)
6457-
return BinaryOperator::CreateOr(X, Builder.CreateNot(Y));
6458-
} else if (NewC == One) {
6459-
if (Pred == ICmpInst::ICMP_EQ)
6460-
return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y));
6461-
if (Pred == ICmpInst::ICMP_NE)
6462-
return BinaryOperator::CreateOr(Builder.CreateNot(X), Y);
6463-
}
6464-
}
6465-
}
6466-
64676484
return nullptr;
64686485
}
64696486

0 commit comments

Comments
 (0)