Skip to content

[InstCombine] Try optimizing with knownbits which determined from Cond #91762

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ void computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownBits &Known);
void computeKnownBitsFromContext(const Value *V, KnownBits &Known,
unsigned Depth, const SimplifyQuery &Q);

void computeKnownBitsFromCond(const Value *V, Value *Cond, KnownBits &Known,
unsigned Depth, const SimplifyQuery &SQ,
bool Invert);

/// Using KnownBits LHS/RHS produce the known bits for logic op (and/xor/or).
KnownBits analyzeKnownBitsFromAndXorOr(const Operator *I,
const KnownBits &KnownLHS,
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Transforms/InstCombine/InstCombiner.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,13 @@ class LLVM_LIBRARY_VISIBILITY InstCombiner {
return llvm::computeKnownBits(V, Depth, SQ.getWithInstruction(CxtI));
}

void computeKnownBitsFromCond(const Value *V, ICmpInst *Cmp, KnownBits &Known,
unsigned Depth, const Instruction *CxtI,
bool Invert) const {
llvm::computeKnownBitsFromCond(V, Cmp, Known, Depth,
SQ.getWithInstruction(CxtI), Invert);
}

bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
unsigned Depth = 0,
const Instruction *CxtI = nullptr) {
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -747,9 +747,9 @@ static void computeKnownBitsFromICmpCond(const Value *V, ICmpInst *Cmp,
computeKnownBitsFromCmp(V, Pred, LHS, RHS, Known, SQ);
}

static void computeKnownBitsFromCond(const Value *V, Value *Cond,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &SQ, bool Invert) {
void llvm::computeKnownBitsFromCond(const Value *V, Value *Cond,
KnownBits &Known, unsigned Depth,
const SimplifyQuery &SQ, bool Invert) {
Value *A, *B;
if (Depth < MaxAnalysisRecursionDepth &&
match(Cond, m_LogicalOp(m_Value(A), m_Value(B)))) {
Expand Down
196 changes: 196 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1809,6 +1809,198 @@ static Instruction *foldSelectICmpEq(SelectInst &SI, ICmpInst *ICI,
return nullptr;
}

// ICmpInst of SelectInst is not included in the calculation of KnownBits
// so we are missing the opportunity to optimize the Value of the True or
// False Condition via ICmpInst with KnownBits.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@ParkHanbum ParkHanbum May 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, same way to use computeKnownBitsFromCond, but computeKnownBitsFromOperator would compute KnownBits of SelectInst. I use Cond from SelectInst to estimate the KnownBits of the variables that make up Cond. The key is that if we can use it to estimate the KnownBits of Trueval, there is a difference in performing the optimization.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see what you mean. This seems very similar to #88298, I think generally this code would fit better as an extension to foldSelectValueEquivilence.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still feel like the code fits better in foldSelectValueEquivilence

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@goldsteinn
I changed it to call inside foldSelectValueEquivalence.

The order is changed so that the optimizations I added are performed first, preventing the previously performed optimizations from being performed. If this happens, should I change the order of the other optimization?

//
// Consider:
// %or = or i32 %x, %y
// %or0 = icmp eq i32 %or, 0
// %and = and i32 %x, %y
// %cond = select i1 %or0, i32 %and, i32 %or
// ret i32 %cond
//
// Expect:
// %or = or i32 %x, %y
// ret i32 %or
//
// We could know what bit was enabled for %x, %y by ICmpInst in SelectInst.
static Instruction *foldSelectICmpBinOp(SelectInst &SI, ICmpInst *ICI,
Value *CmpLHS, Value *CmpRHS,
Value *TVal, Value *FVal,
InstCombinerImpl &IC) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would think you could want to call this twice w/ TVal/FVal swapped and an Invert bool

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I didn't understood what is your point. would you explain more detail for newbie?

as I know, swap of true and false need to swap compare also. so, we test "equal" compare changed to "not equal" for cases of each. and we also test exchange %x, %y.

this is added tests with this commit:

define i32 @src_or_disjoint_xor(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_disjoint_xor(
; CHECK-NEXT:  entry:
; CHECK-NEXT:    ret i32 -1
;
entry:
  %or.disjoint = or disjoint i32 %x, %y
  %cmp = icmp eq i32 %or.disjoint, -1
  %xor = xor i32 %x, %y
  %cond = select i1 %cmp, i32 %xor, i32 -1
  ret i32 %cond
}

define i32 @src_or_disjoint_xor_comm(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_disjoint_xor_comm(
; CHECK-NEXT:  entry:
; CHECK-NEXT:    ret i32 -1
;
entry:
  %or.disjoint = or disjoint i32 %y, %x
  %cmp = icmp eq i32 -1, %or.disjoint
  %xor = xor i32 %y, %x
  %cond = select i1 %cmp, i32 %xor, i32 -1
  ret i32 %cond
}

define i32 @src_or_disjoint_xor_ne(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_disjoint_xor_ne(
; CHECK-NEXT:  entry:
; CHECK-NEXT:    ret i32 -1
;
entry:
  %or.disjoint = or disjoint i32 %x, %y
  %cmp = icmp ne i32 %or.disjoint, -1
  %xor = xor i32 %x, %y
  %cond = select i1 %cmp, i32 -1, i32 %xor
  ret i32 %cond
}

define i32 @src_or_disjoint_xor_ne_comm(i32 %x, i32 %y) {
; CHECK-LABEL: @src_or_disjoint_xor_ne_comm(
; CHECK-NEXT:  entry:
; CHECK-NEXT:    ret i32 -1
;
entry:
  %or.disjoint = or disjoint i32 %y, %x
  %cmp = icmp ne i32 %or.disjoint, -1
  %xor = xor i32 %y, %x
  %cond = select i1 %cmp, i32 -1, i32 %xor
  ret i32 %cond
}

if you explain what you considering then it very helpful for me to improve knowledge of llvm.

Value *X, *Y;
const APInt *C;
unsigned CmpLHSOpc;
bool IsDisjoint = false;
// Specially handling for X^Y==0 transformed to X==Y
if (match(TVal, m_c_BitwiseLogic(m_Specific(CmpLHS), m_Specific(CmpRHS)))) {
X = CmpLHS;
Y = CmpRHS;
APInt ZeroVal = APInt::getZero(CmpLHS->getType()->getScalarSizeInBits());
C = const_cast<APInt *>(&ZeroVal);
CmpLHSOpc = Instruction::Xor;
} else if ((match(CmpLHS, m_BinOp(m_Value(X), m_Value(Y))) &&
match(CmpRHS, m_APInt(C))) &&
(match(TVal, m_c_BinOp(m_Specific(X), m_Value())) ||
match(TVal, m_c_BinOp(m_Specific(Y), m_Value())))) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaict, these should only be matching bitwise logic

if (auto Inst = dyn_cast<PossiblyDisjointInst>(CmpLHS)) {
if (Inst->isDisjoint())
IsDisjoint = true;
CmpLHSOpc = Instruction::Or;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you at the very least assert that this is correct. not an issue now, but we may just disjoint on other instructions in the future.

} else
CmpLHSOpc = cast<BinaryOperator>(CmpLHS)->getOpcode();
} else
return nullptr;

enum SpecialKnownBits {
NothingSpecial = 0,
NoCommonBits = 1 << 1,
AllCommonBits = 1 << 2,
AllBitsEnabled = 1 << 3,
};

// We cannot know exactly what bits is known in X Y.
// Instead, we just know what relationship exist for.
auto isSpecialKnownBitsFor = [&]() -> unsigned {
if (CmpLHSOpc == Instruction::And) {
if (C->isZero())
return NoCommonBits;
} else if (CmpLHSOpc == Instruction::Xor) {
if (C->isAllOnes())
return NoCommonBits | AllBitsEnabled;
if (C->isZero())
return AllCommonBits;
} else if (CmpLHSOpc == Instruction::Or && IsDisjoint) {
if (C->isAllOnes())
return NoCommonBits | AllBitsEnabled;
return NoCommonBits;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or disjoint you can do NoCommonBits and with isAllOnes you can do AllBitsEnabled.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(x | y == -1) you can do AllBitsEnabled irrelivant of disjoint


return NothingSpecial;
};

auto hasOperandAt = [&](Instruction *I, Value *Op) -> int {
for (unsigned Idx = 0; Idx < I->getNumOperands(); Idx++) {
if (I->getOperand(Idx) == Op)
return Idx + 1;
}
return 0;
};

Type *TValTy = TVal->getType();
unsigned BitWidth = TVal->getType()->getScalarSizeInBits();
auto TValBop = cast<BinaryOperator>(TVal);
unsigned XOrder = hasOperandAt(TValBop, X);
unsigned YOrder = hasOperandAt(TValBop, Y);
unsigned SKB = isSpecialKnownBitsFor();

KnownBits Known;
if (TValBop->isBitwiseLogicOp()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you only support BitWise, can you just early return if not bitwise to reduce the level of nested ifs.

// We handle if we know specific knownbits from cond of selectinst.
// ex) X&Y==-1 ? X^Y : False
if (SKB != SpecialKnownBits::NothingSpecial && XOrder && YOrder) {
// No common bits between X, Y
if (SKB & SpecialKnownBits::NoCommonBits) {
if (SKB & (SpecialKnownBits::AllBitsEnabled)) {
// If X op Y == -1, then XOR must be -1
if (TValBop->getOpcode() == Instruction::Xor)
Known = KnownBits::makeConstant(APInt(BitWidth, -1));
}
// If Trueval is X&Y then it should be 0.
if (TValBop->getOpcode() == Instruction::And)
Known = KnownBits::makeConstant(APInt(BitWidth, 0));
// X|Y can be replace with X^Y, X^Y can be replace with X|Y
// This replacing is meaningful when falseval is same.
else if ((match(TVal, m_c_Or(m_Specific(X), m_Specific(Y))) &&
match(FVal, m_c_Xor(m_Specific(X), m_Specific(Y)))) ||
(match(TVal, m_c_Xor(m_Specific(X), m_Specific(Y))) &&
match(FVal, m_c_Or(m_Specific(X), m_Specific(Y)))))
return IC.replaceInstUsesWith(SI, FVal);
// All common bits between X, Y
} else if (SKB & SpecialKnownBits::AllCommonBits) {
// We can replace (X&Y) and (X|Y) to X or Y
if (TValBop->getOpcode() == Instruction::And ||
TValBop->getOpcode() == Instruction::Or)
if (TValBop->hasOneUse())
return IC.replaceOperand(SI, 1, X);
} else if (SKB & SpecialKnownBits::AllBitsEnabled) {
// We can replace (X|Y) to -1
if (TValBop->getOpcode() == Instruction::Or)
Known = KnownBits::makeConstant(APInt(BitWidth, -1));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would organize this code a bit differently, instead of going through the knownbits cases and then applying rules to ops, I would switch on TValBop->getOpcode() then apply the knownbits cases that apply.

Think that will be easier to follow/extend.

}
} else {
KnownBits XKnown, YKnown, Temp;
KnownBits TValBop0KB, TValBop1KB;
// computeKnowBits calculates the KnownBits in the branching condition
// that the specified variable passes in the execution flow. however, it
// does not contain the SelectInst condition, so there is an optimization
// opportunity to update the knownbits obtained by calculating KnownBits
// with the SelectInst condition.
XKnown = IC.computeKnownBits(X, 0, &SI);
IC.computeKnownBitsFromCond(X, ICI, XKnown, 0, &SI, false);
YKnown = IC.computeKnownBits(Y, 0, &SI);
IC.computeKnownBitsFromCond(Y, ICI, YKnown, 0, &SI, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment constants here and above.

CmpInst::Predicate Pred = ICI->getPredicate();
if (Pred == ICmpInst::ICMP_EQ) {
// Estimate additional KnownBits from the relationship between X and Y
if (CmpLHSOpc == Instruction::And) {
// The bit that are set to 1 at `~C&Y` must be 0 in X
// The bit that are set to 1 at `~C&X` must be 0 in Y
XKnown.Zero |= ~*C & YKnown.One;
YKnown.Zero |= ~*C & XKnown.One;
}
if (CmpLHSOpc == Instruction::Or) {
// The bit that are set to 0 at `C&Y` must be 1 in X
// The bit that are set to 0 at `C&X` must be 1 in Y
XKnown.One |= *C & YKnown.Zero;
YKnown.One |= *C & XKnown.Zero;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think we should be handling the or and and case in computeKnownBitsFromCond, no? Otherwise, think we should fix that there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise for the xor case tbh.

Copy link
Contributor

@goldsteinn goldsteinn Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, we can't really do that w/ current infra. Although I think a better way to do this would be to create a new helper in ValueTracking for computing known x/y under the assumption of a cmp involving x/y is true/false.

edit: The difference from existing APIs is that we have knownx/knowny

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, these discovering way of knownbits is exist in valuetracker before but it removed more recently. but I think it can be used for this commit. so, I'm not have confidence about creating new api for this in ValueTracker.

but it is beginner's opinion, I'll update code soon.

}
if (CmpLHSOpc == Instruction::Xor) {
// If X^Y==C, then X and Y must be either (1,0) or (0,1) for the
// enabled bits in C.
XKnown.One |= *C & YKnown.Zero;
XKnown.Zero |= *C & YKnown.One;
YKnown.One |= *C & XKnown.Zero;
YKnown.Zero |= *C & XKnown.One;
// If X^Y==C, then X and Y must be either (0,0) or (1,1) for the
// disabled bits in C.
XKnown.Zero |= ~*C & YKnown.Zero;
XKnown.One |= ~*C & YKnown.One;
YKnown.Zero |= ~*C & XKnown.Zero;
YKnown.One |= ~*C & XKnown.One;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't understand what you are going for here...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when we handle selectInst which have cond as x ^ y == c, we could discovering knownbits from cond.

explain,
a. XKnown.One |= *C & YKnown.Zero;
if same bits of C and Y.Zero was enabled, then X.One must be enabled. X = 1, Y = 0, C = 1. X ^ Y == C

b. XKnown.Zero |= *C & YKnown.One;
likely a. if same bits of C and Y.One was enabled, then same bit of X must be 0.

after we discovering it, we can use it usefully to trueval.

this is one of test which contains in pull-request.

define i8 @src_xor_bit(i8 %x, i8 %y) {
; CHECK-LABEL: @src_xor_bit(
; CHECK-NEXT:    [[AND:%.*]] = and i8 [[Y:%.*]], 12
; CHECK-NEXT:    [[XOR:%.*]] = xor i8 [[AND]], [[X:%.*]]
; CHECK-NEXT:    [[CMP:%.*]] = icmp eq i8 [[XOR]], 3
; CHECK-NEXT:    [[COND:%.*]] = select i1 [[CMP]], i8 3, i8 1
; CHECK-NEXT:    ret i8 [[COND]]
;
  %and = and i8 %y, 12
  %xor = xor i8 %and, %x
  %cmp = icmp eq i8 %xor, 3
  %and1 = and i8 %x, 3
  %cond = select i1 %cmp, i8 %and1, i8 1
  ret i8 %cond
}

we have cond as (Y & 12) ^ X == 3 in selectinst.
we could discover knownbits of X following steps:
(Y & 12)'s KnownBits = 0000??00(1) 1111??11(0)
(Y & 12) ^ X == 3 (in trueval)
X.One = 00000011(3) & 1111??11(Y&12)
X.Zero = 00000011(3) & 0000??00(Y&12)

so, We can treat X as 3 in trueval of this selectinst ,and it can possible replacing %and1 to 3 as you can see.

}
}

// If TrueVal has X or Y, return the corresponding KnownBits, otherwise
// compute and return new KnownBits.
auto getTValBopKB = [&](unsigned OpNum) -> KnownBits {
unsigned Order = OpNum + 1;
if (Order == XOrder)
return XKnown;
else if (Order == YOrder)
return YKnown;

Value *V = TValBop->getOperand(OpNum);
KnownBits Known = IC.computeKnownBits(V, 0, &SI);
return Known;
};
TValBop0KB = getTValBopKB(0);
TValBop1KB = getTValBopKB(1);
Known = analyzeKnownBitsFromAndXorOr(
cast<Operator>(TValBop), TValBop0KB, TValBop1KB, 0,
IC.getSimplifyQuery().getWithInstruction(&SI));
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a desperate need for comments in the above code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update soon!


if (Known.isConstant()) {
auto Const = ConstantInt::get(TValTy, Known.getConstant());
return IC.replaceOperand(SI, 1, Const);
}

return nullptr;
}

/// Visit a SelectInst that has an ICmpInst as its first operand.
Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
ICmpInst *ICI) {
Expand Down Expand Up @@ -1951,6 +2143,10 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI,
if (Value *V = foldAbsDiff(ICI, TrueVal, FalseVal, Builder))
return replaceInstUsesWith(SI, V);

if (Instruction *NewSel = foldSelectICmpBinOp(SI, ICI, CmpLHS, CmpRHS,
TrueVal, FalseVal, *this))
return NewSel;

return Changed ? &SI : nullptr;
}

Expand Down
Loading
Loading