Skip to content

Add a standalone histogram intrinsic, use it to vectorize simple histograms #89366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

huntergr-arm
Copy link
Collaborator

This is a proof-of-concept using the original intrinsic proposed in https://discourse.llvm.org/t/rfc-vectorization-support-for-histogram-count-operations/74788/5

Much of the work was originally done by @paschalis-mpeis ; I've changed it a little to support arbitrary increments and use the appropriate mask when tail folding.

The proposed alternative in the thread was implemented in #88106 to compare the two approaches (though without LoopVec support).

Given code like the following:

void simple_histogram(int *restrict buckets, unsigned *indices, int N, int inc) {
  for (int i = 0; i < N; ++i)
    buckets[indices[i]] += inc;
}

this patch allows LLVM to generate a vectorized loop for SVE:

simple_histogram:
	.cfi_startproc
	cmp	w2, #1
	b.lt	.LBB0_3

	mov	w8, w2
	mov	z0.s, w3
	ptrue	p0.s
	whilelo	p1.s, xzr, x8
	mov	x9, xzr

.LBB0_2:
	ld1w	{ z1.s }, p1/z, [x1, x9, lsl #2]
	incw	x9
	histcnt	z2.s, p1/z, z1.s, z1.s
	ld1w	{ z3.s }, p1/z, [x0, z1.s, uxtw #2]
	mad	z2.s, p0/m, z0.s, z3.s
	st1w	{ z2.s }, p1, [x0, z1.s, uxtw #2]
	whilelo	p1.s, x9, x8
	b.mi	.LBB0_2

.LBB0_3:
	ret

I'll add a comment in the RFC thread about the tradeoffs we've thought of so far.

@llvmbot
Copy link
Member

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-llvm-transforms
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-backend-aarch64

Author: Graham Hunter (huntergr-arm)

Changes

This is a proof-of-concept using the original intrinsic proposed in https://discourse.llvm.org/t/rfc-vectorization-support-for-histogram-count-operations/74788/5

Much of the work was originally done by @paschalis-mpeis ; I've changed it a little to support arbitrary increments and use the appropriate mask when tail folding.

The proposed alternative in the thread was implemented in #88106 to compare the two approaches (though without LoopVec support).

Given code like the following:

void simple_histogram(int *restrict buckets, unsigned *indices, int N, int inc) {
  for (int i = 0; i &lt; N; ++i)
    buckets[indices[i]] += inc;
}

this patch allows LLVM to generate a vectorized loop for SVE:

simple_histogram:
	.cfi_startproc
	cmp	w2, #<!-- -->1
	b.lt	.LBB0_3

	mov	w8, w2
	mov	z0.s, w3
	ptrue	p0.s
	whilelo	p1.s, xzr, x8
	mov	x9, xzr

.LBB0_2:
	ld1w	{ z1.s }, p1/z, [x1, x9, lsl #<!-- -->2]
	incw	x9
	histcnt	z2.s, p1/z, z1.s, z1.s
	ld1w	{ z3.s }, p1/z, [x0, z1.s, uxtw #<!-- -->2]
	mad	z2.s, p0/m, z0.s, z3.s
	st1w	{ z2.s }, p1, [x0, z1.s, uxtw #<!-- -->2]
	whilelo	p1.s, x9, x8
	b.mi	.LBB0_2

.LBB0_3:
	ret

I'll add a comment in the RFC thread about the tradeoffs we've thought of so far.


Patch is 36.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89366.diff

17 Files Affected:

  • (modified) llvm/include/llvm/Analysis/LoopAccessAnalysis.h (+25-2)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+4)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+4)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+6)
  • (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+3)
  • (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+124-8)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+15)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+30)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+20)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+44-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+33)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt.ll (+86)
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index e39c371b41ec5c..0efc9a2814d8fb 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -198,7 +198,8 @@ class MemoryDepChecker {
   bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
                    const DenseMap<Value *, const SCEV *> &Strides,
                    const DenseMap<Value *, SmallVector<const Value *, 16>>
-                       &UnderlyingObjects);
+                       &UnderlyingObjects,
+                   const SmallPtrSetImpl<Value *> &HistogramPtrs);
 
   /// No memory dependence was encountered that would inhibit
   /// vectorization.
@@ -330,7 +331,8 @@ class MemoryDepChecker {
   isDependent(const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
               unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
               const DenseMap<Value *, SmallVector<const Value *, 16>>
-                  &UnderlyingObjects);
+                  &UnderlyingObjects,
+              const SmallPtrSetImpl<Value *> &HistogramPtrs);
 
   /// Check whether the data dependence could prevent store-load
   /// forwarding.
@@ -608,6 +610,19 @@ class LoopAccessInfo {
   unsigned getNumStores() const { return NumStores; }
   unsigned getNumLoads() const { return NumLoads;}
 
+  const DenseMap<Instruction *, Instruction *> &getHistogramCounts() const {
+    return HistogramCounts;
+  }
+
+  /// Given a Histogram count BinOp \p I, returns the Index Value for the input
+  /// array to compute histogram for.
+  std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
+    auto It = HistogramCounts.find(I);
+    if (It == HistogramCounts.end())
+      return std::nullopt;
+    return It->second;
+  }
+
   /// The diagnostics report generated for the analysis.  E.g. why we
   /// couldn't analyze the loop.
   const OptimizationRemarkAnalysis *getReport() const { return Report.get(); }
