Skip to content

[LoopUnroll] Consider convergence control tokens when unrolling #91715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions llvm/include/llvm/Analysis/CodeMetrics.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
namespace llvm {
class AssumptionCache;
class BasicBlock;
class Instruction;
class Loop;
class Function;
template <class T> class SmallPtrSetImpl;
class TargetTransformInfo;
class Value;

enum struct ConvergenceKind { None, Controlled, ExtendedLoop, Uncontrolled };

/// Utility to calculate the size and a few similar metrics for a set
/// of basic blocks.
struct CodeMetrics {
Expand All @@ -42,8 +45,8 @@ struct CodeMetrics {
/// one or more 'noduplicate' instructions.
bool notDuplicatable = false;

/// True if this function contains a call to a convergent function.
bool convergent = false;
/// The kind of convergence specified in this function.
ConvergenceKind Convergence = ConvergenceKind::None;

/// True if this function calls alloca (in the C sense).
bool usesDynamicAlloca = false;
Expand Down Expand Up @@ -77,7 +80,7 @@ struct CodeMetrics {
/// Add information about a block to the current state.
void analyzeBasicBlock(const BasicBlock *BB, const TargetTransformInfo &TTI,
const SmallPtrSetImpl<const Value *> &EphValues,
bool PrepareForLTO = false);
bool PrepareForLTO = false, const Loop *L = nullptr);

/// Collect a loop's ephemeral values (those used only by an assume
/// or similar intrinsics in the loop).
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/Analysis/LoopInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,9 @@ int getIntLoopAttribute(const Loop *TheLoop, StringRef Name, int Default = 0);
std::optional<const MDOperand *> findStringMetadataForLoop(const Loop *TheLoop,
StringRef Name);

/// Find the convergence heart of the loop.
CallBase *getLoopConvergenceHeart(const Loop *TheLoop);

/// Look for the loop attribute that requires progress within the loop.
/// Note: Most consumers probably want "isMustProgress" which checks
/// the containing function attribute too.
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/IR/InstrTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,14 @@ class CallBase : public Instruction {
static CallBase *removeOperandBundle(CallBase *CB, uint32_t ID,
BasicBlock::iterator InsertPt);

/// Return the convergence control token for this call, if it exists.
Value *getConvergenceControlToken() const {
if (auto Bundle = getOperandBundle(llvm::LLVMContext::OB_convergencectrl)) {
return Bundle->Inputs[0].get();
}
return nullptr;
}

static bool classof(const Instruction *I) {
return I->getOpcode() == Instruction::Call ||
I->getOpcode() == Instruction::Invoke ||
Expand Down
19 changes: 8 additions & 11 deletions llvm/include/llvm/IR/IntrinsicInst.h
Original file line number Diff line number Diff line change
Expand Up @@ -1799,17 +1799,14 @@ class ConvergenceControlInst : public IntrinsicInst {
return isa<IntrinsicInst>(V) && classof(cast<IntrinsicInst>(V));
}

// Returns the convergence intrinsic referenced by |I|'s convergencectrl
// attribute if any.
static IntrinsicInst *getParentConvergenceToken(Instruction *I) {
auto *CI = dyn_cast<llvm::CallInst>(I);
if (!CI)
return nullptr;

auto Bundle = CI->getOperandBundle(llvm::LLVMContext::OB_convergencectrl);
assert(Bundle->Inputs.size() == 1 &&
Bundle->Inputs[0]->getType()->isTokenTy());
return dyn_cast<llvm::IntrinsicInst>(Bundle->Inputs[0].get());
bool isAnchor() {
return getIntrinsicID() == Intrinsic::experimental_convergence_anchor;
}
bool isEntry() {
return getIntrinsicID() == Intrinsic::experimental_convergence_entry;
}
bool isLoop() {
return getIntrinsicID() == Intrinsic::experimental_convergence_loop;
}
};

Expand Down
7 changes: 5 additions & 2 deletions llvm/include/llvm/Transforms/Utils/UnrollLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define LLVM_TRANSFORMS_UTILS_UNROLLLOOP_H

#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Support/InstructionCost.h"

Expand Down Expand Up @@ -73,6 +74,7 @@ struct UnrollLoopOptions {
bool AllowExpensiveTripCount;
bool UnrollRemainder;
bool ForgetAllSCEV;
const Instruction *Heart = nullptr;
};

LoopUnrollResult UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
Expand Down Expand Up @@ -128,14 +130,15 @@ class UnrollCostEstimator {

public:
unsigned NumInlineCandidates;
bool Convergent;
ConvergenceKind Convergence;
bool ConvergenceAllowsRuntime;

UnrollCostEstimator(const Loop *L, const TargetTransformInfo &TTI,
const SmallPtrSetImpl<const Value *> &EphValues,
unsigned BEInsns);

/// Whether it is legal to unroll this loop.
bool canUnroll() const { return LoopSize.isValid() && !NotDuplicatable; }
bool canUnroll() const;

uint64_t getRolledLoopSize() const { return *LoopSize.getValue(); }

Expand Down
53 changes: 43 additions & 10 deletions llvm/lib/Analysis/CodeMetrics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/InstructionCost.h"

Expand Down Expand Up @@ -111,11 +112,24 @@ void CodeMetrics::collectEphemeralValues(
completeEphemeralValues(Visited, Worklist, EphValues);
}

static bool extendsConvergenceOutsideLoop(const Instruction &I, const Loop *L) {
if (!L)
return false;
if (!isa<ConvergenceControlInst>(I))
return false;
for (const auto *U : I.users()) {
if (!L->contains(cast<Instruction>(U)))
return true;
}
return false;
}

/// Fill in the current structure with information gleaned from the specified
/// block.
void CodeMetrics::analyzeBasicBlock(
const BasicBlock *BB, const TargetTransformInfo &TTI,
const SmallPtrSetImpl<const Value *> &EphValues, bool PrepareForLTO) {
const SmallPtrSetImpl<const Value *> &EphValues, bool PrepareForLTO,
const Loop *L) {
++NumBlocks;
InstructionCost NumInstsBeforeThisBB = NumInsts;
for (const Instruction &I : *BB) {
Expand Down Expand Up @@ -163,19 +177,38 @@ void CodeMetrics::analyzeBasicBlock(
if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy())
++NumVectorInsts;

if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB))
if (I.getType()->isTokenTy() && !isa<ConvergenceControlInst>(I) &&
I.isUsedOutsideOfBlock(BB)) {
LLVM_DEBUG(dbgs() << I
<< "\n Cannot duplicate a token value used outside "
"the current block (except convergence control).\n");
notDuplicatable = true;

if (const CallInst *CI = dyn_cast<CallInst>(&I)) {
if (CI->cannotDuplicate())
notDuplicatable = true;
if (CI->isConvergent())
convergent = true;
}

if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I))
if (InvI->cannotDuplicate())
if (const CallBase *CB = dyn_cast<CallBase>(&I)) {
if (CB->cannotDuplicate())
notDuplicatable = true;
// Compute a meet over the visited blocks for the following partial order:
//
// None -> { Controlled, ExtendedLoop, Uncontrolled}
// Controlled -> ExtendedLoop
if (Convergence <= ConvergenceKind::Controlled && CB->isConvergent()) {
if (isa<ConvergenceControlInst>(CB) ||
CB->getConvergenceControlToken()) {
assert(Convergence != ConvergenceKind::Uncontrolled);
LLVM_DEBUG(dbgs() << "Found controlled convergence:\n" << I << "\n");
if (extendsConvergenceOutsideLoop(I, L))
Convergence = ConvergenceKind::ExtendedLoop;
else {
assert(Convergence != ConvergenceKind::ExtendedLoop);
Convergence = ConvergenceKind::Controlled;
}
} else {
assert(Convergence == ConvergenceKind::None);
Convergence = ConvergenceKind::Uncontrolled;
}
}
}

NumInsts += TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
}
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/Analysis/LoopInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,6 +1105,26 @@ int llvm::getIntLoopAttribute(const Loop *TheLoop, StringRef Name,
return getOptionalIntLoopAttribute(TheLoop, Name).value_or(Default);
}

CallBase *llvm::getLoopConvergenceHeart(const Loop *TheLoop) {
BasicBlock *H = TheLoop->getHeader();
for (Instruction &II : *H) {
if (auto *CB = dyn_cast<CallBase>(&II)) {
if (!CB->isConvergent())
continue;
// This is the heart if it uses a token defined outside the loop. The
// verifier has already checked that only the loop intrinsic can use such
// a token.
if (auto *Token = CB->getConvergenceControlToken()) {
auto *TokenDef = cast<Instruction>(Token);
if (!TheLoop->contains(TokenDef->getParent()))
return CB;
}
return nullptr;
}
}
return nullptr;
}

bool llvm::isFinite(const Loop *L) {
return L->getHeader()->getParent()->willReturn();
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3961,7 +3961,7 @@ static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);

// Loop is not unrollable if the loop contains certain instructions.
if (!UCE.canUnroll() || UCE.Convergent) {
if (!UCE.canUnroll()) {
LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
return 1;
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Transforms/Scalar/DFAJumpThreading.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,8 @@ struct TransformDFA {
return false;
}

if (Metrics.convergent) {
// FIXME: Allow jump threading with controlled convergence.
if (Metrics.Convergence != ConvergenceKind::None) {
LLVM_DEBUG(dbgs() << "DFA Jump Threading: Not jump threading, contains "
<< "convergent instructions.\n");
ORE->emit([&]() {
Expand Down
8 changes: 5 additions & 3 deletions llvm/lib/Transforms/Scalar/LoopUnrollAndJamPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,7 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
UnrollCostEstimator OuterUCE(L, TTI, EphValues, UP.BEInsns);

if (!InnerUCE.canUnroll() || !OuterUCE.canUnroll()) {
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
<< " which cannot be duplicated or have invalid cost.\n");
LLVM_DEBUG(dbgs() << " Loop not considered unrollable\n");
return LoopUnrollResult::Unmodified;
}

Expand All @@ -341,7 +340,10 @@ tryToUnrollAndJamLoop(Loop *L, DominatorTree &DT, LoopInfo *LI,
LLVM_DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n");
return LoopUnrollResult::Unmodified;
}
if (InnerUCE.Convergent || OuterUCE.Convergent) {
// FIXME: The call to canUnroll() allows some controlled convergent
// operations, but we block them here for future changes.
if (InnerUCE.Convergence != ConvergenceKind::None ||
OuterUCE.Convergence != ConvergenceKind::None) {
LLVM_DEBUG(
dbgs() << " Not unrolling loop with convergent instructions.\n");
return LoopUnrollResult::Unmodified;
Expand Down
57 changes: 40 additions & 17 deletions llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,11 +684,15 @@ UnrollCostEstimator::UnrollCostEstimator(
const SmallPtrSetImpl<const Value *> &EphValues, unsigned BEInsns) {
CodeMetrics Metrics;
for (BasicBlock *BB : L->blocks())
Metrics.analyzeBasicBlock(BB, TTI, EphValues);
Metrics.analyzeBasicBlock(BB, TTI, EphValues, /* PrepareForLTO= */ false,
L);
NumInlineCandidates = Metrics.NumInlineCandidates;
NotDuplicatable = Metrics.notDuplicatable;
Convergent = Metrics.convergent;
Convergence = Metrics.Convergence;
LoopSize = Metrics.NumInsts;
ConvergenceAllowsRuntime =
Metrics.Convergence != ConvergenceKind::Uncontrolled &&
!getLoopConvergenceHeart(L);

// Don't allow an estimate of size zero. This would allows unrolling of loops
// with huge iteration counts, which is a compile time problem even if it's
Expand All @@ -701,6 +705,25 @@ UnrollCostEstimator::UnrollCostEstimator(
LoopSize = BEInsns + 1;
}

bool UnrollCostEstimator::canUnroll() const {
switch (Convergence) {
case ConvergenceKind::ExtendedLoop:
LLVM_DEBUG(dbgs() << " Convergence prevents unrolling.\n");
return false;
default:
break;
}
if (!LoopSize.isValid()) {
LLVM_DEBUG(dbgs() << " Invalid loop size prevents unrolling.\n");
return false;
}
if (NotDuplicatable) {
LLVM_DEBUG(dbgs() << " Non-duplicatable blocks prevent unrolling.\n");
return false;
}
return true;
}

uint64_t UnrollCostEstimator::getUnrolledLoopSize(
const TargetTransformInfo::UnrollingPreferences &UP,
unsigned CountOverwrite) const {
Expand Down Expand Up @@ -1206,8 +1229,7 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,

UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
if (!UCE.canUnroll()) {
LLVM_DEBUG(dbgs() << " Not unrolling loop which contains instructions"
<< " which cannot be duplicated or have invalid cost.\n");
LLVM_DEBUG(dbgs() << " Loop not considered unrollable.\n");
return LoopUnrollResult::Unmodified;
}

Expand Down Expand Up @@ -1254,15 +1276,9 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
// is unsafe -- it adds a control-flow dependency to the convergent
// operation. Therefore restrict remainder loop (try unrolling without).
//
// TODO: This is quite conservative. In practice, convergent_op()
// is likely to be called unconditionally in the loop. In this
// case, the program would be ill-formed (on most architectures)
// unless n were the same on all threads in a thread group.
// Assuming n is the same on all threads, any kind of unrolling is
// safe. But currently llvm's notion of convergence isn't powerful
// enough to express this.
if (UCE.Convergent)
UP.AllowRemainder = false;
// TODO: This is somewhat conservative; we could allow the remainder if the
// trip count is uniform.
UP.AllowRemainder &= UCE.ConvergenceAllowsRuntime;

// Try to find the trip count upper bound if we cannot find the exact trip
// count.
Expand All @@ -1282,6 +1298,8 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,
if (!UP.Count)
return LoopUnrollResult::Unmodified;

UP.Runtime &= UCE.ConvergenceAllowsRuntime;

if (PP.PeelCount) {
assert(UP.Count == 1 && "Cannot perform peel and unroll in the same step");
LLVM_DEBUG(dbgs() << "PEELING loop %" << L->getHeader()->getName()
Expand Down Expand Up @@ -1324,11 +1342,16 @@ tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE,

// Unroll the loop.
Loop *RemainderLoop = nullptr;
UnrollLoopOptions ULO;
ULO.Count = UP.Count;
ULO.Force = UP.Force;
ULO.AllowExpensiveTripCount = UP.AllowExpensiveTripCount;
ULO.UnrollRemainder = UP.UnrollRemainder;
ULO.Runtime = UP.Runtime;
ULO.ForgetAllSCEV = ForgetAllSCEV;
ULO.Heart = getLoopConvergenceHeart(L);
LoopUnrollResult UnrollResult = UnrollLoop(
L,
{UP.Count, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount,
UP.UnrollRemainder, ForgetAllSCEV},
LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA);
L, ULO, LI, &SE, &DT, &AC, &TTI, &ORE, PreserveLCSSA, &RemainderLoop, AA);
if (UnrollResult == LoopUnrollResult::Unmodified)
return LoopUnrollResult::Unmodified;

Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Transforms/Utils/LoopRotationUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) {
L->dump());
return Rotated;
}
if (Metrics.convergent) {
if (Metrics.Convergence != ConvergenceKind::None) {
LLVM_DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent "
"instructions: ";
L->dump());
Expand Down
Loading
Loading