Skip to content

Commit ffbe777

Browse files
committed
Fix dependency resolution
Signed-off-by: Larsen, Steffen <[email protected]>
1 parent 5499ff2 commit ffbe777

File tree

3 files changed

+71
-32
lines changed

3 files changed

+71
-32
lines changed

sycl/source/detail/device_image_impl.hpp

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -746,19 +746,28 @@ class device_image_impl {
746746
MRTCBinInfo->MIncludePairs, BuildOptions, LogPtr);
747747

748748
auto &PM = detail::ProgramManager::getInstance();
749-
std::vector<std::shared_ptr<device_image_impl>> Result;
750-
Result.reserve(Binaries->NumDeviceBinaries);
749+
750+
// Add all binaries and keep the images for processing.
751+
std::vector<std::pair<RTDeviceBinaryImage *,
752+
std::shared_ptr<std::vector<kernel_id>>>>
753+
NewImages;
754+
NewImages.reserve(Binaries->NumDeviceBinaries);
751755
for (int I = 0; I < Binaries->NumDeviceBinaries; I++) {
752756
sycl_device_binary Binary = &(Binaries->DeviceBinaries[I]);
753-
754757
RTDeviceBinaryImage *NewImage = nullptr;
755758
auto KernelIDs = std::make_shared<std::vector<kernel_id>>();
756759
PM.addImage(Binary, &NewImage, KernelIDs.get());
760+
if (NewImage)
761+
NewImages.push_back(
762+
std::make_pair(std::move(NewImage), std::move(KernelIDs)));
763+
}
757764

758-
// If the image is empty, we can skip it.
759-
if (!NewImage)
760-
continue;
761-
765+
// Now bring all images into the proper state. Note that we do this in a
766+
// separate pass over NewImages to make sure dependency images have been
767+
// registered beforehand.
768+
std::vector<std::shared_ptr<device_image_impl>> Result;
769+
Result.reserve(NewImages.size());
770+
for (auto &[NewImage, KernelIDs] : NewImages) {
762771
std::set<std::string> KernelNames;
763772
std::unordered_map<std::string, std::string> MangledKernelNames;
764773
std::unordered_set<std::string> DeviceGlobalIDSet;
@@ -843,7 +852,26 @@ class device_image_impl {
843852
std::move(KernelNames), std::move(MangledKernelNames),
844853
std::string{Prefix}, std::move(DGRegs));
845854

846-
DevImgPlainWithDeps ImgWithDeps{DevImgImpl};
855+
// Resolve dependencies.
856+
// TODO: Consider making a collectDeviceImageDeps variant that takes a
857+
// set reference and inserts into that instead.
858+
std::set<RTDeviceBinaryImage *> ImgDeps;
859+
for (const device &Device : Devices) {
860+
std::set<RTDeviceBinaryImage *> DevImgDeps =
861+
PM.collectDeviceImageDeps(*NewImage, Device);
862+
ImgDeps.insert(DevImgDeps.begin(), DevImgDeps.end());
863+
}
864+
865+
// Pack main image and dependencies together.
866+
std::vector<device_image_plain> NewImageAndDeps;
867+
NewImageAndDeps.reserve(1 + ImgDeps.size());
868+
NewImageAndDeps.push_back(std::move(
869+
createSyclObjFromImpl<device_image_plain>(std::move(DevImgImpl))));
870+
for (RTDeviceBinaryImage *ImgDep : ImgDeps)
871+
NewImageAndDeps.push_back(PM.createDependencyImage(
872+
MContext, Devices, ImgDep, bundle_state::input));
873+
874+
DevImgPlainWithDeps ImgWithDeps(std::move(NewImageAndDeps));
847875
PM.bringSYCLDeviceImageToState(ImgWithDeps, bundle_state::executable);
848876
Result.push_back(getSyclObjImpl(ImgWithDeps.getMain()));
849877
}

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2698,32 +2698,36 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26982698
Images.reserve(Deps.size() + 1);
26992699
Images.push_back(
27002700
createSyclObjFromImpl<device_image_plain>(std::move(MainImpl)));
2701-
for (RTDeviceBinaryImage *Dep : Deps) {
2702-
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
2703-
{
2704-
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2705-
// For device library images, they are not in m_BinImg2KernelIDs since
2706-
// no kernel is included.
2707-
auto DepIt = m_BinImg2KernelIDs.find(Dep);
2708-
if (DepIt != m_BinImg2KernelIDs.end())
2709-
DepKernelIDs = DepIt->second;
2710-
}
2711-
2712-
assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
2713-
"State mismatch between main image and its dependency");
2714-
DeviceImageImplPtr DepImpl = std::make_shared<detail::device_image_impl>(
2715-
Dep, Ctx, Devs, ImgInfoPair.second.State, DepKernelIDs,
2716-
/*PIProgram=*/nullptr);
2717-
2701+
for (RTDeviceBinaryImage *Dep : Deps)
27182702
Images.push_back(
2719-
createSyclObjFromImpl<device_image_plain>(std::move(DepImpl)));
2720-
}
2703+
createDependencyImage(Ctx, Devs, Dep, ImgInfoPair.second.State));
27212704
SYCLDeviceImages.push_back(std::move(Images));
27222705
}
27232706

