Skip to content

[SandboxVec][BottomUpVec] Use SeedCollector and slice seeds #120826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/include/llvm/SandboxIR/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ namespace llvm {

class AAResults;
class ScalarEvolution;
class TargetTransformInfo;

namespace sandboxir {

Expand All @@ -25,15 +26,18 @@ class Region;
class Analyses {
AAResults *AA = nullptr;
ScalarEvolution *SE = nullptr;
TargetTransformInfo *TTI = nullptr;

Analyses() = default;

public:
Analyses(AAResults &AA, ScalarEvolution &SE) : AA(&AA), SE(&SE) {}
Analyses(AAResults &AA, ScalarEvolution &SE, TargetTransformInfo &TTI)
: AA(&AA), SE(&SE), TTI(&TTI) {}

public:
AAResults &getAA() const { return *AA; }
ScalarEvolution &getScalarEvolution() const { return *SE; }
TargetTransformInfo &getTTI() const { return *TTI; }
/// For use by unit tests.
static Analyses emptyForTesting() { return Analyses(); }
};
Expand Down
7 changes: 6 additions & 1 deletion llvm/include/llvm/SandboxIR/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,16 @@ class Utils {
getUnderlyingObject(LSI->getPointerOperand()->Val));
}

/// \Returns the number of bits of \p Ty.
static unsigned getNumBits(Type *Ty, const DataLayout &DL) {
return DL.getTypeSizeInBits(Ty->LLVMTy);
}

/// \Returns the number of bits required to represent the operands or return
/// value of \p V in \p DL.
static unsigned getNumBits(Value *V, const DataLayout &DL) {
Type *Ty = getExpectedType(V);
return DL.getTypeSizeInBits(Ty->LLVMTy);
return getNumBits(Ty, DL);
}

/// \Returns the number of bits required to represent the operands or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class LegalityAnalysis {
// TODO: Try to remove the SkipScheduling argument by refactoring the tests.
const LegalityResult &canVectorize(ArrayRef<Value *> Bndl,
bool SkipScheduling = false);
void clear() { Sched.clear(); }
};

} // namespace llvm::sandboxir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ class Scheduler {
~Scheduler() {}

bool trySchedule(ArrayRef<Instruction *> Instrs);
/// Clear the scheduler's state, including the DAG.
void clear() {
Bndls.clear();
// TODO: clear view once it lands.
DAG.clear();
ScheduleTopItOpt = std::nullopt;
}

