Skip to content

[SYCL] Move allow-device-image-dependencies out of ModuleSplitter #18060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion llvm/include/llvm/SYCLPostLink/ComputeModuleRuntimeInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ PropSetRegTy computeDeviceLibProperties(const Module &M,

PropSetRegTy computeModuleProperties(const Module &M,
const EntryPointSet &EntryPoints,
const GlobalBinImageProps &GlobProps);
const GlobalBinImageProps &GlobProps,
bool AllowDeviceImageDependencies);

std::string computeModuleSymbolTable(const Module &M,
const EntryPointSet &EntryPoints);
Expand Down
21 changes: 13 additions & 8 deletions llvm/include/llvm/SYCLPostLink/ModuleSplitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ constexpr char SYCL_ESIMD_SPLIT_MD_NAME[] = "sycl-esimd-split-status";
constexpr std::array<const char *, 2> SYCLDeviceLibs = {
"libsycl-fallback-bfloat16.bc", "libsycl-native-bfloat16.bc"};

extern cl::OptionCategory &getModuleSplitCategory();

enum IRSplitMode {
SPLIT_PER_TU, // one module per translation unit
SPLIT_PER_KERNEL, // one module per kernel
Expand Down Expand Up @@ -221,7 +219,7 @@ class ModuleDesc {
void restoreLinkageOfDirectInvokeSimdTargets();

// Cleans up module IR - removes dead globals, debug info etc.
void cleanup();
void cleanup(bool AllowDeviceImageDependencies);

bool isSpecConstantDefault() const;
void setSpecConstantDefault(bool Value);
Expand Down Expand Up @@ -252,6 +250,7 @@ class ModuleSplitterBase {
protected:
ModuleDesc Input;
EntryPointGroupVec Groups;
bool AllowDeviceImageDependencies;

protected:
EntryPointGroup nextGroup() {
Expand All @@ -268,8 +267,10 @@ class ModuleSplitterBase {
}

public:
ModuleSplitterBase(ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
: Input(std::move(MD)), Groups(std::move(GroupVec)) {
ModuleSplitterBase(ModuleDesc &&MD, EntryPointGroupVec &&GroupVec,
bool AllowDeviceImageDependencies)
: Input(std::move(MD)), Groups(std::move(GroupVec)),
AllowDeviceImageDependencies(AllowDeviceImageDependencies) {
assert(!Groups.empty() && "Entry points groups collection is empty!");
}

Expand All @@ -294,11 +295,13 @@ class ModuleSplitterBase {
};

SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
bool EmitOnlyKernelsAsEntryPoints);
bool EmitOnlyKernelsAsEntryPoints,
bool AllowDeviceImageDependencies);

std::unique_ptr<ModuleSplitterBase>
getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
bool EmitOnlyKernelsAsEntryPoints);
bool EmitOnlyKernelsAsEntryPoints,
bool AllowDeviceImageDependencies);

#ifndef NDEBUG
void dumpEntryPoints(const EntryPointSet &C, const char *Msg = "", int Tab = 0);
Expand Down Expand Up @@ -327,6 +330,7 @@ struct ModuleSplitterSettings {
IRSplitMode Mode;
bool OutputAssembly = false; // Bitcode or LLVM IR.
StringRef OutputPrefix;
bool AllowDeviceImageDependencies = false;
};

/// Parses the output table file from sycl-post-link tool.
Expand All @@ -342,7 +346,8 @@ Expected<std::vector<SplitModule>>
splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings);

bool isESIMDFunction(const Function &F);
bool canBeImportedFunction(const Function &F);
bool canBeImportedFunction(const Function &F,
bool AllowDeviceImageDependencies);

} // namespace module_split

