Skip to content

Add a standalone histogram intrinsic, use it to vectorize simple histograms #89366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions llvm/include/llvm/Analysis/LoopAccessAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ class MemoryDepChecker {
bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
const DenseMap<Value *, const SCEV *> &Strides,
const DenseMap<Value *, SmallVector<const Value *, 16>>
&UnderlyingObjects);
&UnderlyingObjects,
const SmallPtrSetImpl<Value *> &HistogramPtrs);

/// No memory dependence was encountered that would inhibit
/// vectorization.
Expand Down Expand Up @@ -330,7 +331,8 @@ class MemoryDepChecker {
isDependent(const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
const DenseMap<Value *, SmallVector<const Value *, 16>>
&UnderlyingObjects);
&UnderlyingObjects,
const SmallPtrSetImpl<Value *> &HistogramPtrs);

/// Check whether the data dependence could prevent store-load
/// forwarding.
Expand Down Expand Up @@ -608,6 +610,19 @@ class LoopAccessInfo {
unsigned getNumStores() const { return NumStores; }
unsigned getNumLoads() const { return NumLoads;}

const DenseMap<Instruction *, Instruction *> &getHistogramCounts() const {
return HistogramCounts;
}

/// Given a Histogram count BinOp \p I, returns the Index Value for the input
/// array to compute histogram for.
std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
auto It = HistogramCounts.find(I);
if (It == HistogramCounts.end())
return std::nullopt;
return It->second;
}

/// The diagnostics report generated for the analysis. E.g. why we
/// couldn't analyze the loop.
const OptimizationRemarkAnalysis *getReport() const { return Report.get(); }
Expand Down Expand Up @@ -710,6 +725,14 @@ class LoopAccessInfo {
/// If an access has a symbolic strides, this maps the pointer value to
/// the stride symbol.
DenseMap<Value *, const SCEV *> SymbolicStrides;

/// Holds all the Histogram counts BinOp/Index pairs that we found in the
/// loop, where BinOp is an Add/Sub that does the histogram counting, and
/// Index is the index of the bucket to compute a histogram for.
DenseMap<Instruction *, Instruction *> HistogramCounts;

/// Storing Histogram Pointers
SmallPtrSet<Value *, 2> HistogramPtrs;
};

/// Return the SCEV corresponding to a pointer with the symbolic stride
Expand Down
8 changes: 8 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,9 @@ class TargetTransformInfo {
/// Return hardware support for population count.
PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) const;

/// Returns the cost of generating a vector histogram.
InstructionCost getHistogramCost(Type *Ty) const;

/// Return true if the hardware has a fast square-root instruction.
bool haveFastSqrt(Type *Ty) const;

Expand Down Expand Up @@ -1930,6 +1933,7 @@ class TargetTransformInfo::Concept {
unsigned *Fast) = 0;
virtual PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) = 0;
virtual bool haveFastSqrt(Type *Ty) = 0;
virtual InstructionCost getHistogramCost(Type *Ty) = 0;
virtual bool isExpensiveToSpeculativelyExecute(const Instruction *I) = 0;
virtual bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) = 0;
virtual InstructionCost getFPOpCost(Type *Ty) = 0;
Expand Down Expand Up @@ -2490,6 +2494,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
}
bool haveFastSqrt(Type *Ty) override { return Impl.haveFastSqrt(Ty); }

InstructionCost getHistogramCost(Type *Ty) override {
return Impl.getHistogramCost(Ty);
}

bool isExpensiveToSpeculativelyExecute(const Instruction* I) override {
return Impl.isExpensiveToSpeculativelyExecute(I);
}
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ class TargetTransformInfoImplBase {

bool haveFastSqrt(Type *Ty) const { return false; }

InstructionCost getHistogramCost(Type *Ty) const {
return InstructionCost::getInvalid();
}

bool isExpensiveToSpeculativelyExecute(const Instruction *I) { return true; }

bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) const { return true; }
Expand Down
4 changes: 4 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
TLI->isOperationLegalOrCustom(ISD::FSQRT, VT);
}

