Skip to content

Commit d20796d

Browse files
authored
[InstCombine] Offset both sides of an equality icmp (#134086)
Proof: https://alive2.llvm.org/ce/z/zQ2UW4 Closes #134024
1 parent ad66e54 commit d20796d

File tree

4 files changed

+365
-6
lines changed

4 files changed

+365
-6
lines changed

llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp

+126
Original file line numberDiff line numberDiff line change
@@ -5808,6 +5808,128 @@ static Instruction *foldICmpPow2Test(ICmpInst &I,
58085808
return nullptr;
58095809
}
58105810

5811+
/// Find all possible pairs (BinOp, RHS) that BinOp V, RHS can be simplified.
5812+
using OffsetOp = std::pair<Instruction::BinaryOps, Value *>;
5813+
static void collectOffsetOp(Value *V, SmallVectorImpl<OffsetOp> &Offsets,
5814+
bool AllowRecursion) {
5815+
Instruction *Inst = dyn_cast<Instruction>(V);
5816+
if (!Inst || !Inst->hasOneUse())
5817+
return;
5818+
5819+
switch (Inst->getOpcode()) {
5820+
case Instruction::Add:
5821+
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(1));
5822+
Offsets.emplace_back(Instruction::Sub, Inst->getOperand(0));
5823+
break;
5824+
case Instruction::Sub:
5825+
Offsets.emplace_back(Instruction::Add, Inst->getOperand(1));
5826+
break;
5827+
case Instruction::Xor:
5828+
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(1));
5829+
Offsets.emplace_back(Instruction::Xor, Inst->getOperand(0));
5830+
break;
5831+
case Instruction::Select:
5832+
if (AllowRecursion) {
5833+
collectOffsetOp(Inst->getOperand(1), Offsets, /*AllowRecursion=*/false);
5834+
collectOffsetOp(Inst->getOperand(2), Offsets, /*AllowRecursion=*/false);
5835+
}
5836+
break;
5837+
default:
5838+
break;
5839+
}
5840+
}
5841+
5842+
enum class OffsetKind { Invalid, Value, Select };
5843+
5844+
struct OffsetResult {
5845+
OffsetKind Kind;
5846+
Value *V0, *V1, *V2;
5847+
5848+
static OffsetResult invalid() {
5849+
return {OffsetKind::Invalid, nullptr, nullptr, nullptr};
5850+
}
5851+
static OffsetResult value(Value *V) {
5852+
return {OffsetKind::Value, V, nullptr, nullptr};
5853+
}
5854+
static OffsetResult select(Value *Cond, Value *TrueV, Value *FalseV) {
5855+
return {OffsetKind::Select, Cond, TrueV, FalseV};
5856+
}
5857+
bool isValid() const { return Kind != OffsetKind::Invalid; }
5858+
Value *materialize(InstCombiner::BuilderTy &Builder) const {
5859+
switch (Kind) {
5860+
case OffsetKind::Invalid:
5861+
llvm_unreachable("Invalid offset result");
5862+
case OffsetKind::Value:
5863+
return V0;
5864+
case OffsetKind::Select:
5865+
return Builder.CreateSelect(V0, V1, V2);
5866+
}
5867+
}
5868+
};
5869+
5870+
/// Offset both sides of an equality icmp to see if we can save some
5871+
/// instructions: icmp eq/ne X, Y -> icmp eq/ne X op Z, Y op Z.
5872+
/// Note: This operation should not introduce poison.
5873+
static Instruction *foldICmpEqualityWithOffset(ICmpInst &I,
5874+
InstCombiner::BuilderTy &Builder,
5875+
const SimplifyQuery &SQ) {
5876+
assert(I.isEquality() && "Expected an equality icmp");
5877+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
5878+
if (!Op0->getType()->isIntOrIntVectorTy())
5879+
return nullptr;
5880+
5881+
SmallVector<OffsetOp, 4> OffsetOps;
5882+
collectOffsetOp(Op0, OffsetOps, /*AllowRecursion=*/true);
5883+
collectOffsetOp(Op1, OffsetOps, /*AllowRecursion=*/true);
5884+
5885+
auto ApplyOffsetImpl = [&](Value *V, unsigned BinOpc, Value *RHS) -> Value * {
5886+
Value *Simplified = simplifyBinOp(BinOpc, V, RHS, SQ);
5887+
// Avoid infinite loops by checking if RHS is an identity for the BinOp.
5888+
if (!Simplified || Simplified == V)
5889+
return nullptr;
5890+
// Reject constant expressions as they don't simplify things.
5891+
if (isa<Constant>(Simplified) && !match(Simplified, m_ImmConstant()))
5892+
return nullptr;
5893+
// Check if the transformation introduces poison.
5894+
return impliesPoison(RHS, V) ? Simplified : nullptr;
5895+
};
5896+
5897+
auto ApplyOffset = [&](Value *V, unsigned BinOpc,
5898+
Value *RHS) -> OffsetResult {
5899+
if (auto *Sel = dyn_cast<SelectInst>(V)) {
5900+
if (!Sel->hasOneUse())
5901+
return OffsetResult::invalid();
5902+
Value *TrueVal = ApplyOffsetImpl(Sel->getTrueValue(), BinOpc, RHS);
5903+
if (!TrueVal)
5904+
return OffsetResult::invalid();
5905+
Value *FalseVal = ApplyOffsetImpl(Sel->getFalseValue(), BinOpc, RHS);
5906+
if (!FalseVal)
5907+
return OffsetResult::invalid();
5908+
return OffsetResult::select(Sel->getCondition(), TrueVal, FalseVal);
5909+
}
5910+
if (Value *Simplified = ApplyOffsetImpl(V, BinOpc, RHS))
5911+
return OffsetResult::value(Simplified);
5912+
return OffsetResult::invalid();
5913+
};
5914+
5915+
for (auto [BinOp, RHS] : OffsetOps) {
5916+
auto BinOpc = static_cast<unsigned>(BinOp);
5917+
5918+
auto Op0Result = ApplyOffset(Op0, BinOpc, RHS);
5919+
if (!Op0Result.isValid())
5920+
continue;
5921+
auto Op1Result = ApplyOffset(Op1, BinOpc, RHS);
5922+
if (!Op1Result.isValid())
5923+
continue;
5924+
5925+
Value *NewLHS = Op0Result.materialize(Builder);
5926+
Value *NewRHS = Op1Result.materialize(Builder);
5927+
return new ICmpInst(I.getPredicate(), NewLHS, NewRHS);
5928+
}
5929+
5930+
return nullptr;
5931+
}
5932+
58115933
Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
58125934
if (!I.isEquality())
58135935
return nullptr;
@@ -6054,6 +6176,10 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
60546176
: ConstantInt::getNullValue(A->getType()));
60556177
}
60566178