@@ -710,6 +725,14 @@ class LoopAccessInfo {
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
   DenseMap<Value *, const SCEV *> SymbolicStrides;
+
+  /// Holds all the Histogram counts BinOp/Index pairs that we found in the
+  /// loop, where BinOp is an Add/Sub that does the histogram counting, and
+  /// Index is the index of the bucket to compute a histogram for.
+  DenseMap<Instruction *, Instruction *> HistogramCounts;
+
+  /// Storing Histogram Pointers
+  SmallPtrSet<Value *, 2> HistogramPtrs;
 };
 
 /// Return the SCEV corresponding to a pointer with the symbolic stride
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58c69ac939763a..c136f1dabfa7ce 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -982,6 +982,9 @@ class TargetTransformInfo {
   /// Return hardware support for population count.
   PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) const;
 
+  /// Returns the cost of generating a vector histogram.
+  InstructionCost getHistogramCost(Type *Ty) const;
+
   /// Return true if the hardware has a fast square-root instruction.
   bool haveFastSqrt(Type *Ty) const;
 
@@ -1930,6 +1933,7 @@ class TargetTransformInfo::Concept {
                                               unsigned *Fast) = 0;
   virtual PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) = 0;
   virtual bool haveFastSqrt(Type *Ty) = 0;
+  virtual InstructionCost getHistogramCost(Type *Ty) = 0;
   virtual bool isExpensiveToSpeculativelyExecute(const Instruction *I) = 0;
   virtual bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) = 0;
   virtual InstructionCost getFPOpCost(Type *Ty) = 0;
@@ -2490,6 +2494,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   bool haveFastSqrt(Type *Ty) override { return Impl.haveFastSqrt(Ty); }
 
+  InstructionCost getHistogramCost(Type *Ty) override {
+    return Impl.getHistogramCost(Ty);
+  }
+
   bool isExpensiveToSpeculativelyExecute(const Instruction* I) override {
     return Impl.isExpensiveToSpeculativelyExecute(I);
   }
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 5b40e49714069f..67c04979cfe5db 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -412,6 +412,10 @@ class TargetTransformInfoImplBase {
 
   bool haveFastSqrt(Type *Ty) const { return false; }
 
+  InstructionCost getHistogramCost(Type *Ty) const {
+    return InstructionCost::getInvalid();
+  }
+
   bool isExpensiveToSpeculativelyExecute(const Instruction *I) { return true; }
 
   bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) const { return true; }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 06a19c75cf873a..bb65a5a6ce905c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -538,6 +538,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
            TLI->isOperationLegalOrCustom(ISD::FSQRT, VT);
   }
 
+  InstructionCost getHistogramCost(Type *Ty) {
+    return InstructionCost::getInvalid();
+  }
+
   bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) {
     return true;
   }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index bdd8465883fcff..92cad82ff03b5e 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1849,6 +1849,12 @@ def int_experimental_vp_strided_load  : DefaultAttrsIntrinsic<[llvm_anyvector_ty
                                llvm_i32_ty],
                              [ NoCapture<ArgIndex<0>>, IntrNoSync, IntrReadMem, IntrWillReturn, IntrArgMemOnly ]>;
 