InstructionCost getHistogramCost(Type *Ty) {
return InstructionCost::getInvalid();
}

bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) {
return true;
}
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/IR/Intrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1849,6 +1849,12 @@ def int_experimental_vp_strided_load : DefaultAttrsIntrinsic<[llvm_anyvector_ty
llvm_i32_ty],
[ NoCapture<ArgIndex<0>>, IntrNoSync, IntrReadMem, IntrWillReturn, IntrArgMemOnly ]>;

// Experimental histogram count
def int_experimental_vector_histogram_count : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
[ LLVMMatchType<0>,
LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
[ IntrNoMem ]>;

// Operators
let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
// Integer arithmetic
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ class LoopVectorizationLegality {

unsigned getNumStores() const { return LAI->getNumStores(); }
unsigned getNumLoads() const { return LAI->getNumLoads(); }
std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
return LAI->getHistogramIndexValue(I);
}

PredicatedScalarEvolution *getPredicatedScalarEvolution() const {
return &PSE;
Expand Down
132 changes: 124 additions & 8 deletions llvm/lib/Analysis/LoopAccessAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/AliasAnalysis.h"
#include "llvm/Analysis/AliasSetTracker.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
Expand Down Expand Up @@ -69,6 +70,8 @@ using namespace llvm::PatternMatch;

#define DEBUG_TYPE "loop-accesses"

STATISTIC(HistogramsDetected, "Number of Histograms detected");

static cl::opt<unsigned, true>
VectorizationFactor("force-vector-width", cl::Hidden,
cl::desc("Sets the SIMD width. Zero is autoselect."),
Expand Down Expand Up @@ -730,6 +733,23 @@ class AccessAnalysis {
return UnderlyingObjects;
}

/// Find Histogram counts that match high-level code in loops:
/// \code
/// buckets[indices[i]]+=step;
/// \endcode
///
/// It matches a pattern starting from \p HSt, which Stores to the 'buckets'
/// array the computed histogram. It uses a BinOp to sum all counts, storing
/// them using a loop-variant index Load from the 'indices' input array.
///
/// On successful matches it updates the STATISTIC 'HistogramsDetected',
/// regardless of hardware support. When there is support, it additionally
/// stores the BinOp/Load pairs in \p HistogramCounts, as well the pointers
/// used to update histogram in \p HistogramPtrs.
void findHistograms(StoreInst *HSt,
DenseMap<Instruction *, Instruction *> &HistogramCounts,
SmallPtrSetImpl<Value *> &HistogramPtrs);

private:
typedef MapVector<MemAccessInfo, SmallSetVector<Type *, 1>> PtrAccessMap;

Expand Down Expand Up @@ -1946,7 +1966,8 @@ getDependenceDistanceStrideAndSize(
const AccessAnalysis::MemAccessInfo &B, Instruction *BInst,
const DenseMap<Value *, const SCEV *> &Strides,
const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
PredicatedScalarEvolution &PSE, const Loop *InnermostLoop) {
PredicatedScalarEvolution &PSE, const Loop *InnermostLoop,
const SmallPtrSetImpl<Value *> &HistogramPtrs) {
auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout();
auto &SE = *PSE.getSE();
auto [APtr, AIsWrite] = A;
Expand All @@ -1964,6 +1985,15 @@ getDependenceDistanceStrideAndSize(
BPtr->getType()->getPointerAddressSpace())
return MemoryDepChecker::Dependence::Unknown;

// Ignore Histogram count updates as they are handled by the Intrinsic. This
// happens when the same pointer is first used to read from and then is used
// to write to.
if (!AIsWrite && BIsWrite && APtr == BPtr && HistogramPtrs.contains(APtr)) {
LLVM_DEBUG(dbgs() << "LAA: Histogram: Update is safely ignored. Pointer: "
<< *APtr);
return MemoryDepChecker::Dependence::NoDep;
}

int64_t StrideAPtr =
getPtrStride(PSE, ATy, APtr, InnermostLoop, Strides, true).value_or(0);
int64_t StrideBPtr =
Expand Down Expand Up @@ -2016,15 +2046,15 @@ getDependenceDistanceStrideAndSize(
MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
const DenseMap<Value *, SmallVector<const Value *, 16>>
&UnderlyingObjects) {
const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
const SmallPtrSetImpl<Value *> &HistogramPtrs) {
assert(AIdx < BIdx && "Must pass arguments in program order");

// Get the dependence distance, stride, type size and what access writes for
// the dependence between A and B.
auto Res = getDependenceDistanceStrideAndSize(
A, InstMap[AIdx], B, InstMap[BIdx], Strides, UnderlyingObjects, PSE,
InnermostLoop);
InnermostLoop, HistogramPtrs);
if (std::holds_alternative<Dependence::DepType>(Res))
return std::get<Dependence::DepType>(Res);

Expand Down Expand Up @@ -2185,8 +2215,8 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
bool MemoryDepChecker::areDepsSafe(
DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
const DenseMap<Value *, const SCEV *> &Strides,
const DenseMap<Value *, SmallVector<const Value *, 16>>
&UnderlyingObjects) {
const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
const SmallPtrSetImpl<Value *> &HistogramPtrs) {

MinDepDistBytes = -1;
SmallPtrSet<MemAccessInfo, 8> Visited;
Expand Down Expand Up @@ -2231,7 +2261,7 @@ bool MemoryDepChecker::areDepsSafe(

Dependence::DepType Type =
isDependent(*A.first, A.second, *B.first, B.second, Strides,
UnderlyingObjects);
UnderlyingObjects, HistogramPtrs);
mergeInStatus(Dependence::isSafeForVectorization(Type));

// Gather dependences unless we accumulated MaxDependences
Expand Down Expand Up @@ -2567,6 +2597,9 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
// check.
Accesses.buildDependenceSets();

for (StoreInst *ST : Stores)
Accesses.findHistograms(ST, HistogramCounts, HistogramPtrs);

// Find pointers with computable bounds. We are going to use this information
// to place a runtime bound check.
Value *UncomputablePtr = nullptr;
Expand All @@ -2591,7 +2624,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n");
CanVecMem = DepChecker->areDepsSafe(
DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides,
Accesses.getUnderlyingObjects());
Accesses.getUnderlyingObjects(), HistogramPtrs);

if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) {
LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n");
Expand Down Expand Up @@ -3025,6 +3058,89 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L) {
return *I.first->second;
}

void AccessAnalysis::findHistograms(
StoreInst *HSt, DenseMap<Instruction *, Instruction *> &HistogramCounts,
SmallPtrSetImpl<Value *> &HistogramPtrs) {
LLVM_DEBUG(dbgs() << "LAA: Attempting to match histogram from " << *HSt
<< "\n");
// Store value must come from a Binary Operation.
Instruction *HPtrInstr = nullptr;
BinaryOperator *HBinOp = nullptr;
if (!match(HSt, m_Store(m_BinOp(HBinOp), m_Instruction(HPtrInstr)))) {
LLVM_DEBUG(dbgs() << "\tNo BinOp\n");
return;
}

// BinOp must be an Add or a Sub operating modifying the bucket value by a
// loop invariant amount.
// FIXME: We assume the loop invariant term is on the RHS.
// Fine for an immediate/constant, but maybe not a generic value?
Value *HIncVal = nullptr;
if (!match(HBinOp, m_Add(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
!match(HBinOp, m_Sub(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal)))) {
LLVM_DEBUG(dbgs() << "\tNo matching load\n");
return;
}

// The address to store is calculated through a GEP Instruction.
// FIXME: Support GEPs with more operands.
GetElementPtrInst *HPtr = dyn_cast<GetElementPtrInst>(HPtrInstr);
if (!HPtr || HPtr->getNumOperands() > 2) {
LLVM_DEBUG(dbgs() << "\tToo many GEP operands\n");
return;
}

// Check that the index is calculated by loading from another array. Ignore
// any extensions.
// FIXME: Support indices from other sources that a linear load from memory?
Value *HIdx = HPtr->getOperand(1);
Instruction *IdxInst = nullptr;
// FIXME: Can this fail? Maybe if IdxInst isn't an instruction. Just need to
// look through extensions, find another way?
if (!match(HIdx, m_ZExtOrSExtOrSelf(m_Instruction(IdxInst))))
return;

// Currently restricting this to linear addressing when loading indices.
LoadInst *VLoad = dyn_cast<LoadInst>(IdxInst);
Value *VPtrVal;
if (!VLoad || !match(VLoad, m_Load(m_Value(VPtrVal)))) {
LLVM_DEBUG(dbgs() << "\tBad Index Load\n");
return;
}

if (!isa<SCEVAddRecExpr>(PSE.getSCEV(VPtrVal))) {
LLVM_DEBUG(dbgs() << "\tCannot determine index load stride\n");
return;
}

// FIXME: support smaller types of input arrays. Integers can be promoted
// for codegen.
Type *VLoadTy = VLoad->getType();
if (!VLoadTy->isIntegerTy() || (VLoadTy->getScalarSizeInBits() != 32 &&
VLoadTy->getScalarSizeInBits() != 64)) {
LLVM_DEBUG(dbgs() << "\tUnsupported bucket type: " << *VLoadTy << "\n");
return;
}

// A histogram pointer may only alias to itself, and must only have two uses,
// the load and the store.
for (AliasSet &AS : AST)
if (AS.isMustAlias() || AS.isMayAlias())
if ((is_contained(AS.getPointers(), HPtr) && AS.size() > 1) ||
HPtr->getNumUses() != 2) {
LLVM_DEBUG(dbgs() << "\tAliasing problem\n");
return;
}

LLVM_DEBUG(dbgs() << "LAA: Found Histogram Operation: " << *HBinOp << "\n");
HistogramsDetected++;

// Store pairs of BinOp (Add/Sub) that modify the count and the index load.
HistogramCounts.insert(std::make_pair(HBinOp, VLoad));
// Store pointers used to write those counts in the computed histogram.
HistogramPtrs.insert(HPtr);
}

bool LoopAccessInfoManager::invalidate(
Function &F, const PreservedAnalyses &PA,
FunctionAnalysisManager::Invalidator &Inv) {
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,10 @@ bool TargetTransformInfo::haveFastSqrt(Type *Ty) const {
return TTIImpl->haveFastSqrt(Ty);
}

InstructionCost TargetTransformInfo::getHistogramCost(Type *Ty) const {
return TTIImpl->getHistogramCost(Ty);
}

bool TargetTransformInfo::isExpensiveToSpeculativelyExecute(
const Instruction *I) const {
return TTIImpl->isExpensiveToSpeculativelyExecute(I);
Expand Down
15 changes: 15 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5293,6 +5293,21 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
SDLoc dl(Op);
switch (IntNo) {
default: return SDValue(); // Don't custom lower most intrinsics.
case Intrinsic::experimental_vector_histogram_count: {
// Replacing IR Intrinsic with AArch64/sve2 histcnt
assert((Op.getNumOperands() == 3) &&
"Histogram Intrinsic requires 3 operands");
EVT Ty = Op.getValueType();
assert((Ty == MVT::nxv4i32 || Ty == MVT::nxv2i64) &&
"Intrinsic supports only i64 or i32 types");
// EVT VT = (Ty == MVT::nxv4i32) ? MVT::i32 : MVT::i64;
SDValue InputVector = Op.getOperand(1);
SDValue Mask = Op.getOperand(2);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, dl, MVT::i32);
return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Ty, ID, Mask, InputVector,
InputVector);
}
case Intrinsic::thread_pointer: {
EVT PtrVT = getPointerTy(DAG.getDataLayout());
return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
Expand Down
Loading