27242707
return SYCLDeviceImages;
27252708
}
27262709

2710+
device_image_plain ProgramManager::createDependencyImage(
2711+
const context &Ctx, const std::vector<device> &Devs,
2712+
RTDeviceBinaryImage *DepImage, bundle_state DepState) {
2713+
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
2714+
{
2715+
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2716+
// For device library images, they are not in m_BinImg2KernelIDs since
2717+
// no kernel is included.
2718+
auto DepIt = m_BinImg2KernelIDs.find(DepImage);
2719+
if (DepIt != m_BinImg2KernelIDs.end())
2720+
DepKernelIDs = DepIt->second;
2721+
}
2722+
2723+
assert(DepState == getBinImageState(DepImage) &&
2724+
"State mismatch between main image and its dependency");
2725+
DeviceImageImplPtr DepImpl = std::make_shared<detail::device_image_impl>(
2726+
DepImage, Ctx, Devs, DepState, DepKernelIDs, /*PIProgram=*/nullptr);
2727+
2728+
return createSyclObjFromImpl<device_image_plain>(std::move(DepImpl));
2729+
}
2730+
27272731
void ProgramManager::bringSYCLDeviceImageToState(
27282732
DevImgPlainWithDeps &DeviceImage, bundle_state TargetState) {
27292733
device_image_plain &MainImg = DeviceImage.getMain();

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,12 @@ class ProgramManager {
297297
const context &Ctx, const std::vector<device> &Devs,
298298
bundle_state TargetState, const std::vector<kernel_id> &KernelIDs = {});
299299

300+
// Creates a new dependency image for a given dependency binary image.
301+
device_image_plain createDependencyImage(const context &Ctx,
302+
const std::vector<device> &Devs,
303+
RTDeviceBinaryImage *DepImage,
304+
bundle_state DepState);
305+
300306
// Bring image to the required state. Does it inplace
301307
void bringSYCLDeviceImageToState(DevImgPlainWithDeps &DeviceImage,
302308
bundle_state TargetState);
@@ -363,6 +369,12 @@ class ProgramManager {
363369
std::set<RTDeviceBinaryImage *>
364370
getRawDeviceImages(const std::vector<kernel_id> &KernelIDs);
365371

372+
std::set<RTDeviceBinaryImage *>
373+
collectDeviceImageDeps(const RTDeviceBinaryImage &Img, const device &Dev);
374+
std::set<RTDeviceBinaryImage *>
375+
collectDeviceImageDepsForImportedSymbols(const RTDeviceBinaryImage &Img,
376+
const device &Dev);
377+
366378
private:
367379
ProgramManager(ProgramManager const &) = delete;
368380
ProgramManager &operator=(ProgramManager const &) = delete;
@@ -386,11 +398,6 @@ class ProgramManager {
386398
/// Add info on kernels using local arg into cache
387399
void cacheKernelImplicitLocalArg(RTDeviceBinaryImage &Img);
388400

389-
std::set<RTDeviceBinaryImage *>
390-
collectDeviceImageDeps(const RTDeviceBinaryImage &Img, const device &Dev);
391-
std::set<RTDeviceBinaryImage *>
392-
collectDeviceImageDepsForImportedSymbols(const RTDeviceBinaryImage &Img,
393-
const device &Dev);
394401
std::set<RTDeviceBinaryImage *>
395402
collectDependentDeviceImagesForVirtualFunctions(
396403
const RTDeviceBinaryImage &Img, const device &Dev);

0 commit comments

Comments
 (0)