Expand Down
6 changes: 4 additions & 2 deletions llvm/lib/SYCLPostLink/ComputeModuleRuntimeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ PropSetRegTy computeDeviceLibProperties(const Module &M,

PropSetRegTy computeModuleProperties(const Module &M,
const EntryPointSet &EntryPoints,
const GlobalBinImageProps &GlobProps) {
const GlobalBinImageProps &GlobProps,
bool AllowDeviceImageDependencies) {

PropSetRegTy PropSet;
{
Expand Down Expand Up @@ -286,7 +287,8 @@ PropSetRegTy computeModuleProperties(const Module &M,
if (F.hasFnAttribute("indirectly-callable"))
continue;

if (module_split::canBeImportedFunction(F)) {
if (module_split::canBeImportedFunction(F,
AllowDeviceImageDependencies)) {
// StripDeadPrototypes is called during module splitting
// cleanup. At this point all function decls should have uses.
assert(!F.use_empty() && "Function F has no uses");
Expand Down
96 changes: 52 additions & 44 deletions llvm/lib/SYCLPostLink/ModuleSplitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ constexpr char SYCL_SCOPE_NAME[] = "<SYCL>";
constexpr char ESIMD_SCOPE_NAME[] = "<ESIMD>";
constexpr char ESIMD_MARKER_MD[] = "sycl_explicit_simd";

cl::opt<bool> AllowDeviceImageDependencies{
"allow-device-image-dependencies",
cl::desc("Allow dependencies between device images"),
cl::cat(getModuleSplitCategory()), cl::init(false)};

EntryPointsGroupScope selectDeviceCodeGroupScope(const Module &M,
IRSplitMode Mode,
bool AutoSplitIsGlobalScope) {
Expand Down Expand Up @@ -178,7 +173,7 @@ class DependencyGraph {
public:
using GlobalSet = SmallPtrSet<const GlobalValue *, 16>;

DependencyGraph(const Module &M) {
DependencyGraph(const Module &M, bool AllowDeviceImageDependencies) {
// Group functions by their signature to handle case (2) described above
DenseMap<const FunctionType *, DependencyGraph::GlobalSet>
FuncTypeToFuncsMap;
Expand All @@ -196,7 +191,7 @@ class DependencyGraph {
}

for (const auto &F : M.functions()) {
if (canBeImportedFunction(F))
if (canBeImportedFunction(F, AllowDeviceImageDependencies))
continue;

// case (1), see comment above the class definition
Expand Down Expand Up @@ -311,7 +306,9 @@ static bool isIntrinsicOrBuiltin(const Function &F) {
}

// Checks for use of undefined user functions and emits a warning message.
static void checkForCallsToUndefinedFunctions(const Module &M) {
static void
checkForCallsToUndefinedFunctions(const Module &M,
bool AllowDeviceImageDependencies) {
if (AllowDeviceImageDependencies)
return;
for (const Function &F : M) {
Expand Down Expand Up @@ -391,11 +388,11 @@ ModuleDesc extractSubModule(const ModuleDesc &MD,
// The function produces a copy of input LLVM IR module M with only those
// functions and globals that can be called from entry points that are specified
// in ModuleEntryPoints vector, in addition to the entry point functions.
ModuleDesc extractCallGraph(const ModuleDesc &MD,
EntryPointGroup &&ModuleEntryPoints,
const DependencyGraph &CG,
const std::function<bool(const Function *)>
&IncludeFunctionPredicate = nullptr) {
ModuleDesc extractCallGraph(
const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
const DependencyGraph &CG, bool AllowDeviceImageDependencies,
const std::function<bool(const Function *)> &IncludeFunctionPredicate =
nullptr) {
SetVector<const GlobalValue *> GVs;
collectFunctionsAndGlobalVariablesToExtract(
GVs, MD.getModule(), ModuleEntryPoints, CG, IncludeFunctionPredicate);
Expand All @@ -405,20 +402,21 @@ ModuleDesc extractCallGraph(const ModuleDesc &MD,
// sycl-post-link. This call is redundant. However, we subsequently run
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
// can be removed once that pass no longer depends on this cleanup.
SplitM.cleanup();
checkForCallsToUndefinedFunctions(SplitM.getModule());
SplitM.cleanup(AllowDeviceImageDependencies);
checkForCallsToUndefinedFunctions(SplitM.getModule(),
AllowDeviceImageDependencies);

return SplitM;
}

// The function is similar to 'extractCallGraph', but it produces a copy of
// input LLVM IR module M with _all_ ESIMD functions and kernels included,
// regardless of whether or not they are listed in ModuleEntryPoints.
ModuleDesc extractESIMDSubModule(const ModuleDesc &MD,
EntryPointGroup &&ModuleEntryPoints,
const DependencyGraph &CG,
const std::function<bool(const Function *)>
&IncludeFunctionPredicate = nullptr) {
ModuleDesc extractESIMDSubModule(
const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
const DependencyGraph &CG, bool AllowDeviceImageDependencies,
const std::function<bool(const Function *)> &IncludeFunctionPredicate =
nullptr) {
SetVector<const GlobalValue *> GVs;
for (const auto &F : MD.getModule().functions())
if (isESIMDFunction(F))
Expand All @@ -432,7 +430,7 @@ ModuleDesc extractESIMDSubModule(const ModuleDesc &MD,
// sycl-post-link. This call is redundant. However, we subsequently run
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
// can be removed once that pass no longer depends on this cleanup.
SplitM.cleanup();
SplitM.cleanup(AllowDeviceImageDependencies);

return SplitM;
}
Expand All @@ -449,19 +447,22 @@ class ModuleCopier : public ModuleSplitterBase {
// sycl-post-link. This call is redundant. However, we subsequently run
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup
// call can be removed once that pass no longer depends on this cleanup.
Desc.cleanup();
Desc.cleanup(AllowDeviceImageDependencies);
return Desc;
}
};

class ModuleSplitter : public ModuleSplitterBase {
public:
ModuleSplitter(ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
: ModuleSplitterBase(std::move(MD), std::move(GroupVec)),
CG(Input.getModule()) {}
ModuleSplitter(ModuleDesc &&MD, EntryPointGroupVec &&GroupVec,
bool AllowDeviceImageDependencies)
: ModuleSplitterBase(std::move(MD), std::move(GroupVec),
AllowDeviceImageDependencies),
CG(Input.getModule(), AllowDeviceImageDependencies) {}

ModuleDesc nextSplit() override {
return extractCallGraph(Input, nextGroup(), CG);
return extractCallGraph(Input, nextGroup(), CG,
AllowDeviceImageDependencies);
}

private:
Expand Down Expand Up @@ -489,11 +490,6 @@ bool isESIMDFunction(const Function &F) {
return F.getMetadata(ESIMD_MARKER_MD) != nullptr;
}

cl::OptionCategory &getModuleSplitCategory() {
static cl::OptionCategory ModuleSplitCategory{"Module Split options"};
return ModuleSplitCategory;
}

Error ModuleSplitterBase::verifyNoCrossModuleDeviceGlobalUsage() {
const Module &M = getInputModule();
// Early exit if there is only one group
Expand Down Expand Up @@ -692,7 +688,8 @@ void ModuleDesc::restoreLinkageOfDirectInvokeSimdTargets() {
// tries to internalize absolutely everything. This function serves as "input
// from a linker" that tells the pass what must be preserved in order to make
// the transformation safe.
static bool mustPreserveGV(const GlobalValue &GV) {
static bool mustPreserveGV(const GlobalValue &GV,
bool AllowDeviceImageDependencies) {
if (const Function *F = dyn_cast<Function>(&GV)) {
// When dynamic linking is supported, we internalize everything (except
// kernels which are the entry points from host code to device code) that
Expand All @@ -703,7 +700,8 @@ static bool mustPreserveGV(const GlobalValue &GV) {
const bool SpirOrGPU = CC == CallingConv::SPIR_KERNEL ||
CC == CallingConv::AMDGPU_KERNEL ||
CC == CallingConv::PTX_Kernel;
return SpirOrGPU || canBeImportedFunction(*F);
return SpirOrGPU ||
canBeImportedFunction(*F, AllowDeviceImageDependencies);
}

// Otherwise, we are being even more aggressive: SYCL modules are expected
Expand Down Expand Up @@ -754,7 +752,7 @@ void cleanupSYCLRegisteredKernels(Module *M) {

// TODO: try to move all passes (cleanup, spec consts, compile time properties)
// in one place and execute MPM.run() only once.
void ModuleDesc::cleanup() {
void ModuleDesc::cleanup(bool AllowDeviceImageDependencies) {
// Any definitions of virtual functions should be removed and turned into
// declarations, they are supposed to be provided by a different module.
if (!EntryPoints.Props.HasVirtualFunctionDefinitions) {
Expand All @@ -781,7 +779,10 @@ void ModuleDesc::cleanup() {
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
ModulePassManager MPM;
// Do cleanup.
MPM.addPass(InternalizePass(mustPreserveGV));
MPM.addPass(
InternalizePass([AllowDeviceImageDependencies](const GlobalValue &GV) {
return mustPreserveGV(GV, AllowDeviceImageDependencies);
}));
MPM.addPass(GlobalDCEPass()); // Delete unreachable globals.
MPM.addPass(StripDeadDebugInfoPass()); // Remove dead debug info.
MPM.addPass(StripDeadPrototypesPass()); // Remove dead func decls.
Expand Down Expand Up @@ -1157,7 +1158,8 @@ std::string FunctionsCategorizer::computeCategoryFor(Function *F) const {

std::unique_ptr<ModuleSplitterBase>
getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
bool EmitOnlyKernelsAsEntryPoints) {
bool EmitOnlyKernelsAsEntryPoints,
bool AllowDeviceImageDependencies) {
FunctionsCategorizer Categorizer;

EntryPointsGroupScope Scope =
Expand Down Expand Up @@ -1252,9 +1254,11 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
(Groups.size() > 1 || !Groups.cbegin()->Functions.empty()));

if (DoSplit)
return std::make_unique<ModuleSplitter>(std::move(MD), std::move(Groups));
return std::make_unique<ModuleSplitter>(std::move(MD), std::move(Groups),
AllowDeviceImageDependencies);

return std::make_unique<ModuleCopier>(std::move(MD), std::move(Groups));
return std::make_unique<ModuleCopier>(std::move(MD), std::move(Groups),
AllowDeviceImageDependencies);
}

// Splits input module into two:
Expand All @@ -1277,7 +1281,8 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
// avoid undefined behavior at later stages. That is done at higher level,
// outside of this function.
SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
bool EmitOnlyKernelsAsEntryPoints) {
bool EmitOnlyKernelsAsEntryPoints,
bool AllowDeviceImageDependencies) {

SmallVector<module_split::ModuleDesc, 2> Result;
EntryPointGroupVec EntryPointGroups{};
Expand Down Expand Up @@ -1320,12 +1325,13 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
return Result;
}

DependencyGraph CG(MD.getModule());
DependencyGraph CG(MD.getModule(), AllowDeviceImageDependencies);
for (auto &Group : EntryPointGroups) {
if (Group.isEsimd()) {
// For ESIMD module, we use full call graph of all entry points and all
// ESIMD functions.
Result.emplace_back(extractESIMDSubModule(MD, std::move(Group), CG));
Result.emplace_back(extractESIMDSubModule(MD, std::move(Group), CG,
AllowDeviceImageDependencies));
} else {
// For non-ESIMD module we only use non-ESIMD functions. Additional filter
// is needed, because there could be uses of ESIMD functions from
Expand All @@ -1334,7 +1340,7 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
// were processed and therefore it is fine to return an "incomplete"
// module here.
Result.emplace_back(extractCallGraph(
MD, std::move(Group), CG,
MD, std::move(Group), CG, AllowDeviceImageDependencies,
[=](const Function *F) -> bool { return !isESIMDFunction(*F); }));
}
}
Expand Down Expand Up @@ -1477,7 +1483,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
// FIXME: false arguments are temporary for now.
auto Splitter = getDeviceCodeSplitter(std::move(MD), Settings.Mode,
/*IROutputOnly=*/false,
/*EmitOnlyKernelsAsEntryPoints=*/false);
/*EmitOnlyKernelsAsEntryPoints=*/false,
Settings.AllowDeviceImageDependencies);

size_t ID = 0;
std::vector<SplitModule> OutputImages;
Expand All @@ -1498,7 +1505,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
return OutputImages;
}

bool canBeImportedFunction(const Function &F) {
bool canBeImportedFunction(const Function &F,
bool AllowDeviceImageDependencies) {

// We use sycl dynamic library mechanism to involve bf16 devicelib when
// necessary, all __devicelib_* functions from native or fallback bf16
Expand Down
6 changes: 6 additions & 0 deletions llvm/tools/sycl-module-split/sycl-module-split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ cl::opt<IRSplitMode> SplitMode(
"Choose split mode automatically")),
cl::cat(SplitCategory));

cl::opt<bool> AllowDeviceImageDependencies{
"allow-device-image-dependencies",
cl::desc("Allow dependencies between device images"),
cl::cat(SplitCategory), cl::init(false)};

void writeStringToFile(const std::string &Content, StringRef Path) {
std::error_code EC;
raw_fd_ostream OS(Path, EC);
Expand Down Expand Up @@ -120,6 +125,7 @@ int main(int argc, char *argv[]) {
Settings.Mode = SplitMode;
Settings.OutputAssembly = OutputAssembly;
Settings.OutputPrefix = OutputFilenamePrefix;
Settings.AllowDeviceImageDependencies = AllowDeviceImageDependencies;
auto SplitModulesOrErr = splitSYCLModule(std::move(M), Settings);
if (!SplitModulesOrErr) {
Err.print(argv[0], errs());
Expand Down
Loading