@@ -60,11 +60,6 @@ constexpr char SYCL_SCOPE_NAME[] = "<SYCL>";
60
60
constexpr char ESIMD_SCOPE_NAME[] = " <ESIMD>" ;
61
61
constexpr char ESIMD_MARKER_MD[] = " sycl_explicit_simd" ;
62
62
63
- cl::opt<bool > AllowDeviceImageDependencies{
64
- " allow-device-image-dependencies" ,
65
- cl::desc (" Allow dependencies between device images" ),
66
- cl::cat (getModuleSplitCategory ()), cl::init (false )};
67
-
68
63
EntryPointsGroupScope selectDeviceCodeGroupScope (const Module &M,
69
64
IRSplitMode Mode,
70
65
bool AutoSplitIsGlobalScope) {
@@ -178,7 +173,7 @@ class DependencyGraph {
178
173
public:
179
174
using GlobalSet = SmallPtrSet<const GlobalValue *, 16 >;
180
175
181
- DependencyGraph (const Module &M) {
176
+ DependencyGraph (const Module &M, bool AllowDeviceImageDependencies ) {
182
177
// Group functions by their signature to handle case (2) described above
183
178
DenseMap<const FunctionType *, DependencyGraph::GlobalSet>
184
179
FuncTypeToFuncsMap;
@@ -196,7 +191,7 @@ class DependencyGraph {
196
191
}
197
192
198
193
for (const auto &F : M.functions ()) {
199
- if (canBeImportedFunction (F))
194
+ if (canBeImportedFunction (F, AllowDeviceImageDependencies ))
200
195
continue ;
201
196
202
197
// case (1), see comment above the class definition
@@ -311,7 +306,9 @@ static bool isIntrinsicOrBuiltin(const Function &F) {
311
306
}
312
307
313
308
// Checks for use of undefined user functions and emits a warning message.
314
- static void checkForCallsToUndefinedFunctions (const Module &M) {
309
+ static void
310
+ checkForCallsToUndefinedFunctions (const Module &M,
311
+ bool AllowDeviceImageDependencies) {
315
312
if (AllowDeviceImageDependencies)
316
313
return ;
317
314
for (const Function &F : M) {
@@ -391,11 +388,11 @@ ModuleDesc extractSubModule(const ModuleDesc &MD,
391
388
// The function produces a copy of input LLVM IR module M with only those
392
389
// functions and globals that can be called from entry points that are specified
393
390
// in ModuleEntryPoints vector, in addition to the entry point functions.
394
- ModuleDesc extractCallGraph (const ModuleDesc &MD,
395
- EntryPointGroup &&ModuleEntryPoints,
396
- const DependencyGraph &CG,
397
- const std::function<bool (const Function *)>
398
- &IncludeFunctionPredicate = nullptr) {
391
+ ModuleDesc extractCallGraph (
392
+ const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
393
+ const DependencyGraph &CG, bool AllowDeviceImageDependencies ,
394
+ const std::function<bool (const Function *)> &IncludeFunctionPredicate =
395
+ nullptr) {
399
396
SetVector<const GlobalValue *> GVs;
400
397
collectFunctionsAndGlobalVariablesToExtract (
401
398
GVs, MD.getModule (), ModuleEntryPoints, CG, IncludeFunctionPredicate);
@@ -405,20 +402,21 @@ ModuleDesc extractCallGraph(const ModuleDesc &MD,
405
402
// sycl-post-link. This call is redundant. However, we subsequently run
406
403
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
407
404
// can be removed once that pass no longer depends on this cleanup.
408
- SplitM.cleanup ();
409
- checkForCallsToUndefinedFunctions (SplitM.getModule ());
405
+ SplitM.cleanup (AllowDeviceImageDependencies);
406
+ checkForCallsToUndefinedFunctions (SplitM.getModule (),
407
+ AllowDeviceImageDependencies);
410
408
411
409
return SplitM;
412
410
}
413
411
414
412
// The function is similar to 'extractCallGraph', but it produces a copy of
415
413
// input LLVM IR module M with _all_ ESIMD functions and kernels included,
416
414
// regardless of whether or not they are listed in ModuleEntryPoints.
417
- ModuleDesc extractESIMDSubModule (const ModuleDesc &MD,
418
- EntryPointGroup &&ModuleEntryPoints,
419
- const DependencyGraph &CG,
420
- const std::function<bool (const Function *)>
421
- &IncludeFunctionPredicate = nullptr) {
415
+ ModuleDesc extractESIMDSubModule (
416
+ const ModuleDesc &MD, EntryPointGroup &&ModuleEntryPoints,
417
+ const DependencyGraph &CG, bool AllowDeviceImageDependencies ,
418
+ const std::function<bool (const Function *)> &IncludeFunctionPredicate =
419
+ nullptr) {
422
420
SetVector<const GlobalValue *> GVs;
423
421
for (const auto &F : MD.getModule ().functions ())
424
422
if (isESIMDFunction (F))
@@ -432,7 +430,7 @@ ModuleDesc extractESIMDSubModule(const ModuleDesc &MD,
432
430
// sycl-post-link. This call is redundant. However, we subsequently run
433
431
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup call
434
432
// can be removed once that pass no longer depends on this cleanup.
435
- SplitM.cleanup ();
433
+ SplitM.cleanup (AllowDeviceImageDependencies );
436
434
437
435
return SplitM;
438
436
}
@@ -449,19 +447,22 @@ class ModuleCopier : public ModuleSplitterBase {
449
447
// sycl-post-link. This call is redundant. However, we subsequently run
450
448
// GenXSPIRVWriterAdaptor pass that relies on this cleanup. This cleanup
451
449
// call can be removed once that pass no longer depends on this cleanup.
452
- Desc.cleanup ();
450
+ Desc.cleanup (AllowDeviceImageDependencies );
453
451
return Desc;
454
452
}
455
453
};
456
454
457
455
class ModuleSplitter : public ModuleSplitterBase {
458
456
public:
459
- ModuleSplitter (ModuleDesc &&MD, EntryPointGroupVec &&GroupVec)
460
- : ModuleSplitterBase(std::move(MD), std::move(GroupVec)),
461
- CG (Input.getModule()) {}
457
+ ModuleSplitter (ModuleDesc &&MD, EntryPointGroupVec &&GroupVec,
458
+ bool AllowDeviceImageDependencies)
459
+ : ModuleSplitterBase(std::move(MD), std::move(GroupVec),
460
+ AllowDeviceImageDependencies),
461
+ CG (Input.getModule(), AllowDeviceImageDependencies) {}
462
462
463
463
ModuleDesc nextSplit () override {
464
- return extractCallGraph (Input, nextGroup (), CG);
464
+ return extractCallGraph (Input, nextGroup (), CG,
465
+ AllowDeviceImageDependencies);
465
466
}
466
467
467
468
private:
@@ -489,11 +490,6 @@ bool isESIMDFunction(const Function &F) {
489
490
return F.getMetadata (ESIMD_MARKER_MD) != nullptr ;
490
491
}
491
492
492
- cl::OptionCategory &getModuleSplitCategory () {
493
- static cl::OptionCategory ModuleSplitCategory{" Module Split options" };
494
- return ModuleSplitCategory;
495
- }
496
-
497
493
Error ModuleSplitterBase::verifyNoCrossModuleDeviceGlobalUsage () {
498
494
const Module &M = getInputModule ();
499
495
// Early exit if there is only one group
@@ -692,7 +688,8 @@ void ModuleDesc::restoreLinkageOfDirectInvokeSimdTargets() {
692
688
// tries to internalize absolutely everything. This function serves as "input
693
689
// from a linker" that tells the pass what must be preserved in order to make
694
690
// the transformation safe.
695
- static bool mustPreserveGV (const GlobalValue &GV) {
691
+ static bool mustPreserveGV (const GlobalValue &GV,
692
+ bool AllowDeviceImageDependencies) {
696
693
if (const Function *F = dyn_cast<Function>(&GV)) {
697
694
// When dynamic linking is supported, we internalize everything (except
698
695
// kernels which are the entry points from host code to device code) that
@@ -703,7 +700,8 @@ static bool mustPreserveGV(const GlobalValue &GV) {
703
700
const bool SpirOrGPU = CC == CallingConv::SPIR_KERNEL ||
704
701
CC == CallingConv::AMDGPU_KERNEL ||
705
702
CC == CallingConv::PTX_Kernel;
706
- return SpirOrGPU || canBeImportedFunction (*F);
703
+ return SpirOrGPU ||
704
+ canBeImportedFunction (*F, AllowDeviceImageDependencies);
707
705
}
708
706
709
707
// Otherwise, we are being even more aggressive: SYCL modules are expected
@@ -754,7 +752,7 @@ void cleanupSYCLRegisteredKernels(Module *M) {
754
752
755
753
// TODO: try to move all passes (cleanup, spec consts, compile time properties)
756
754
// in one place and execute MPM.run() only once.
757
- void ModuleDesc::cleanup () {
755
+ void ModuleDesc::cleanup (bool AllowDeviceImageDependencies ) {
758
756
// Any definitions of virtual functions should be removed and turned into
759
757
// declarations, they are supposed to be provided by a different module.
760
758
if (!EntryPoints.Props .HasVirtualFunctionDefinitions ) {
@@ -781,7 +779,10 @@ void ModuleDesc::cleanup() {
781
779
MAM.registerPass ([&] { return PassInstrumentationAnalysis (); });
782
780
ModulePassManager MPM;
783
781
// Do cleanup.
784
- MPM.addPass (InternalizePass (mustPreserveGV));
782
+ MPM.addPass (
783
+ InternalizePass ([AllowDeviceImageDependencies](const GlobalValue &GV) {
784
+ return mustPreserveGV (GV, AllowDeviceImageDependencies);
785
+ }));
785
786
MPM.addPass (GlobalDCEPass ()); // Delete unreachable globals.
786
787
MPM.addPass (StripDeadDebugInfoPass ()); // Remove dead debug info.
787
788
MPM.addPass (StripDeadPrototypesPass ()); // Remove dead func decls.
@@ -1157,7 +1158,8 @@ std::string FunctionsCategorizer::computeCategoryFor(Function *F) const {
1157
1158
1158
1159
std::unique_ptr<ModuleSplitterBase>
1159
1160
getDeviceCodeSplitter (ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
1160
- bool EmitOnlyKernelsAsEntryPoints) {
1161
+ bool EmitOnlyKernelsAsEntryPoints,
1162
+ bool AllowDeviceImageDependencies) {
1161
1163
FunctionsCategorizer Categorizer;
1162
1164
1163
1165
EntryPointsGroupScope Scope =
@@ -1252,9 +1254,11 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
1252
1254
(Groups.size () > 1 || !Groups.cbegin ()->Functions .empty ()));
1253
1255
1254
1256
if (DoSplit)
1255
- return std::make_unique<ModuleSplitter>(std::move (MD), std::move (Groups));
1257
+ return std::make_unique<ModuleSplitter>(std::move (MD), std::move (Groups),
1258
+ AllowDeviceImageDependencies);
1256
1259
1257
- return std::make_unique<ModuleCopier>(std::move (MD), std::move (Groups));
1260
+ return std::make_unique<ModuleCopier>(std::move (MD), std::move (Groups),
1261
+ AllowDeviceImageDependencies);
1258
1262
}
1259
1263
1260
1264
// Splits input module into two:
@@ -1277,7 +1281,8 @@ getDeviceCodeSplitter(ModuleDesc &&MD, IRSplitMode Mode, bool IROutputOnly,
1277
1281
// avoid undefined behavior at later stages. That is done at higher level,
1278
1282
// outside of this function.
1279
1283
SmallVector<ModuleDesc, 2 > splitByESIMD (ModuleDesc &&MD,
1280
- bool EmitOnlyKernelsAsEntryPoints) {
1284
+ bool EmitOnlyKernelsAsEntryPoints,
1285
+ bool AllowDeviceImageDependencies) {
1281
1286
1282
1287
SmallVector<module_split::ModuleDesc, 2 > Result;
1283
1288
EntryPointGroupVec EntryPointGroups{};
@@ -1320,12 +1325,13 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
1320
1325
return Result;
1321
1326
}
1322
1327
1323
- DependencyGraph CG (MD.getModule ());
1328
+ DependencyGraph CG (MD.getModule (), AllowDeviceImageDependencies );
1324
1329
for (auto &Group : EntryPointGroups) {
1325
1330
if (Group.isEsimd ()) {
1326
1331
// For ESIMD module, we use full call graph of all entry points and all
1327
1332
// ESIMD functions.
1328
- Result.emplace_back (extractESIMDSubModule (MD, std::move (Group), CG));
1333
+ Result.emplace_back (extractESIMDSubModule (MD, std::move (Group), CG,
1334
+ AllowDeviceImageDependencies));
1329
1335
} else {
1330
1336
// For non-ESIMD module we only use non-ESIMD functions. Additional filter
1331
1337
// is needed, because there could be uses of ESIMD functions from
@@ -1334,7 +1340,7 @@ SmallVector<ModuleDesc, 2> splitByESIMD(ModuleDesc &&MD,
1334
1340
// were processed and therefore it is fine to return an "incomplete"
1335
1341
// module here.
1336
1342
Result.emplace_back (extractCallGraph (
1337
- MD, std::move (Group), CG,
1343
+ MD, std::move (Group), CG, AllowDeviceImageDependencies,
1338
1344
[=](const Function *F) -> bool { return !isESIMDFunction (*F); }));
1339
1345
}
1340
1346
}
@@ -1477,7 +1483,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
1477
1483
// FIXME: false arguments are temporary for now.
1478
1484
auto Splitter = getDeviceCodeSplitter (std::move (MD), Settings.Mode ,
1479
1485
/* IROutputOnly=*/ false ,
1480
- /* EmitOnlyKernelsAsEntryPoints=*/ false );
1486
+ /* EmitOnlyKernelsAsEntryPoints=*/ false ,
1487
+ Settings.AllowDeviceImageDependencies );
1481
1488
1482
1489
size_t ID = 0 ;
1483
1490
std::vector<SplitModule> OutputImages;
@@ -1498,7 +1505,8 @@ splitSYCLModule(std::unique_ptr<Module> M, ModuleSplitterSettings Settings) {
1498
1505
return OutputImages;
1499
1506
}
1500
1507
1501
- bool canBeImportedFunction (const Function &F) {
1508
+ bool canBeImportedFunction (const Function &F,
1509
+ bool AllowDeviceImageDependencies) {
1502
1510
1503
1511
// We use sycl dynamic library mechanism to involve bf16 devicelib when
1504
1512
// necessary, all __devicelib_* functions from native or fallback bf16
0 commit comments