#ifndef NDEBUG
void dump(raw_ostream &OS) const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class SeedBundle {
/// with a total size <= \p MaxVecRegBits, or an empty slice if the
/// requirements cannot be met . If \p ForcePowOf2 is true, then the returned
/// slice will have a total number of bits that is a power of 2.
MutableArrayRef<Instruction *>
getSlice(unsigned StartIdx, unsigned MaxVecRegBits, bool ForcePowOf2);
ArrayRef<Instruction *> getSlice(unsigned StartIdx, unsigned MaxVecRegBits,
bool ForcePowOf2);

/// \Returns the number of seed elements in the bundle.
std::size_t size() const { return Seeds.size(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ class VecUtils {
assert(tryGetCommonScalarType(Bndl) && "Expected common scalar type!");
return ScalarTy;
}
/// \Returns the first integer power of 2 that is <= Num.
static unsigned getFloorPowerOf2(unsigned Num) {
if (Num == 0)
return Num;
unsigned Mask = Num;
Mask >>= 1;
for (unsigned ShiftBy = 1; ShiftBy < sizeof(Num) * 8; ShiftBy <<= 1)
Mask |= Mask >> ShiftBy;
return Num & ~Mask;
}
};

} // namespace llvm::sandboxir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,31 @@

#include "llvm/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/SandboxIR/Function.h"
#include "llvm/SandboxIR/Instruction.h"
#include "llvm/SandboxIR/Module.h"
#include "llvm/SandboxIR/Utils.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SandboxVectorizerPassBuilder.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/SeedCollector.h"
#include "llvm/Transforms/Vectorize/SandboxVectorizer/VecUtils.h"

namespace llvm::sandboxir {
namespace llvm {

static cl::opt<unsigned>
OverrideVecRegBits("sbvec-vec-reg-bits", cl::init(0), cl::Hidden,
cl::desc("Override the vector register size in bits, "
"which is otherwise found by querying TTI."));
static cl::opt<bool>
AllowNonPow2("sbvec-allow-non-pow2", cl::init(false), cl::Hidden,
cl::desc("Allow non-power-of-2 vectorization."));

namespace sandboxir {

BottomUpVec::BottomUpVec(StringRef Pipeline)
: FunctionPass("bottom-up-vec"),
RPM("rpm", Pipeline, SandboxVectorizerPassBuilder::createRegionPass) {}

// TODO: This is a temporary function that returns some seeds.
// Replace this with SeedCollector's function when it lands.
static llvm::SmallVector<Value *, 4> collectSeeds(BasicBlock &BB) {
llvm::SmallVector<Value *, 4> Seeds;
for (auto &I : BB)
if (auto *SI = llvm::dyn_cast<StoreInst>(&I))
Seeds.push_back(SI);
return Seeds;
}

static SmallVector<Value *, 4> getOperand(ArrayRef<Value *> Bndl,
unsigned OpIdx) {
SmallVector<Value *, 4> Operands;
Expand Down Expand Up @@ -265,6 +267,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl, unsigned Depth) {

bool BottomUpVec::tryVectorize(ArrayRef<Value *> Bndl) {
DeadInstrCandidates.clear();
Legality->clear();
vectorizeRec(Bndl, /*Depth=*/0);
tryEraseDeadInstrs();
return Change;
Expand All @@ -275,17 +278,67 @@ bool BottomUpVec::runOnFunction(Function &F, const Analyses &A) {
A.getAA(), A.getScalarEvolution(), F.getParent()->getDataLayout(),
F.getContext());
Change = false;
const auto &DL = F.getParent()->getDataLayout();
unsigned VecRegBits =
OverrideVecRegBits != 0
? OverrideVecRegBits
: A.getTTI()
.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
.getFixedValue();

// TODO: Start from innermost BBs first
for (auto &BB : F) {
// TODO: Replace with proper SeedCollector function.
auto Seeds = collectSeeds(BB);
// TODO: Slice Seeds into smaller chunks.
// TODO: If vectorization succeeds, run the RegionPassManager on the
// resulting region.
if (Seeds.size() >= 2)
Change |= tryVectorize(Seeds);
SeedCollector SC(&BB, A.getScalarEvolution());
for (SeedBundle &Seeds : SC.getStoreSeeds()) {
unsigned ElmBits =
Utils::getNumBits(VecUtils::getElementType(Utils::getExpectedType(
Seeds[Seeds.getFirstUnusedElementIdx()])),
DL);

auto DivideBy2 = [](unsigned Num) {
auto Floor = VecUtils::getFloorPowerOf2(Num);
if (Floor == Num)
return Floor / 2;
return Floor;
};
// Try to create the largest vector supported by the target. If it fails
// reduce the vector size by half.
for (unsigned SliceElms = std::min(VecRegBits / ElmBits,
Seeds.getNumUnusedBits() / ElmBits);
SliceElms >= 2u; SliceElms = DivideBy2(SliceElms)) {
if (Seeds.allUsed())
break;
// Keep trying offsets after FirstUnusedElementIdx, until we vectorize
// the slice. This could be quite expensive, so we enforce a limit.
for (unsigned Offset = Seeds.getFirstUnusedElementIdx(),
OE = Seeds.size();
Offset + 1 < OE; Offset += 1) {
// Seeds are getting used as we vectorize, so skip them.
if (Seeds.isUsed(Offset))
continue;
if (Seeds.allUsed())
break;

auto SeedSlice =
Seeds.getSlice(Offset, SliceElms * ElmBits, !AllowNonPow2);
if (SeedSlice.empty())
continue;

assert(SeedSlice.size() >= 2 && "Should have been rejected!");

// TODO: If vectorization succeeds, run the RegionPassManager on the
// resulting region.

// TODO: Refactor to remove the unnecessary copy to SeedSliceVals.
SmallVector<Value *> SeedSliceVals(SeedSlice.begin(),
SeedSlice.end());
Change |= tryVectorize(SeedSliceVals);
}
}
}
}
return Change;
}

} // namespace llvm::sandboxir
} // namespace sandboxir
} // namespace llvm
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,6 @@ bool SandboxVectorizerPass::runImpl(Function &LLVMF) {

// Create SandboxIR for LLVMF and run BottomUpVec on it.
sandboxir::Function &F = *Ctx->createFunction(&LLVMF);
sandboxir::Analyses A(*AA, *SE);
sandboxir::Analyses A(*AA, *SE, *TTI);
return FPM.runOnFunction(F, A);
}
17 changes: 10 additions & 7 deletions llvm/lib/Transforms/Vectorize/SandboxVectorizer/SeedCollector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ cl::opt<unsigned> SeedGroupsLimit(
cl::desc("Limit the number of collected seeds groups in a BB to "
"cap compilation time."));

MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
unsigned MaxVecRegBits,
bool ForcePowerOf2) {
ArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
unsigned MaxVecRegBits,
bool ForcePowerOf2) {
// Use uint32_t here for compatibility with IsPowerOf2_32

// BitCount tracks the size of the working slice. From that we can tell
Expand All @@ -47,10 +47,13 @@ MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
// Can't start a slice with a used instruction.
assert(!isUsed(StartIdx) && "Expected unused at StartIdx");
for (auto S : make_range(Seeds.begin() + StartIdx, Seeds.end())) {
// Stop if this instruction is used. This needs to be done before
// getNumBits() because a "used" instruction may have been erased.
if (isUsed(StartIdx + NumElements))
break;
uint32_t InstBits = Utils::getNumBits(S);
// Stop if this instruction is used, or if adding it puts the slice over
// the limit.
if (isUsed(StartIdx + NumElements) || BitCount + InstBits > MaxVecRegBits)
// Stop if adding it puts the slice over the limit.
if (BitCount + InstBits > MaxVecRegBits)
break;
NumElements++;
BitCount += InstBits;
Expand All @@ -68,7 +71,7 @@ MutableArrayRef<Instruction *> SeedBundle::getSlice(unsigned StartIdx,
"Must be a power of two");
// Return any non-empty slice
if (NumElements > 1)
return MutableArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
return ArrayRef<Instruction *>(&Seeds[StartIdx], NumElements);
else
return {};
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=sandbox-vectorizer -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s

define void @store_load(ptr %ptr) {
; CHECK-LABEL: define void @store_load(
Expand Down
33 changes: 33 additions & 0 deletions llvm/test/Transforms/SandboxVectorizer/bottomup_seed_slice.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2 -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s


declare void @foo()
define void @slice_seeds(ptr %ptr, float %val) {
; CHECK-LABEL: define void @slice_seeds(
; CHECK-SAME: ptr [[PTR:%.*]], float [[VAL:%.*]]) {
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; CHECK-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; CHECK-NEXT: [[PTR2:%.*]] = getelementptr float, ptr [[PTR]], i32 2
; CHECK-NEXT: [[LD2:%.*]] = load float, ptr [[PTR2]], align 4
; CHECK-NEXT: store float [[LD2]], ptr [[PTR2]], align 4
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; CHECK-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4
; CHECK-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ptr2 = getelementptr float, ptr %ptr, i32 2

%ld2 = load float, ptr %ptr2
store float %ld2, ptr %ptr2
; This call blocks scheduling of all 3 stores.
call void @foo()

%ld0 = load float, ptr %ptr0
%ld1 = load float, ptr %ptr1
store float %ld0, ptr %ptr0
store float %ld1, ptr %ptr1
ret void
}
37 changes: 37 additions & 0 deletions llvm/test/Transforms/SandboxVectorizer/bottomup_seed_slice_pow2.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2=false -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s --check-prefix=POW2
; RUN: opt -passes=sandbox-vectorizer -sbvec-vec-reg-bits=1024 -sbvec-allow-non-pow2=true -sbvec-passes="bottom-up-vec<>" %s -S | FileCheck %s --check-prefix=NON-POW2

define void @pow2(ptr %ptr, float %val) {
; POW2-LABEL: define void @pow2(
; POW2-SAME: ptr [[PTR:%.*]], float [[VAL:%.*]]) {
; POW2-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; POW2-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; POW2-NEXT: [[PTR2:%.*]] = getelementptr float, ptr [[PTR]], i32 2
; POW2-NEXT: [[VECL:%.*]] = load <2 x float>, ptr [[PTR0]], align 4
; POW2-NEXT: [[LD2:%.*]] = load float, ptr [[PTR2]], align 4
; POW2-NEXT: store <2 x float> [[VECL]], ptr [[PTR0]], align 4
; POW2-NEXT: store float [[LD2]], ptr [[PTR2]], align 4
; POW2-NEXT: ret void
;
; NON-POW2-LABEL: define void @pow2(
; NON-POW2-SAME: ptr [[PTR:%.*]], float [[VAL:%.*]]) {
; NON-POW2-NEXT: [[PTR0:%.*]] = getelementptr float, ptr [[PTR]], i32 0
; NON-POW2-NEXT: [[PTR1:%.*]] = getelementptr float, ptr [[PTR]], i32 1
; NON-POW2-NEXT: [[PTR2:%.*]] = getelementptr float, ptr [[PTR]], i32 2
; NON-POW2-NEXT: [[PACK2:%.*]] = load <3 x float>, ptr [[PTR0]], align 4
; NON-POW2-NEXT: store <3 x float> [[PACK2]], ptr [[PTR0]], align 4
; NON-POW2-NEXT: ret void
;
%ptr0 = getelementptr float, ptr %ptr, i32 0
%ptr1 = getelementptr float, ptr %ptr, i32 1
%ptr2 = getelementptr float, ptr %ptr, i32 2

%ld0 = load float, ptr %ptr0
%ld1 = load float, ptr %ptr1
%ld2 = load float, ptr %ptr2
store float %ld0, ptr %ptr0
store float %ld1, ptr %ptr1
store float %ld2, ptr %ptr2
ret void
}
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,14 @@ define void @foo(i8 %v, ptr %ptr) {
#endif // NDEBUG
}
}

TEST_F(VecUtilsTest, FloorPowerOf2) {
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(0), 0u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(1 << 0), 1u << 0);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(3), 2u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(4), 4u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(5), 4u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(7), 4u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(8), 8u);
EXPECT_EQ(sandboxir::VecUtils::getFloorPowerOf2(9), 8u);
}
Loading