6179+
if (auto *Res = foldICmpEqualityWithOffset(
6180+
I, Builder, getSimplifyQuery().getWithInstruction(&I)))
6181+
return Res;
6182+
60576183
return nullptr;
60586184
}
60596185

llvm/test/Transforms/InstCombine/icmp-add.ll

+2-4
Original file line numberDiff line numberDiff line change
@@ -2380,8 +2380,7 @@ define <2 x i1> @icmp_eq_add_non_splat(<2 x i32> %a) {
23802380

23812381
define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
23822382
; CHECK-LABEL: @icmp_eq_add_undef2(
2383-
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
2384-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 undef>
2383+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 undef>
23852384
; CHECK-NEXT: ret <2 x i1> [[CMP]]
23862385
;
23872386
%add = add <2 x i32> %a, <i32 5, i32 5>
@@ -2391,8 +2390,7 @@ define <2 x i1> @icmp_eq_add_undef2(<2 x i32> %a) {
23912390

23922391
define <2 x i1> @icmp_eq_add_non_splat2(<2 x i32> %a) {
23932392
; CHECK-LABEL: @icmp_eq_add_non_splat2(
2394-
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i32> [[A:%.*]], splat (i32 5)
2395-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[ADD]], <i32 10, i32 11>
2393+
; CHECK-NEXT: [[CMP:%.*]] = icmp eq <2 x i32> [[A:%.*]], <i32 5, i32 6>
23962394
; CHECK-NEXT: ret <2 x i1> [[CMP]]
23972395
;
23982396
%add = add <2 x i32> %a, <i32 5, i32 5>

llvm/test/Transforms/InstCombine/icmp-equality-xor.ll

+1-2
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ define i1 @foo2(i32 %x, i32 %y) {
136136
define <2 x i1> @foo3(<2 x i8> %x) {
137137
; CHECK-LABEL: @foo3(
138138
; CHECK-NEXT: entry:
139-
; CHECK-NEXT: [[XOR:%.*]] = xor <2 x i8> [[X:%.*]], <i8 -2, i8 -1>
140-
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[XOR]], <i8 9, i8 79>
139+
; CHECK-NEXT: [[CMP:%.*]] = icmp ne <2 x i8> [[X:%.*]], <i8 -9, i8 -80>
141140
; CHECK-NEXT: ret <2 x i1> [[CMP]]
142141
;
143142
entry:

0 commit comments

Comments
 (0)