Skip to content

[DAGCombine] Fold icmp with chain of or of loads #139165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 92 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9551,6 +9551,89 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
}

// Try to find a tree of or's with leafs that are all loads that are offset from
// the same base, and can be combined to a single larger load.
static SDValue matchOrOfLoadToLargeLoad(SDValue Root, SelectionDAG &DAG,
const TargetLowering &TLI) {
EVT VT = Root.getValueType();
SmallVector<SDValue> Worklist = {Root};
SmallVector<std::pair<LoadSDNode *, int64_t>> Loads;
std::optional<BaseIndexOffset> Base;
LoadSDNode *BaseLoad = nullptr;

// Check up the chain of or instructions with loads at the end.
while (!Worklist.empty()) {
SDValue V = Worklist.pop_back_val();
if (!V.hasOneUse())
return SDValue();
if (V.getOpcode() == ISD::OR) {
Worklist.push_back(V.getOperand(0));
Worklist.push_back(V.getOperand(1));
} else if (V.getOpcode() == ISD::ZERO_EXTEND ||
V.getOpcode() == ISD::SIGN_EXTEND) {
Worklist.push_back(V.getOperand(0));
} else if (V.getOpcode() == ISD::LOAD) {
LoadSDNode *Ld = cast<LoadSDNode>(V.getNode());
if (!Ld->isSimple() || Ld->getMemoryVT().getSizeInBits() % 8 != 0)
return SDValue();

BaseIndexOffset Ptr = BaseIndexOffset::match(Ld, DAG);
int64_t ByteOffsetFromBase = 0;
if (!Base) {
Base = Ptr;
BaseLoad = Ld;
} else if (BaseLoad->getChain() != Ld->getChain() ||
!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
return SDValue();
Loads.push_back({Ld, ByteOffsetFromBase});
} else {
return SDValue();
}
}

// Sort nodes by increasing ByteOffsetFromBase
llvm::sort(Loads, [](auto &A, auto &B) { return A.second < B.second; });
Base = BaseIndexOffset::match(Loads[0].first, DAG);

// Check that they are all adjacent in memory
int64_t BaseOffset = 0;
for (unsigned I = 0; I < Loads.size(); ++I) {
int64_t Offset = Loads[I].second - Loads[0].second;
if (Offset != BaseOffset)
return SDValue();
BaseOffset += Loads[I].first->getMemoryVT().getSizeInBits() / 8;
}

uint64_t MemSize =
Loads[Loads.size() - 1].second - Loads[0].second +
Loads[Loads.size() - 1].first->getMemoryVT().getSizeInBits() / 8;
if (!isPowerOf2_64(MemSize) || MemSize * 8 > VT.getSizeInBits())
return SDValue();
EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), MemSize * 8);

bool NeedsZext = VT.bitsGT(MemVT);
if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
MemVT))
return SDValue();

unsigned Fast = 0;
bool Allowed =
TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
*Loads[0].first->getMemOperand(), &Fast);
if (!Allowed || !Fast)
return SDValue();

SDValue NewLoad = DAG.getExtLoad(
NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(Root), VT,
Loads[0].first->getChain(), Loads[0].first->getBasePtr(),
Loads[0].first->getPointerInfo(), MemVT, Loads[0].first->getAlign());

// Transfer chain users from old loads to the new load.
for (auto &L : Loads)
DAG.makeEquivalentMemoryOrdering(L.first, NewLoad);
return NewLoad;
}

// If the target has andn, bsl, or a similar bit-select instruction,
// we want to unfold masked merge, with canonical pattern of:
// | A | |B|
Expand Down Expand Up @@ -28649,7 +28732,15 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
bool foldBooleans) {
TargetLowering::DAGCombinerInfo
DagCombineInfo(DAG, Level, false, this);
return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
if (SDValue C =
TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL))
return C;

if ((Cond == ISD::SETNE || Cond == ISD::SETEQ) && N0.getOpcode() == ISD::OR &&
isNullConstant(N1))
if (SDValue Load = matchOrOfLoadToLargeLoad(N0, DAG, TLI))
return DAG.getSetCC(DL, VT, Load, N1, Cond);
return SDValue();
}

/// Given an ISD::SDIV node expressing a divide by constant, return
Expand Down
30 changes: 6 additions & 24 deletions llvm/test/CodeGen/AArch64/icmp-or-load.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
define i1 @loadzext_i8i8(ptr %p) {
; CHECK-LABEL: loadzext_i8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrb w8, [x0]
; CHECK-NEXT: ldrb w9, [x0, #1]
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand All @@ -23,9 +21,7 @@ define i1 @loadzext_i8i8(ptr %p) {
define i1 @loadzext_c_i8i8(ptr %p) {
; CHECK-LABEL: loadzext_c_i8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrb w8, [x0]
; CHECK-NEXT: ldrb w9, [x0, #1]
; CHECK-NEXT: orr w8, w9, w8
; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand Down Expand Up @@ -85,13 +81,7 @@ define i1 @loadzext_i8i8i8(ptr %p) {
define i1 @loadzext_i8i8i8i8(ptr %p) {
; CHECK-LABEL: loadzext_i8i8i8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrb w8, [x0]
; CHECK-NEXT: ldrb w9, [x0, #1]
; CHECK-NEXT: ldrb w10, [x0, #2]
; CHECK-NEXT: ldrb w11, [x0, #3]
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: orr w9, w10, w11
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand All @@ -116,9 +106,7 @@ define i1 @loadzext_i8i8i8i8(ptr %p) {
define i1 @load_i8i8(ptr %p) {
; CHECK-LABEL: load_i8i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrb w8, [x0]
; CHECK-NEXT: ldrb w9, [x0, #1]
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand All @@ -133,9 +121,7 @@ define i1 @load_i8i8(ptr %p) {
define i1 @load_i16i16(ptr %p) {
; CHECK-LABEL: load_i16i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrh w8, [x0]
; CHECK-NEXT: ldrh w9, [x0, #2]
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand Down Expand Up @@ -182,11 +168,7 @@ define i1 @load_i64i64(ptr %p) {
define i1 @load_i8i16i8(ptr %p) {
; CHECK-LABEL: load_i8i16i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldrb w8, [x0]
; CHECK-NEXT: ldrb w9, [x0, #3]
; CHECK-NEXT: ldurh w10, [x0, #1]
; CHECK-NEXT: orr w8, w8, w9
; CHECK-NEXT: orr w8, w8, w10
; CHECK-NEXT: ldr w8, [x0]
; CHECK-NEXT: cmp w8, #0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
Expand Down
Loading