Skip to content

Commit d53eed3

Browse files
committed
[DAGCombine] Fold icmp with chain of or of loads
Given a `icmp eq/ne or(..), 0`, it is only checking that some of the bits are set. Given chains of ors of loads that are offset from one another, we can convert the loads to a single larger load.
1 parent e9702ce commit d53eed3

File tree

2 files changed

+99
-25
lines changed

2 files changed

+99
-25
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9551,6 +9551,90 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
95519551
return DAG.getNode(ISD::BSWAP, SDLoc(N), VT, ShiftedLoad);
95529552
}
95539553

9554+
// Try to find a tree of or's with leafs that are all loads that are offset from
9555+
// the same base, and can be combined to a single larger load.
9556+
static SDValue MatchOrOfLoadToLargeLoad(SDValue Root, SelectionDAG &DAG,
9557+
const TargetLowering &TLI) {
9558+
EVT VT = Root.getValueType();
9559+
SmallVector<SDValue> Worklist;
9560+
Worklist.push_back(Root);
9561+
SmallVector<std::pair<LoadSDNode *, int64_t>> Loads;
9562+
std::optional<BaseIndexOffset> Base;
9563+
LoadSDNode *BaseLoad = nullptr;
9564+
9565+
// Check up the chain of or instructions with loads at the end.
9566+
while (!Worklist.empty()) {
9567+
SDValue V = Worklist.pop_back_val();
9568+
if (!V.hasOneUse())
9569+
return SDValue();
9570+
if (V.getOpcode() == ISD::OR) {
9571+
Worklist.push_back(V.getOperand(0));
9572+
Worklist.push_back(V.getOperand(1));
9573+
} else if (V.getOpcode() == ISD::ZERO_EXTEND ||
9574+
V.getOpcode() == ISD::SIGN_EXTEND) {
9575+
Worklist.push_back(V.getOperand(0));
9576+
} else if (V.getOpcode() == ISD::LOAD) {
9577+
LoadSDNode *Ld = cast<LoadSDNode>(V.getNode());
9578+
if (!Ld->isSimple() || Ld->getMemoryVT().getSizeInBits() % 8 != 0)
9579+
return SDValue();
9580+
9581+
BaseIndexOffset Ptr = BaseIndexOffset::match(Ld, DAG);
9582+
int64_t ByteOffsetFromBase = 0;
9583+
if (!Base) {
9584+
Base = Ptr;
9585+
BaseLoad = Ld;
9586+
} else if (BaseLoad->getChain() != Ld->getChain() ||
9587+
!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase))
9588+
return SDValue();
9589+
Loads.push_back({Ld, ByteOffsetFromBase});
9590+
} else {
9591+
return SDValue();
9592+
}
9593+
}
9594+
9595+
// Sort nodes by increasing ByteOffsetFromBase
9596+
llvm::sort(Loads, [](auto &A, auto &B) { return A.second < B.second; });
9597+
Base = BaseIndexOffset::match(Loads[0].first, DAG);
9598+
9599+
// Check that they are all adjacent in memory
9600+
int64_t BaseOffset = 0;
9601+
for (unsigned I = 0; I < Loads.size(); ++I) {
9602+
int64_t Offset = Loads[I].second - Loads[0].second;
9603+
if (Offset != BaseOffset)
9604+
return SDValue();
9605+
BaseOffset += Loads[I].first->getMemoryVT().getSizeInBits() / 8;
9606+
}
9607+
9608+
uint64_t MemSize =
9609+
Loads[Loads.size() - 1].second - Loads[0].second +
9610+
Loads[Loads.size() - 1].first->getMemoryVT().getSizeInBits() / 8;
9611+
if (!isPowerOf2_64(MemSize) || MemSize * 8 > VT.getSizeInBits())
9612+
return SDValue();
9613+
EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), MemSize * 8);
9614+
9615+
bool NeedsZext = VT.bitsGT(MemVT);
9616+
if (!TLI.isLoadExtLegal(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, VT,
9617+
MemVT))
9618+
return SDValue();
9619+
9620+
unsigned Fast = 0;
9621+
bool Allowed =
9622+
TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), MemVT,
9623+
*Loads[0].first->getMemOperand(), &Fast);
9624+
if (!Allowed || !Fast)
9625+
return SDValue();
9626+
9627+
SDValue NewLoad = DAG.getExtLoad(
9628+
NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(Root), VT,
9629+
Loads[0].first->getChain(), Loads[0].first->getBasePtr(),
9630+
Loads[0].first->getPointerInfo(), MemVT, Loads[0].first->getAlign());
9631+
9632+
// Transfer chain users from old loads to the new load.
9633+
for (auto &L : Loads)
9634+
DAG.makeEquivalentMemoryOrdering(L.first, NewLoad);
9635+
return NewLoad;
9636+
}
9637+
95549638
// If the target has andn, bsl, or a similar bit-select instruction,
95559639
// we want to unfold masked merge, with canonical pattern of:
95569640
// | A | |B|
@@ -28654,7 +28738,15 @@ SDValue DAGCombiner::SimplifySetCC(EVT VT, SDValue N0, SDValue N1,
2865428738
bool foldBooleans) {
2865528739
TargetLowering::DAGCombinerInfo
2865628740
DagCombineInfo(DAG, Level, false, this);
28657-
return TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL);
28741+
if (SDValue C =
28742+
TLI.SimplifySetCC(VT, N0, N1, Cond, foldBooleans, DagCombineInfo, DL))
28743+
return C;
28744+
28745+
if ((Cond == ISD::SETNE || Cond == ISD::SETEQ) && isNullConstant(N1) &&
28746+
N0.getOpcode() == ISD::OR)
28747+
if (SDValue Load = MatchOrOfLoadToLargeLoad(N0, DAG, TLI))
28748+
return DAG.getSetCC(DL, VT, Load, N1, Cond);
28749+
return SDValue();
2865828750
}
2865928751

2866028752
/// Given an ISD::SDIV node expressing a divide by constant, return

llvm/test/CodeGen/AArch64/icmp-or-load.ll

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
define i1 @loadzext_i8i8(ptr %p) {
55
; CHECK-LABEL: loadzext_i8i8:
66
; CHECK: // %bb.0:
7-
; CHECK-NEXT: ldrb w8, [x0]
8-
; CHECK-NEXT: ldrb w9, [x0, #1]
9-
; CHECK-NEXT: orr w8, w8, w9
7+
; CHECK-NEXT: ldrh w8, [x0]
108
; CHECK-NEXT: cmp w8, #0
119
; CHECK-NEXT: cset w0, eq
1210
; CHECK-NEXT: ret
@@ -23,9 +21,7 @@ define i1 @loadzext_i8i8(ptr %p) {
2321
define i1 @loadzext_c_i8i8(ptr %p) {
2422
; CHECK-LABEL: loadzext_c_i8i8:
2523
; CHECK: // %bb.0:
26-
; CHECK-NEXT: ldrb w8, [x0]
27-
; CHECK-NEXT: ldrb w9, [x0, #1]
28-
; CHECK-NEXT: orr w8, w9, w8
24+
; CHECK-NEXT: ldrh w8, [x0]
2925
; CHECK-NEXT: cmp w8, #0
3026
; CHECK-NEXT: cset w0, eq
3127
; CHECK-NEXT: ret
@@ -85,13 +81,7 @@ define i1 @loadzext_i8i8i8(ptr %p) {
8581
define i1 @loadzext_i8i8i8i8(ptr %p) {
8682
; CHECK-LABEL: loadzext_i8i8i8i8:
8783
; CHECK: // %bb.0:
88-
; CHECK-NEXT: ldrb w8, [x0]
89-
; CHECK-NEXT: ldrb w9, [x0, #1]
90-
; CHECK-NEXT: ldrb w10, [x0, #2]
91-
; CHECK-NEXT: ldrb w11, [x0, #3]
92-
; CHECK-NEXT: orr w8, w8, w9
93-
; CHECK-NEXT: orr w9, w10, w11
94-
; CHECK-NEXT: orr w8, w8, w9
84+
; CHECK-NEXT: ldr w8, [x0]
9585
; CHECK-NEXT: cmp w8, #0
9686
; CHECK-NEXT: cset w0, eq
9787
; CHECK-NEXT: ret
@@ -116,9 +106,7 @@ define i1 @loadzext_i8i8i8i8(ptr %p) {
116106
define i1 @load_i8i8(ptr %p) {
117107
; CHECK-LABEL: load_i8i8:
118108
; CHECK: // %bb.0:
119-
; CHECK-NEXT: ldrb w8, [x0]
120-
; CHECK-NEXT: ldrb w9, [x0, #1]
121-
; CHECK-NEXT: orr w8, w8, w9
109+
; CHECK-NEXT: ldrh w8, [x0]
122110
; CHECK-NEXT: cmp w8, #0
123111
; CHECK-NEXT: cset w0, eq
124112
; CHECK-NEXT: ret
@@ -133,9 +121,7 @@ define i1 @load_i8i8(ptr %p) {
133121
define i1 @load_i16i16(ptr %p) {
134122
; CHECK-LABEL: load_i16i16:
135123
; CHECK: // %bb.0:
136-
; CHECK-NEXT: ldrh w8, [x0]
137-
; CHECK-NEXT: ldrh w9, [x0, #2]
138-
; CHECK-NEXT: orr w8, w8, w9
124+
; CHECK-NEXT: ldr w8, [x0]
139125
; CHECK-NEXT: cmp w8, #0
140126
; CHECK-NEXT: cset w0, eq
141127
; CHECK-NEXT: ret
@@ -182,11 +168,7 @@ define i1 @load_i64i64(ptr %p) {
182168
define i1 @load_i8i16i8(ptr %p) {
183169
; CHECK-LABEL: load_i8i16i8:
184170
; CHECK: // %bb.0:
185-
; CHECK-NEXT: ldrb w8, [x0]
186-
; CHECK-NEXT: ldrb w9, [x0, #3]
187-
; CHECK-NEXT: ldurh w10, [x0, #1]
188-
; CHECK-NEXT: orr w8, w8, w9
189-
; CHECK-NEXT: orr w8, w8, w10
171+
; CHECK-NEXT: ldr w8, [x0]
190172
; CHECK-NEXT: cmp w8, #0
191173
; CHECK-NEXT: cset w0, eq
192174
; CHECK-NEXT: ret

0 commit comments

Comments
 (0)