Skip to content

Commit e0ac087

Browse files
authored
[LoopUnroll] Consider convergence control tokens when unrolling (#91715)
- There is no restriction on a loop with controlled convergent operations when the relevant tokens are defined and used within the loop. - When a token defined outside a loop is used inside (also called a loop convergence heart), unrolling is allowed only in the absence of remainder or runtime checks. - When a token defined inside a loop is used outside, such a loop is said to be "extended". This loop can only be unrolled by also duplicating the extended part lying outside the loop. Such unrolling is disabled for now. - Clean up loop hearts: When unrolling a loop with a heart, duplicating the heart will introduce multiple static uses of a convergence control token in a cycle that does not contain its definition. This violates the static rules for tokens, and needs to be cleaned up into a single occurrence of the intrinsic. - Spell out the initializer for UnrollLoopOptions to improve readability. Original implementation [D85605] by Nicolai Haehnle <[email protected]>.
1 parent 16e2ec8 commit e0ac087

File tree

15 files changed

+748
-69
lines changed

15 files changed

+748
-69
lines changed

llvm/include/llvm/Analysis/CodeMetrics.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
namespace llvm {
2121
class AssumptionCache;
2222
class BasicBlock;
23+
class Instruction;
2324
class Loop;
2425
class Function;
2526
template <class T> class SmallPtrSetImpl;
2627
class TargetTransformInfo;
2728
class Value;
2829

30+
enum struct ConvergenceKind { None, Controlled, ExtendedLoop, Uncontrolled };
31+
2932
/// Utility to calculate the size and a few similar metrics for a set
3033
/// of basic blocks.
3134
struct CodeMetrics {
@@ -42,8 +45,8 @@ struct CodeMetrics {
4245
/// one or more 'noduplicate' instructions.
4346
bool notDuplicatable = false;
4447

45-
/// True if this function contains a call to a convergent function.
46-
bool convergent = false;
48+
/// The kind of convergence specified in this function.
49+
ConvergenceKind Convergence = ConvergenceKind::None;
4750

4851
/// True if this function calls alloca (in the C sense).
4952
bool usesDynamicAlloca = false;
@@ -77,7 +80,7 @@ struct CodeMetrics {
7780
/// Add information about a block to the current state.
7881
void analyzeBasicBlock(const BasicBlock *BB, const TargetTransformInfo &TTI,
7982
const SmallPtrSetImpl<const Value *> &EphValues,
80-
bool PrepareForLTO = false);
83+
bool PrepareForLTO = false, const Loop *L = nullptr);
8184

8285
/// Collect a loop's ephemeral values (those used only by an assume
8386
/// or similar intrinsics in the loop).

llvm/include/llvm/Analysis/LoopInfo.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,9 @@ int getIntLoopAttribute(const Loop *TheLoop, StringRef Name, int Default = 0);
649649
std::optional<const MDOperand *> findStringMetadataForLoop(const Loop *TheLoop,
650650
StringRef Name);
651651

652+
/// Find the convergence heart of the loop.
653+
CallBase *getLoopConvergenceHeart(const Loop *TheLoop);
654+
652655
/// Look for the loop attribute that requires progress within the loop.
653656
/// Note: Most consumers probably want "isMustProgress" which checks
654657
/// the containing function attribute too.

llvm/include/llvm/IR/InstrTypes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,6 +1588,14 @@ class CallBase : public Instruction {
15881588
static CallBase *removeOperandBundle(CallBase *CB, uint32_t ID,
15891589
BasicBlock::iterator InsertPt);
15901590

1591+
/// Return the convergence control token for this call, if it exists.
1592+
Value *getConvergenceControlToken() const {
1593+
if (auto Bundle = getOperandBundle(llvm::LLVMContext::OB_convergencectrl)) {
1594+
return Bundle->Inputs[0].get();
1595+
}
1596+
return nullptr;
1597+
}
1598+
15911599
static bool classof(const Instruction *I) {
15921600
return I->getOpcode() == Instruction::Call ||
15931601
I->getOpcode() == Instruction::Invoke ||

llvm/include/llvm/IR/IntrinsicInst.h

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1799,17 +1799,14 @@ class ConvergenceControlInst : public IntrinsicInst {
17991799
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
18001800
}
18011801

1802-
// Returns the convergence intrinsic referenced by |I|'s convergencectrl
1803-
// attribute if any.
1804-
static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
1805-
auto *CI = dyn_cast<llvm::CallInst>(I);
1806-
if (!CI)
1807-
return nullptr;
1808-
1809-
auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
1810-
assert(Bundle->Inputs.size() == 1 &&
1811-
Bundle->Inputs[0]->getType()->isTokenTy());
1812-
return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
1802+
bool isAnchor() {
1803+
return getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
1804+
}
1805+
bool isEntry() {
1806+
return getIntrinsicID() == Intrinsic::experimental_convergence_entry;
1807+
}
1808+
bool isLoop() {
1809+
return getIntrinsicID() == Intrinsic::experimental_convergence_loop;
18131810
}
18141811
};
18151812

llvm/include/llvm/Transforms/Utils/UnrollLoop.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#define LLVM_TRANSFORMS_UTILS_UNROLLLOOP_H
1717

1818
#include "llvm/ADT/DenseMap.h"
19+
#include "llvm/Analysis/CodeMetrics.h"
1920
#include "llvm/Analysis/TargetTransformInfo.h"
2021
#include "llvm/Support/InstructionCost.h"
2122

@@ -73,6 +74,7 @@ struct UnrollLoopOptions {
7374
bool AllowExpensiveTripCount;
7475
bool UnrollRemainder;
7576
bool ForgetAllSCEV;
77+
const Instruction *Heart = nullptr;
7678
};
7779

7880
LoopUnrollResult UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
@@ -128,14 +130,15 @@ class UnrollCostEstimator {
128130

129131
public:
130132
unsigned NumInlineCandidates;
131-
bool Convergent;
133+
ConvergenceKind Convergence;
134+
bool ConvergenceAllowsRuntime;
132135

133136
UnrollCostEstimator(const Loop *L, const TargetTransformInfo &TTI,
134137
const SmallPtrSetImpl<const Value *> &EphValues,
135138
unsigned BEInsns);
136139

137140
/// Whether it is legal to unroll this loop.
138-
bool canUnroll() const { return LoopSize.isValid() && !NotDuplicatable; }
141+
bool canUnroll() const;
139142

140143
uint64_t getRolledLoopSize() const { return *LoopSize.getValue(); }
141144

llvm/lib/Analysis/CodeMetrics.cpp

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "llvm/Analysis/LoopInfo.h"
1717
#include "llvm/Analysis/TargetTransformInfo.h"
1818
#include "llvm/IR/Function.h"
19+
#include "llvm/IR/IntrinsicInst.h"
1920
#include "llvm/Support/Debug.h"
2021
#include "llvm/Support/InstructionCost.h"
2122

@@ -111,11 +112,24 @@ void CodeMetrics::collectEphemeralValues(
111112
completeEphemeralValues(Visited, Worklist, EphValues);
112113
}
113114

115+
static bool extendsConvergenceOutsideLoop(const Instruction &I, const Loop *L) {
116+
if (!L)
117+
return false;
118+
if (!isa<ConvergenceControlInst>(I))
119+
return false;
120+
for (const auto *U : I.users()) {
121+
if (!L->contains(cast<Instruction>(U)))
122+
return true;
123+
}
124+
return false;
125+
}
126+
114127
/// Fill in the current structure with information gleaned from the specified
115128
/// block.
116129
void CodeMetrics::analyzeBasicBlock(
117130
const BasicBlock *BB, const TargetTransformInfo &TTI,
118-
const SmallPtrSetImpl<const Value *> &EphValues, bool PrepareForLTO) {
131+
const SmallPtrSetImpl<const Value *> &EphValues, bool PrepareForLTO,
132+
const Loop *L) {
119133
++NumBlocks;
120134
InstructionCost NumInstsBeforeThisBB = NumInsts;
121135
for (const Instruction &I : *BB) {
@@ -163,19 +177,38 @@ void CodeMetrics::analyzeBasicBlock(
163177
if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy())
164178
++NumVectorInsts;
165179

166-
if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB))
180+
if (I.getType()->isTokenTy() && !isa<ConvergenceControlInst>(I) &&
181+
I.isUsedOutsideOfBlock(BB)) {
182+
LLVM_DEBUG(dbgs() << I
183+
<< "\n Cannot duplicate a token value used outside "
184+
"the current block (except convergence control).\n");
167185
notDuplicatable = true;
168-
169-
if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
170-
if (CI->cannotDuplicate())
171-
notDuplicatable = true;
172-
if (CI->isConvergent())
173-
convergent = true;
174186
}
175187

176-
if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I))
177-
if (InvI->cannotDuplicate())
188+
if (const CallBase *CB = dyn_cast<CallBase>(&I)) {
189+
if (CB->cannotDuplicate())
178190
notDuplicatable = true;
191+
// Compute a meet over the visited blocks for the following partial order:
192+
//
193+
// None -> { Controlled, ExtendedLoop, Uncontrolled}
194+
// Controlled -> ExtendedLoop
195+
if (Convergence <= ConvergenceKind::Controlled && CB->isConvergent()) {
196+
if (isa<ConvergenceControlInst>(CB) ||
197+
CB->getConvergenceControlToken()) {
198+
assert(Convergence != ConvergenceKind::Uncontrolled);
199+
LLVM_DEBUG(dbgs() << "Found controlled convergence:\n" << I << "\n");
200+
if (extendsConvergenceOutsideLoop(I, L))
201+
Convergence = ConvergenceKind::ExtendedLoop;
202+
else {
203+
assert(Convergence != ConvergenceKind::ExtendedLoop);
204+
Convergence = ConvergenceKind::Controlled;
205+
}
206+
} else {
207+
assert(Convergence == ConvergenceKind::None);
208+
Convergence = ConvergenceKind::Uncontrolled;
209+
}
210+
}
211+
}
179212

180213
NumInsts += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
181214
}

llvm/lib/Analysis/LoopInfo.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,26 @@ int llvm::getIntLoopAttribute(const Loop *TheLoop, StringRef Name,
11051105
return getOptionalIntLoopAttribute(TheLoop, Name).value_or(Default);
11061106
}
11071107

1108+
CallBase *llvm::getLoopConvergenceHeart(const Loop *TheLoop) {
1109+
BasicBlock *H = TheLoop->getHeader();
1110+
for (Instruction &II : *H) {
1111+
if (auto *CB = dyn_cast<CallBase>(&II)) {
1112+
if (!CB->isConvergent())
1113+
continue;
1114+
// This is the heart if it uses a token defined outside the loop. The
1115+
// verifier has already checked that only the loop intrinsic can use such
1116+
// a token.
1117+
if (auto *Token = CB->getConvergenceControlToken()) {
1118+
auto *TokenDef = cast<Instruction>(Token);
1119+
if (!TheLoop->contains(TokenDef->getParent()))
1120+
return CB;
1121+
}
1122+
return nullptr;
1123+
}
1124+
}
1125+
return nullptr;
1126+
}
1127+
11081128
bool llvm::isFinite(const Loop *L) {
11091129
return L->getHeader()->getParent()->willReturn();
11101130
}

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3961,7 +3961,7 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
39613961
UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
39623962

39633963
// Loop is not unrollable if the loop contains certain instructions.
3964-
if (!UCE.canUnroll() || UCE.Convergent) {
3964+
if (!UCE.canUnroll()) {
39653965
LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
39663966
return 1;
39673967
}

llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,8 @@ struct TransformDFA {
827827
return false;
828828
}
829829

830-
if (Metrics.convergent) {
830+
// FIXME: Allow jump threading with controlled convergence.
831+
if (Metrics.Convergence != ConvergenceKind::None) {
831832
LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
832833
<< "convergent instructions.\n");
833834
ORE->emit([&]() {

llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
327327
UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);
328328

329329
if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
330-
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
331-
<< " which cannot be duplicated or have invalid cost.\n");
330+
LLVM_DEBUG(dbgs() << " Loop not considered unrollable\n");
332331
return LoopUnrollResult::Unmodified;
333332
}
334333

@@ -341,7 +340,10 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
341340
LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
342341
return LoopUnrollResult::Unmodified;
343342
}
344-
if (InnerUCE.Convergent || OuterUCE.Convergent) {
343+
// FIXME: The call to canUnroll() allows some controlled convergent
344+
// operations, but we block them here for future changes.
345+
if (InnerUCE.Convergence != ConvergenceKind::None ||
346+
OuterUCE.Convergence != ConvergenceKind::None) {
345347
LLVM_DEBUG(
346348
dbgs() << " Not unrolling loop with convergent instructions.\n");
347349
return LoopUnrollResult::Unmodified;

llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -684,11 +684,15 @@ UnrollCostEstimator::UnrollCostEstimator(
684684
const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) {
685685
CodeMetrics Metrics;
686686
for (BasicBlock *BB : L->blocks())
687-
Metrics.analyzeBasicBlock(BB, TTI, EphValues);
687+
Metrics.analyzeBasicBlock(BB, TTI, EphValues, /* PrepareForLTO= */ false,
688+
L);
688689
NumInlineCandidates = Metrics.NumInlineCandidates;
689690
NotDuplicatable = Metrics.notDuplicatable;
690-
Convergent = Metrics.convergent;
691+
Convergence = Metrics.Convergence;
691692
LoopSize = Metrics.NumInsts;
693+
ConvergenceAllowsRuntime =
694+
Metrics.Convergence != ConvergenceKind::Uncontrolled &&
695+
!getLoopConvergenceHeart(L);
692696

693697
// Don't allow an estimate of size zero. This would allows unrolling of loops
694698
// with huge iteration counts, which is a compile time problem even if it's
@@ -701,6 +705,25 @@ UnrollCostEstimator::UnrollCostEstimator(
701705
LoopSize = BEInsns + 1;
702706
}
703707

708+
bool UnrollCostEstimator::canUnroll() const {
709+
switch (Convergence) {
710+
case ConvergenceKind::ExtendedLoop:
711+
LLVM_DEBUG(dbgs() << " Convergence prevents unrolling.\n");
712+
return false;
713+
default:
714+
break;
715+
}
716+
if (!LoopSize.isValid()) {
717+
LLVM_DEBUG(dbgs() << " Invalid loop size prevents unrolling.\n");
718+
return false;
719+
}
720+
if (NotDuplicatable) {
721+
LLVM_DEBUG(dbgs() << " Non-duplicatable blocks prevent unrolling.\n");
722+
return false;
723+
}
724+
return true;
725+
}
726+
704727
uint64_t UnrollCostEstimator::getUnrolledLoopSize(
705728
const TargetTransformInfo::UnrollingPreferences &UP,
706729
unsigned CountOverwrite) const {
@@ -1206,8 +1229,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
12061229

12071230
UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
12081231
if (!UCE.canUnroll()) {
1209-
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
1210-
<< " which cannot be duplicated or have invalid cost.\n");
1232+
LLVM_DEBUG(dbgs() << " Loop not considered unrollable.\n");
12111233
return LoopUnrollResult::Unmodified;
12121234
}
12131235

@@ -1254,15 +1276,9 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
12541276
// is unsafe -- it adds a control-flow dependency to the convergent
12551277
// operation. Therefore restrict remainder loop (try unrolling without).
12561278
//
1257-
// TODO: This is quite conservative. In practice, convergent_op()
1258-
// is likely to be called unconditionally in the loop. In this
1259-
// case, the program would be ill-formed (on most architectures)
1260-
// unless n were the same on all threads in a thread group.
1261-
// Assuming n is the same on all threads, any kind of unrolling is
1262-
// safe. But currently llvm's notion of convergence isn't powerful
1263-
// enough to express this.
1264-
if (UCE.Convergent)
1265-
UP.AllowRemainder = false;
1279+
// TODO: This is somewhat conservative; we could allow the remainder if the
1280+
// trip count is uniform.
1281+
UP.AllowRemainder &= UCE.ConvergenceAllowsRuntime;
12661282

12671283
// Try to find the trip count upper bound if we cannot find the exact trip
12681284
// count.
@@ -1282,6 +1298,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
12821298
if (!UP.Count)
12831299
return LoopUnrollResult::Unmodified;
12841300

1301+
UP.Runtime &= UCE.ConvergenceAllowsRuntime;
1302+
12851303
if (PP.PeelCount) {
12861304
assert(UP.Count == 1 && "Cannot perform peel and unroll in the same step");
12871305
LLVM_DEBUG(dbgs() << "PEELING loop %" << L->getHeader()->getName()
@@ -1324,11 +1342,16 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
13241342

13251343
// Unroll the loop.
13261344
Loop *RemainderLoop = nullptr;
1345+
UnrollLoopOptions ULO;
1346+
ULO.Count = UP.Count;
1347+
ULO.Force = UP.Force;
1348+
ULO.AllowExpensiveTripCount = UP.AllowExpensiveTripCount;
1349+
ULO.UnrollRemainder = UP.UnrollRemainder;
1350+
ULO.Runtime = UP.Runtime;
1351+
ULO.ForgetAllSCEV = ForgetAllSCEV;
1352+
ULO.Heart = getLoopConvergenceHeart(L);
13271353
LoopUnrollResult UnrollResult = UnrollLoop(
1328-
L,
1329-
{UP.Count, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount,
1330-
UP.UnrollRemainder, ForgetAllSCEV},
1331-
LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA);
1354+
L, ULO, LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA);
13321355
if (UnrollResult == LoopUnrollResult::Unmodified)
13331356
return LoopUnrollResult::Unmodified;
13341357

llvm/lib/Transforms/Utils/LoopRotationUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
460460
L->dump());
461461
return Rotated;
462462
}
463-
if (Metrics.convergent) {
463+
if (Metrics.Convergence != ConvergenceKind::None) {
464464
LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent "
465465
"instructions: ";
466466
L->dump());

0 commit comments

Comments
 (0)