+// Experimental histogram count
+def int_experimental_vector_histogram_count : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                             [ LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
+                             [ IntrNoMem ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index a509ebf6a7e1b3..8f0e8f26f7fb13 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -386,6 +386,9 @@ class LoopVectorizationLegality {
 
   unsigned getNumStores() const { return LAI->getNumStores(); }
   unsigned getNumLoads() const { return LAI->getNumLoads(); }
+  std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
+    return LAI->getHistogramIndexValue(I);
+  }
 
   PredicatedScalarEvolution *getPredicatedScalarEvolution() const {
     return &PSE;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 3bfc9700a14559..9fbd47a329c6f9 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/AliasSetTracker.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
@@ -69,6 +70,8 @@ using namespace llvm::PatternMatch;
 
 #define DEBUG_TYPE "loop-accesses"
 
+STATISTIC(HistogramsDetected, "Number of Histograms detected");
+
 static cl::opt<unsigned, true>
 VectorizationFactor("force-vector-width", cl::Hidden,
                     cl::desc("Sets the SIMD width. Zero is autoselect."),
@@ -730,6 +733,23 @@ class AccessAnalysis {
     return UnderlyingObjects;
   }
 
+  /// Find Histogram counts that match high-level code in loops:
+  /// \code
+  /// buckets[indices[i]]+=step;
+  /// \endcode
+  ///
+  /// It matches a pattern starting from \p HSt, which Stores to the 'buckets'
+  /// array the computed histogram. It uses a BinOp to sum all counts, storing
+  /// them using a loop-variant index Load from the 'indices' input array.
+  ///
+  /// On successful matches it updates the STATISTIC 'HistogramsDetected',
+  /// regardless of hardware support. When there is support, it additionally
+  /// stores the BinOp/Load pairs in \p HistogramCounts, as well the pointers
+  /// used to update histogram in \p HistogramPtrs.
+  void findHistograms(StoreInst *HSt,
+                      DenseMap<Instruction *, Instruction *> &HistogramCounts,
+                      SmallPtrSetImpl<Value *> &HistogramPtrs);
+
 private:
   typedef MapVector<MemAccessInfo, SmallSetVector<Type *, 1>> PtrAccessMap;
 
@@ -1946,7 +1966,8 @@ getDependenceDistanceStrideAndSize(
     const AccessAnalysis::MemAccessInfo &B, Instruction *BInst,
     const DenseMap<Value *, const SCEV *> &Strides,
     const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
-    PredicatedScalarEvolution &PSE, const Loop *InnermostLoop) {
+    PredicatedScalarEvolution &PSE, const Loop *InnermostLoop,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
   auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout();
   auto &SE = *PSE.getSE();
   auto [APtr, AIsWrite] = A;
@@ -1964,6 +1985,15 @@ getDependenceDistanceStrideAndSize(
       BPtr->getType()->getPointerAddressSpace())
     return MemoryDepChecker::Dependence::Unknown;
 
+  // Ignore Histogram count updates as they are handled by the Intrinsic. This
+  // happens when the same pointer is first used to read from and then is used
+  // to write to.
+  if (!AIsWrite && BIsWrite && APtr == BPtr && HistogramPtrs.contains(APtr)) {
+    LLVM_DEBUG(dbgs() << "LAA: Histogram: Update is safely ignored. Pointer: "
+                      << *APtr);
+    return MemoryDepChecker::Dependence::NoDep;
+  }
+
   int64_t StrideAPtr =
       getPtrStride(PSE, ATy, APtr, InnermostLoop, Strides, true).value_or(0);
   int64_t StrideBPtr =
@@ -2016,15 +2046,15 @@ getDependenceDistanceStrideAndSize(
 MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
     const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
     unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
-    const DenseMap<Value *, SmallVector<const Value *, 16>>
-        &UnderlyingObjects) {
+    const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
   assert(AIdx < BIdx && "Must pass arguments in program order");
 
   // Get the dependence distance, stride, type size and what access writes for
   // the dependence between A and B.
   auto Res = getDependenceDistanceStrideAndSize(
       A, InstMap[AIdx], B, InstMap[BIdx], Strides, UnderlyingObjects, PSE,
-      InnermostLoop);
+      InnermostLoop, HistogramPtrs);
   if (std::holds_alternative<Dependence::DepType>(Res))
     return std::get<Dependence::DepType>(Res);
 
@@ -2185,8 +2215,8 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
 bool MemoryDepChecker::areDepsSafe(
     DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
     const DenseMap<Value *, const SCEV *> &Strides,
-    const DenseMap<Value *, SmallVector<const Value *, 16>>
-        &UnderlyingObjects) {
+    const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
 
   MinDepDistBytes = -1;
   SmallPtrSet<MemAccessInfo, 8> Visited;
@@ -2231,7 +2261,7 @@ bool MemoryDepChecker::areDepsSafe(
 
             Dependence::DepType Type =
                 isDependent(*A.first, A.second, *B.first, B.second, Strides,
-                            UnderlyingObjects);
+                            UnderlyingObjects, HistogramPtrs);
             mergeInStatus(Dependence::isSafeForVectorization(Type));
 
             // Gather dependences unless we accumulated MaxDependences
@@ -2567,6 +2597,9 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
   // check.
   Accesses.buildDependenceSets();
 
+  for (StoreInst *ST : Stores)
+    Accesses.findHistograms(ST, HistogramCounts, HistogramPtrs);
+
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
   Value *UncomputablePtr = nullptr;
@@ -2591,7 +2624,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
     LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n");
     CanVecMem = DepChecker->areDepsSafe(
         DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides,
-        Accesses.getUnderlyingObjects());
+        Accesses.getUnderlyingObjects(), HistogramPtrs);
 
     if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) {
       LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n");
@@ -3025,6 +3058,89 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L) {
   return *I.first->second;
 }
 
+void AccessAnalysis::findHistograms(
+    StoreInst *HSt, DenseMap<Instruction *, Instruction *> &HistogramCounts,
+    SmallPtrSetImpl<Value *> &HistogramPtrs) {
+  LLVM_DEBUG(dbgs() << "LAA: Attempting to match histogram from " << *HSt
+                    << "\n");
+  // Store value must come from a Binary Operation.
+  Instruction *HPtrInstr = nullptr;
+  BinaryOperator *HBinOp = nullptr;
+  if (!match(HSt, m_Store(m_BinOp(HBinOp), m_Instruction(HPtrInstr)))) {
+    LLVM_DEBUG(dbgs() << "\tNo BinOp\n");
+    return;
+  }
+
+  // BinOp must be an Add or a Sub operating modifying the bucket value by a
+  // loop invariant amount.
+  // FIXME: We assume the loop invariant term is on the RHS.
+  //        Fine for an immediate/constant, but maybe not a generic value?
+  Value *HIncVal = nullptr;
+  if (!match(HBinOp, m_Add(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
+      !match(HBinOp, m_Sub(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal)))) {
+    LLVM_DEBUG(dbgs() << "\tNo matching load\n");
+    return;
+  }
+
+  // The address to store is calculated through a GEP Instruction.
+  // FIXME: Support GEPs with more operands.
+  GetElementPtrInst *HPtr = dyn_cast<GetElementPtrInst>(HPtrInstr);
+  if (!HPtr || HPtr->getNumOperands() > 2) {
+    LLVM_DEBUG(dbgs() << "\tToo many GEP operands\n");
+    return;
+  }
+
+  // Check that the index is calculated by loading from another array. Ignore
+  // any extensions.
+  // FIXME: Support indices from other sources that a linear load from memory?
+  Value *HIdx = HPtr->getOperand(1);
+  Instruction *IdxInst = nullptr;
+  // FIXME: Can this fail? Maybe if IdxInst isn't an instruction. Just need to
+  //        look through extensions, find another way?
+  if (!match(HIdx, m_ZExtOrSExtOrSelf(m_Instruction(IdxInst))))
+    return;
+
+  // Currently restricting this to linear addressing when loading indices.
+  LoadInst *VLoad = dyn_cast<LoadInst>(IdxInst);
+  Value *VPtrVal;
+  if (!VLoad || !match(VLoad, m_Load(m_Value(VPtrVal)))) {
+    LLVM_DEBUG(dbgs() << "\tBad Index Load\n");
+    return;
+  }
+
+  if (!isa<SCEVAddRecExpr>(PSE.getSCEV(VPtrVal))) {
+    LLVM_DEBUG(dbgs() << "\tCannot determine index load stride\n");
+    return;
+  }
+
+  // FIXME: support smaller types of input arrays. Integers can be promoted
+  //        for codegen.
+  Type *VLoadTy = VLoad->getType();
+  if (!VLoadTy->isIntegerTy() || (VLoadTy->getScalarSizeInBits() != 32 &&
+                                  VLoadTy->getScalarSizeInBits() != 64)) {
+    LLVM_DEBUG(dbgs() << "\tUnsupported bucket type: " << *VLoadTy << "\n");
+    return;
+  }
+
+  // A histogram pointer may only alias to itself, and must only have two uses,
+  // the load and the store.
+  for (AliasSet &AS : AST)
+    if (AS.isMustAlias() || AS.isMayAlias())
+      if ((is_contained(AS.getPointers(), HPtr) && AS.size() > 1) ||
+          HPtr->getNumUses() != 2) {
+        LLVM_DEBUG(dbgs() << "\tAliasing problem\n");
+        return;
+      }
+
+  LLVM_DEBUG(dbgs() << "LAA: Found Histogram Operation: " << *HBinOp << "\n");
+  HistogramsDetected++;
+
+  // Store pairs of BinOp (Add/Sub) that modify the count and the index load.
+  HistogramCounts.insert(std::make_pair(HBinOp, VLoad));
+  // Store pointers used to write those counts in the computed histogram.
+  HistogramPtrs.insert(HPtr);
+}
+
 bool LoopAccessInfoManager::invalidate(
     Function &F, const PreservedAnalyses &PA,
     FunctionAnalysisManager::Invalidator &Inv) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 33c899fe889990..18149cb2f5393b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -653,6 +653,10 @@ bool TargetTransformInfo::haveFastSqrt(Type *Ty) const {
   return TTIImpl->haveFastSqrt(Ty);
 }
 
+InstructionCost TargetTransformInfo::getHistogramCost(Type *Ty) const {
+  return TTIImpl->getHistogramCost(Ty);
+}
+
 bool TargetTransformInfo::isExpensiveToSpeculativelyExecute(
     const Instruction *I) const {
   return TTIImpl->isExpensiveToSpeculativelyExecute(I);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7947d73f9a4dd0..ab94df4d05d468 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5293,6 +5293,21 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   SDLoc dl(Op);
   switch (IntNo) {
   default: return SDValue();    // Don't custom lower most intrinsics.
+  case Intrinsic::experimental_vector_histogram_count: {
+    // Replacing IR Intrinsic with AArch64/sve2 histcnt
+    assert((Op.getNumOperands() == 3) &&
+           "Histogram Intrinsic requires 3 operands");
+    EVT Ty = Op.getValueType();
+    assert((Ty == MVT::nxv4i32 || Ty == MVT::nxv2i64) &&
+           "Intrinsic supports only i64 or i32 types");
+    // EVT VT = (Ty == MVT::nxv4i32) ? MVT::i32 : MVT::i64;
+    SDValue InputVector = Op.getOperand(1);
+    SDValue Mask = Op.getOperand(2);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, dl, MVT::i32);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Ty, ID, Mask, InputVector,
+                       InputVector);
+  }
   case Intrinsic::thread_pointer: {
     EVT PtrVT = getPointerTy(DAG.getDataLayout());
     return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e80931a03f30b6..1019b0928bfd75 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -58,6 +58,11 @@ static cl::opt<unsigned> InlineCallPenaltyChangeSM(
 static cl::opt<bool> EnableOrLikeSelectOpt("enable-aarch64-or-like-select",
                                            cl::init(true), cl::Hidden);
 
+// A complete guess as to a reasonable cost.
+static cl::opt<unsigned>
+    BaseHistCntCost("aarch64-base-histcnt-cost", cl::init(8), cl::Hidden,
+                    cl::desc("The cost of a histcnt instruction"));
+
 namespace {
 class TailFoldingOption {
   // These bitfields will only ever be set to something non-zero in operator=,
@@ -505,6 +510,31 @@ static bool isUnpackedVectorVT(EVT VecVT) {
          VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock;
 }
 
+InstructionCost AArch64TTIImpl::getHistogramCost(Type *Ty) const {
+  if (!ST->hasSVE2orSME())
+    return InstructionCost::getInvalid();
+
+  Type *EltTy = Ty->getScalarType();
+
+  // Only allow (<=64b) integers or pointers for now...
+  if ((!EltTy->isIntegerTy() && !EltTy->isPointerTy()) ||
+      EltTy->getScalarSizeInBits() > 64)
+    return InstructionCost::getInvalid();
+
+  // FIXME: Hacky check for legal vector types. We can promote smaller types
+  //        but we cannot legalize vectors via splitting for histcnt.
+  // FIXME: We should be able to generate histcnt for fixed-length vectors
+  //        using ptrue with a specific VL.
+  if (VectorType *VTy = dyn_cast<VectorType>(Ty))
+    if ((VTy->getElementCount().getKnownMinValue() != 2 &&
+      ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-llvm-ir

Author: Graham Hunter (huntergr-arm)

Changes

This is a proof-of-concept using the original intrinsic proposed in https://discourse.llvm.org/t/rfc-vectorization-support-for-histogram-count-operations/74788/5

Much of the work was originally done by @paschalis-mpeis ; I've changed it a little to support arbitrary increments and use the appropriate mask when tail folding.

The proposed alternative in the thread was implemented in #88106 to compare the two approaches (though without LoopVec support).

Given code like the following:

void simple_histogram(int *restrict buckets, unsigned *indices, int N, int inc) {
  for (int i = 0; i &lt; N; ++i)
    buckets[indices[i]] += inc;
}

this patch allows LLVM to generate a vectorized loop for SVE:

simple_histogram:
	.cfi_startproc
	cmp	w2, #<!-- -->1
	b.lt	.LBB0_3

	mov	w8, w2
	mov	z0.s, w3
	ptrue	p0.s
	whilelo	p1.s, xzr, x8
	mov	x9, xzr

.LBB0_2:
	ld1w	{ z1.s }, p1/z, [x1, x9, lsl #<!-- -->2]
	incw	x9
	histcnt	z2.s, p1/z, z1.s, z1.s
	ld1w	{ z3.s }, p1/z, [x0, z1.s, uxtw #<!-- -->2]
	mad	z2.s, p0/m, z0.s, z3.s
	st1w	{ z2.s }, p1, [x0, z1.s, uxtw #<!-- -->2]
	whilelo	p1.s, x9, x8
	b.mi	.LBB0_2

.LBB0_3:
	ret

I'll add a comment in the RFC thread about the tradeoffs we've thought of so far.


Patch is 36.32 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89366.diff

17 Files Affected:

  • (modified) llvm/include/llvm/Analysis/LoopAccessAnalysis.h (+25-2)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+8)
  • (modified) llvm/include/llvm/Analysis/TargetTransformInfoImpl.h (+4)
  • (modified) llvm/include/llvm/CodeGen/BasicTTIImpl.h (+4)
  • (modified) llvm/include/llvm/IR/Intrinsics.td (+6)
  • (modified) llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h (+3)
  • (modified) llvm/lib/Analysis/LoopAccessAnalysis.cpp (+124-8)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+4)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+15)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+30)
  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp (+20)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+44-1)
  • (modified) llvm/lib/Transforms/Vectorize/VPRecipeBuilder.h (+7)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+33)
  • (added) llvm/test/Transforms/LoopVectorize/AArch64/sve2-histcnt.ll (+86)
diff --git a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
index e39c371b41ec5c..0efc9a2814d8fb 100644
--- a/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
+++ b/llvm/include/llvm/Analysis/LoopAccessAnalysis.h
@@ -198,7 +198,8 @@ class MemoryDepChecker {
   bool areDepsSafe(DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
                    const DenseMap<Value *, const SCEV *> &Strides,
                    const DenseMap<Value *, SmallVector<const Value *, 16>>
-                       &UnderlyingObjects);
+                       &UnderlyingObjects,
+                   const SmallPtrSetImpl<Value *> &HistogramPtrs);
 
   /// No memory dependence was encountered that would inhibit
   /// vectorization.
@@ -330,7 +331,8 @@ class MemoryDepChecker {
   isDependent(const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
               unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
               const DenseMap<Value *, SmallVector<const Value *, 16>>
-                  &UnderlyingObjects);
+                  &UnderlyingObjects,
+              const SmallPtrSetImpl<Value *> &HistogramPtrs);
 
   /// Check whether the data dependence could prevent store-load
   /// forwarding.
@@ -608,6 +610,19 @@ class LoopAccessInfo {
   unsigned getNumStores() const { return NumStores; }
   unsigned getNumLoads() const { return NumLoads;}
 
+  const DenseMap<Instruction *, Instruction *> &getHistogramCounts() const {
+    return HistogramCounts;
+  }
+
+  /// Given a Histogram count BinOp \p I, returns the Index Value for the input
+  /// array to compute histogram for.
+  std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
+    auto It = HistogramCounts.find(I);
+    if (It == HistogramCounts.end())
+      return std::nullopt;
+    return It->second;
+  }
+
   /// The diagnostics report generated for the analysis.  E.g. why we
   /// couldn't analyze the loop.
   const OptimizationRemarkAnalysis *getReport() const { return Report.get(); }
@@ -710,6 +725,14 @@ class LoopAccessInfo {
   /// If an access has a symbolic strides, this maps the pointer value to
   /// the stride symbol.
   DenseMap<Value *, const SCEV *> SymbolicStrides;
+
+  /// Holds all the Histogram counts BinOp/Index pairs that we found in the
+  /// loop, where BinOp is an Add/Sub that does the histogram counting, and
+  /// Index is the index of the bucket to compute a histogram for.
+  DenseMap<Instruction *, Instruction *> HistogramCounts;
+
+  /// Storing Histogram Pointers
+  SmallPtrSet<Value *, 2> HistogramPtrs;
 };
 
 /// Return the SCEV corresponding to a pointer with the symbolic stride
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 58c69ac939763a..c136f1dabfa7ce 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -982,6 +982,9 @@ class TargetTransformInfo {
   /// Return hardware support for population count.
   PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) const;
 
+  /// Returns the cost of generating a vector histogram.
+  InstructionCost getHistogramCost(Type *Ty) const;
+
   /// Return true if the hardware has a fast square-root instruction.
   bool haveFastSqrt(Type *Ty) const;
 
@@ -1930,6 +1933,7 @@ class TargetTransformInfo::Concept {
                                               unsigned *Fast) = 0;
   virtual PopcntSupportKind getPopcntSupport(unsigned IntTyWidthInBit) = 0;
   virtual bool haveFastSqrt(Type *Ty) = 0;
+  virtual InstructionCost getHistogramCost(Type *Ty) = 0;
   virtual bool isExpensiveToSpeculativelyExecute(const Instruction *I) = 0;
   virtual bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) = 0;
   virtual InstructionCost getFPOpCost(Type *Ty) = 0;
@@ -2490,6 +2494,10 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
   }
   bool haveFastSqrt(Type *Ty) override { return Impl.haveFastSqrt(Ty); }
 
+  InstructionCost getHistogramCost(Type *Ty) override {
+    return Impl.getHistogramCost(Ty);
+  }
+
   bool isExpensiveToSpeculativelyExecute(const Instruction* I) override {
     return Impl.isExpensiveToSpeculativelyExecute(I);
   }
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index 5b40e49714069f..67c04979cfe5db 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -412,6 +412,10 @@ class TargetTransformInfoImplBase {
 
   bool haveFastSqrt(Type *Ty) const { return false; }
 
+  InstructionCost getHistogramCost(Type *Ty) const {
+    return InstructionCost::getInvalid();
+  }
+
   bool isExpensiveToSpeculativelyExecute(const Instruction *I) { return true; }
 
   bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) const { return true; }
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 06a19c75cf873a..bb65a5a6ce905c 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -538,6 +538,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
            TLI->isOperationLegalOrCustom(ISD::FSQRT, VT);
   }
 
+  InstructionCost getHistogramCost(Type *Ty) {
+    return InstructionCost::getInvalid();
+  }
+
   bool isFCmpOrdCheaperThanFCmpZero(Type *Ty) {
     return true;
   }
diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td
index bdd8465883fcff..92cad82ff03b5e 100644
--- a/llvm/include/llvm/IR/Intrinsics.td
+++ b/llvm/include/llvm/IR/Intrinsics.td
@@ -1849,6 +1849,12 @@ def int_experimental_vp_strided_load  : DefaultAttrsIntrinsic<[llvm_anyvector_ty
                                llvm_i32_ty],
                              [ NoCapture<ArgIndex<0>>, IntrNoSync, IntrReadMem, IntrWillReturn, IntrArgMemOnly ]>;
 
+// Experimental histogram count
+def int_experimental_vector_histogram_count : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+                             [ LLVMMatchType<0>,
+                               LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>],
+                             [ IntrNoMem ]>;
+
 // Operators
 let IntrProperties = [IntrNoMem, IntrNoSync, IntrWillReturn] in {
   // Integer arithmetic
diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
index a509ebf6a7e1b3..8f0e8f26f7fb13 100644
--- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
+++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h
@@ -386,6 +386,9 @@ class LoopVectorizationLegality {
 
   unsigned getNumStores() const { return LAI->getNumStores(); }
   unsigned getNumLoads() const { return LAI->getNumLoads(); }
+  std::optional<Instruction *> getHistogramIndexValue(Instruction *I) const {
+    return LAI->getHistogramIndexValue(I);
+  }
 
   PredicatedScalarEvolution *getPredicatedScalarEvolution() const {
     return &PSE;
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
index 3bfc9700a14559..9fbd47a329c6f9 100644
--- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp
+++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Statistic.h"
 #include "llvm/Analysis/AliasAnalysis.h"
 #include "llvm/Analysis/AliasSetTracker.h"
 #include "llvm/Analysis/LoopAnalysisManager.h"
@@ -69,6 +70,8 @@ using namespace llvm::PatternMatch;
 
 #define DEBUG_TYPE "loop-accesses"
 
+STATISTIC(HistogramsDetected, "Number of Histograms detected");
+
 static cl::opt<unsigned, true>
 VectorizationFactor("force-vector-width", cl::Hidden,
                     cl::desc("Sets the SIMD width. Zero is autoselect."),
@@ -730,6 +733,23 @@ class AccessAnalysis {
     return UnderlyingObjects;
   }
 
+  /// Find Histogram counts that match high-level code in loops:
+  /// \code
+  /// buckets[indices[i]]+=step;
+  /// \endcode
+  ///
+  /// It matches a pattern starting from \p HSt, which Stores to the 'buckets'
+  /// array the computed histogram. It uses a BinOp to sum all counts, storing
+  /// them using a loop-variant index Load from the 'indices' input array.
+  ///
+  /// On successful matches it updates the STATISTIC 'HistogramsDetected',
+  /// regardless of hardware support. When there is support, it additionally
+  /// stores the BinOp/Load pairs in \p HistogramCounts, as well the pointers
+  /// used to update histogram in \p HistogramPtrs.
+  void findHistograms(StoreInst *HSt,
+                      DenseMap<Instruction *, Instruction *> &HistogramCounts,
+                      SmallPtrSetImpl<Value *> &HistogramPtrs);
+
 private:
   typedef MapVector<MemAccessInfo, SmallSetVector<Type *, 1>> PtrAccessMap;
 
@@ -1946,7 +1966,8 @@ getDependenceDistanceStrideAndSize(
     const AccessAnalysis::MemAccessInfo &B, Instruction *BInst,
     const DenseMap<Value *, const SCEV *> &Strides,
     const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
-    PredicatedScalarEvolution &PSE, const Loop *InnermostLoop) {
+    PredicatedScalarEvolution &PSE, const Loop *InnermostLoop,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
   auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout();
   auto &SE = *PSE.getSE();
   auto [APtr, AIsWrite] = A;
@@ -1964,6 +1985,15 @@ getDependenceDistanceStrideAndSize(
       BPtr->getType()->getPointerAddressSpace())
     return MemoryDepChecker::Dependence::Unknown;
 
+  // Ignore Histogram count updates as they are handled by the Intrinsic. This
+  // happens when the same pointer is first used to read from and then is used
+  // to write to.
+  if (!AIsWrite && BIsWrite && APtr == BPtr && HistogramPtrs.contains(APtr)) {
+    LLVM_DEBUG(dbgs() << "LAA: Histogram: Update is safely ignored. Pointer: "
+                      << *APtr);
+    return MemoryDepChecker::Dependence::NoDep;
+  }
+
   int64_t StrideAPtr =
       getPtrStride(PSE, ATy, APtr, InnermostLoop, Strides, true).value_or(0);
   int64_t StrideBPtr =
@@ -2016,15 +2046,15 @@ getDependenceDistanceStrideAndSize(
 MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
     const MemAccessInfo &A, unsigned AIdx, const MemAccessInfo &B,
     unsigned BIdx, const DenseMap<Value *, const SCEV *> &Strides,
-    const DenseMap<Value *, SmallVector<const Value *, 16>>
-        &UnderlyingObjects) {
+    const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
   assert(AIdx < BIdx && "Must pass arguments in program order");
 
   // Get the dependence distance, stride, type size and what access writes for
   // the dependence between A and B.
   auto Res = getDependenceDistanceStrideAndSize(
       A, InstMap[AIdx], B, InstMap[BIdx], Strides, UnderlyingObjects, PSE,
-      InnermostLoop);
+      InnermostLoop, HistogramPtrs);
   if (std::holds_alternative<Dependence::DepType>(Res))
     return std::get<Dependence::DepType>(Res);
 
@@ -2185,8 +2215,8 @@ MemoryDepChecker::Dependence::DepType MemoryDepChecker::isDependent(
 bool MemoryDepChecker::areDepsSafe(
     DepCandidates &AccessSets, MemAccessInfoList &CheckDeps,
     const DenseMap<Value *, const SCEV *> &Strides,
-    const DenseMap<Value *, SmallVector<const Value *, 16>>
-        &UnderlyingObjects) {
+    const DenseMap<Value *, SmallVector<const Value *, 16>> &UnderlyingObjects,
+    const SmallPtrSetImpl<Value *> &HistogramPtrs) {
 
   MinDepDistBytes = -1;
   SmallPtrSet<MemAccessInfo, 8> Visited;
@@ -2231,7 +2261,7 @@ bool MemoryDepChecker::areDepsSafe(
 
             Dependence::DepType Type =
                 isDependent(*A.first, A.second, *B.first, B.second, Strides,
-                            UnderlyingObjects);
+                            UnderlyingObjects, HistogramPtrs);
             mergeInStatus(Dependence::isSafeForVectorization(Type));
 
             // Gather dependences unless we accumulated MaxDependences
@@ -2567,6 +2597,9 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
   // check.
   Accesses.buildDependenceSets();
 
+  for (StoreInst *ST : Stores)
+    Accesses.findHistograms(ST, HistogramCounts, HistogramPtrs);
+
   // Find pointers with computable bounds. We are going to use this information
   // to place a runtime bound check.
   Value *UncomputablePtr = nullptr;
@@ -2591,7 +2624,7 @@ void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI,
     LLVM_DEBUG(dbgs() << "LAA: Checking memory dependencies\n");
     CanVecMem = DepChecker->areDepsSafe(
         DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides,
-        Accesses.getUnderlyingObjects());
+        Accesses.getUnderlyingObjects(), HistogramPtrs);
 
     if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) {
       LLVM_DEBUG(dbgs() << "LAA: Retrying with memory checks\n");
@@ -3025,6 +3058,89 @@ const LoopAccessInfo &LoopAccessInfoManager::getInfo(Loop &L) {
   return *I.first->second;
 }
 
+void AccessAnalysis::findHistograms(
+    StoreInst *HSt, DenseMap<Instruction *, Instruction *> &HistogramCounts,
+    SmallPtrSetImpl<Value *> &HistogramPtrs) {
+  LLVM_DEBUG(dbgs() << "LAA: Attempting to match histogram from " << *HSt
+                    << "\n");
+  // Store value must come from a Binary Operation.
+  Instruction *HPtrInstr = nullptr;
+  BinaryOperator *HBinOp = nullptr;
+  if (!match(HSt, m_Store(m_BinOp(HBinOp), m_Instruction(HPtrInstr)))) {
+    LLVM_DEBUG(dbgs() << "\tNo BinOp\n");
+    return;
+  }
+
+  // BinOp must be an Add or a Sub operating modifying the bucket value by a
+  // loop invariant amount.
+  // FIXME: We assume the loop invariant term is on the RHS.
+  //        Fine for an immediate/constant, but maybe not a generic value?
+  Value *HIncVal = nullptr;
+  if (!match(HBinOp, m_Add(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal))) &&
+      !match(HBinOp, m_Sub(m_Load(m_Specific(HPtrInstr)), m_Value(HIncVal)))) {
+    LLVM_DEBUG(dbgs() << "\tNo matching load\n");
+    return;
+  }
+
+  // The address to store is calculated through a GEP Instruction.
+  // FIXME: Support GEPs with more operands.
+  GetElementPtrInst *HPtr = dyn_cast<GetElementPtrInst>(HPtrInstr);
+  if (!HPtr || HPtr->getNumOperands() > 2) {
+    LLVM_DEBUG(dbgs() << "\tToo many GEP operands\n");
+    return;
+  }
+
+  // Check that the index is calculated by loading from another array. Ignore
+  // any extensions.
+  // FIXME: Support indices from other sources that a linear load from memory?
+  Value *HIdx = HPtr->getOperand(1);
+  Instruction *IdxInst = nullptr;
+  // FIXME: Can this fail? Maybe if IdxInst isn't an instruction. Just need to
+  //        look through extensions, find another way?
+  if (!match(HIdx, m_ZExtOrSExtOrSelf(m_Instruction(IdxInst))))
+    return;
+
+  // Currently restricting this to linear addressing when loading indices.
+  LoadInst *VLoad = dyn_cast<LoadInst>(IdxInst);
+  Value *VPtrVal;
+  if (!VLoad || !match(VLoad, m_Load(m_Value(VPtrVal)))) {
+    LLVM_DEBUG(dbgs() << "\tBad Index Load\n");
+    return;
+  }
+
+  if (!isa<SCEVAddRecExpr>(PSE.getSCEV(VPtrVal))) {
+    LLVM_DEBUG(dbgs() << "\tCannot determine index load stride\n");
+    return;
+  }
+
+  // FIXME: support smaller types of input arrays. Integers can be promoted
+  //        for codegen.
+  Type *VLoadTy = VLoad->getType();
+  if (!VLoadTy->isIntegerTy() || (VLoadTy->getScalarSizeInBits() != 32 &&
+                                  VLoadTy->getScalarSizeInBits() != 64)) {
+    LLVM_DEBUG(dbgs() << "\tUnsupported bucket type: " << *VLoadTy << "\n");
+    return;
+  }
+
+  // A histogram pointer may only alias to itself, and must only have two uses,
+  // the load and the store.
+  for (AliasSet &AS : AST)
+    if (AS.isMustAlias() || AS.isMayAlias())
+      if ((is_contained(AS.getPointers(), HPtr) && AS.size() > 1) ||
+          HPtr->getNumUses() != 2) {
+        LLVM_DEBUG(dbgs() << "\tAliasing problem\n");
+        return;
+      }
+
+  LLVM_DEBUG(dbgs() << "LAA: Found Histogram Operation: " << *HBinOp << "\n");
+  HistogramsDetected++;
+
+  // Store pairs of BinOp (Add/Sub) that modify the count and the index load.
+  HistogramCounts.insert(std::make_pair(HBinOp, VLoad));
+  // Store pointers used to write those counts in the computed histogram.
+  HistogramPtrs.insert(HPtr);
+}
+
 bool LoopAccessInfoManager::invalidate(
     Function &F, const PreservedAnalyses &PA,
     FunctionAnalysisManager::Invalidator &Inv) {
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 33c899fe889990..18149cb2f5393b 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -653,6 +653,10 @@ bool TargetTransformInfo::haveFastSqrt(Type *Ty) const {
   return TTIImpl->haveFastSqrt(Ty);
 }
 
+InstructionCost TargetTransformInfo::getHistogramCost(Type *Ty) const {
+  return TTIImpl->getHistogramCost(Ty);
+}
+
 bool TargetTransformInfo::isExpensiveToSpeculativelyExecute(
     const Instruction *I) const {
   return TTIImpl->isExpensiveToSpeculativelyExecute(I);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 7947d73f9a4dd0..ab94df4d05d468 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -5293,6 +5293,21 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
   SDLoc dl(Op);
   switch (IntNo) {
   default: return SDValue();    // Don't custom lower most intrinsics.
+  case Intrinsic::experimental_vector_histogram_count: {
+    // Replacing IR Intrinsic with AArch64/sve2 histcnt
+    assert((Op.getNumOperands() == 3) &&
+           "Histogram Intrinsic requires 3 operands");
+    EVT Ty = Op.getValueType();
+    assert((Ty == MVT::nxv4i32 || Ty == MVT::nxv2i64) &&
+           "Intrinsic supports only i64 or i32 types");
+    // EVT VT = (Ty == MVT::nxv4i32) ? MVT::i32 : MVT::i64;
+    SDValue InputVector = Op.getOperand(1);
+    SDValue Mask = Op.getOperand(2);
+    SDValue ID =
+        DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, dl, MVT::i32);
+    return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Ty, ID, Mask, InputVector,
+                       InputVector);
+  }
   case Intrinsic::thread_pointer: {
     EVT PtrVT = getPointerTy(DAG.getDataLayout());
     return DAG.getNode(AArch64ISD::THREAD_POINTER, dl, PtrVT);
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index e80931a03f30b6..1019b0928bfd75 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -58,6 +58,11 @@ static cl::opt<unsigned> InlineCallPenaltyChangeSM(
 static cl::opt<bool> EnableOrLikeSelectOpt("enable-aarch64-or-like-select",
                                            cl::init(true), cl::Hidden);
 
+// A complete guess as to a reasonable cost.
+static cl::opt<unsigned>
+    BaseHistCntCost("aarch64-base-histcnt-cost", cl::init(8), cl::Hidden,
+                    cl::desc("The cost of a histcnt instruction"));
+
 namespace {
 class TailFoldingOption {
   // These bitfields will only ever be set to something non-zero in operator=,
@@ -505,6 +510,31 @@ static bool isUnpackedVectorVT(EVT VecVT) {
          VecVT.getSizeInBits().getKnownMinValue() < AArch64::SVEBitsPerBlock;
 }
 
+InstructionCost AArch64TTIImpl::getHistogramCost(Type *Ty) const {
+  if (!ST->hasSVE2orSME())
+    return InstructionCost::getInvalid();
+
+  Type *EltTy = Ty->getScalarType();
+
+  // Only allow (<=64b) integers or pointers for now...
+  if ((!EltTy->isIntegerTy() && !EltTy->isPointerTy()) ||
+      EltTy->getScalarSizeInBits() > 64)
+    return InstructionCost::getInvalid();
+
+  // FIXME: Hacky check for legal vector types. We can promote smaller types
+  //        but we cannot legalize vectors via splitting for histcnt.
+  // FIXME: We should be able to generate histcnt for fixed-length vectors
+  //        using ptrue with a specific VL.
+  if (VectorType *VTy = dyn_cast<VectorType>(Ty))
+    if ((VTy->getElementCount().getKnownMinValue() != 2 &&
+      ...
[truncated]

@huntergr-arm
Copy link
Collaborator Author

The all-in-one approach seems to have more support (see comments and approval on #88106), so I'll repurpose this PR to handle the LoopVec side of that unless there are objections.

@huntergr-arm
Copy link
Collaborator Author

Superseded by the all-in-one intrinsic

@huntergr-arm huntergr-arm deleted the standalone-histogram-intrinsic-poc branch July 11, 2024 14:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants