Skip to content

[SYCL] Keep multiple copies for bf16 device library image #17461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 36 additions & 70 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1837,84 +1837,27 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
return {};
}

static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
// For bfloat16 device library image, we should keep it. However, in some
// scenario, __sycl_register_lib can be called multiple times and the same
// bfloat16 device library image may be handled multiple times which is not
// needed. 2 static bool variables are created to record whether native or
// fallback bfloat16 device library image has been handled, if yes, we just
// need to skip it.
// We cannot prevent redundant loads of device library images if they are part
// of a runtime-compiled device binary, as these will be freed when the
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
// on the presence of RTC device library images.
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
// For bfloat16 device library image, we should keep it although it doesn't
// include any kernel.
sycl_device_binary_property_set ImgPS;
static bool IsNativeBF16DeviceLibHandled = false;
static bool IsFallbackBF16DeviceLibHandled = false;
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
++ImgPS) {
if (ImgPS->Name &&
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
sycl_device_binary_property ImgP;
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
++ImgP) {
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
break;
}
if (ImgP == ImgPS->PropertiesEnd)
return true;

// A valid bfloat16 device library image is found here.
// If it originated from RTC, we cannot skip it, but do not mark it as
// being present.
if (IsRTC)
return false;

// Otherwise, we need to check whether it has been handled already.
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
return true;

if (BF16NativeVal == 0)
IsFallbackBF16DeviceLibHandled = true;
else
IsNativeBF16DeviceLibHandled = true;

!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name))
return false;
}
}
return true;
}

static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
// Check whether the first device binary contains a legacy format offload
// entry with a `$` in its name.
if (DeviceBinary->NumDeviceBinaries > 0) {
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
if (Binary->EntriesBegin != Binary->EntriesEnd) {
sycl_offload_entry Entry = Binary->EntriesBegin;
if (!Entry->IsNewOffloadEntryType() &&
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
return true;
}
}
}
return false;
return true;
}

void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// If the image does not contain kernels, skip it unless it is one of the
// bfloat16 device libraries, and it wasn't loaded before or resulted from
// runtime compilation.
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
continue;

std::unique_ptr<RTDeviceBinaryImage> Img;
Expand Down Expand Up @@ -1946,6 +1889,25 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
// Fill maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);

// For bfloat16 device library image, it doesn't include any kernel, device
// global, virtual function, so just skip adding it to any related maps. We
// only need to 1) add exported symbols to m_ExportedSymbolImages, and 2)
// add the device image to m_Bfloat16DeviceLibImages.
{
auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata();
if (Bfloat16DeviceLibProp.isAvailable()) {
uint32_t LibVersion = DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32();
if (m_Bfloat16DeviceLibImages.count(LibVersion) > 0)
continue;
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
m_ExportedSymbolImages.insert({ESProp->Name, Img.get()});
}
m_Bfloat16DeviceLibImages.insert({LibVersion, std::move(Img)});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't sufficient to keep the image alive. The RTDeviceBinaryImage does not own the sycl_device_binary_struct, which itself does not own the raw binary data (e.g. SPIRV) at sycl_device_binary_struct::BinaryStart; both entities become unavailable after ProgramManager::removeImages (either through dlclose or explicit deleting from a context object in the sycl-jit library). One way to trigger a crash is to execute test_device_libraries twice in sycl/test-e2e/KernelCompiler/sycl.cpp.

I think a solution could be to create a DynRTDeviceBinaryImage with a copy of the data coming from a bfloat device library image. (NB: DynRTDeviceBinaryImage doesn't set compile/link options, is that a problem?)

Copy link
Contributor Author

@jinge90 jinge90 Mar 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jopperm
Thanks very much to point out this. I updated the PR to let program manager to own devicelib DynRTDeviceBinaryImage. It seems compile/link options are not problems but we can't set PropertySet value for DynRTDeviceBinaryImage, so can't check BFloat16DevicelibMetadata. I have to use a workaround to compare the DeviceImagePtr. I also updated the KernelCompiler/sycl.cpp to run test_device_library 2 times. The PM owned DynRTDeviceBinaryImage won't be removed unless PM is destroyed.

Thanks very much.

continue;
}
}

// Register all exported symbols
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
Expand Down Expand Up @@ -2110,19 +2072,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
}

void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
auto DevImgIt = m_DeviceImages.find(RawImg);
if (DevImgIt == m_DeviceImages.end())
continue;
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// Skip clean up if there are no offload entries, unless `DeviceBinary`
// resulted from runtime compilation: Then, this is one of the `bfloat16`
// device libraries, so we want to make sure that the image and its exported
// symbols are removed from the program manager's maps.
if (EntriesB == EntriesE && !IsRTC)
if ((EntriesB == EntriesE))
continue;

RTDeviceBinaryImage *Img = DevImgIt->second.get();
Expand Down Expand Up @@ -2650,7 +2607,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
DepKernelIDs = m_BinImg2KernelIDs[Dep];
// For device library images, they are not in m_BinImg2KernelIDs since
// no kernel is included.
auto DepIt = m_BinImg2KernelIDs.find(Dep);
if (DepIt != m_BinImg2KernelIDs.end())
DepKernelIDs = DepIt->second;
}

assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
Expand Down Expand Up @@ -2863,6 +2824,11 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
for (const device_image_plain &Img : Imgs) {
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
getSyclObjImpl(Img);
auto BinImgRef = DeviceImageImpl->get_bin_image_ref();
// For bfloat16 deice library image, no kernels, spec const are included,
// so we just skip merging data.
if (BinImgRef && BinImgRef->getDeviceLibMetadata().isAvailable())
continue;
// Duplicates are not expected here, otherwise urProgramLink should fail
KernelIDs.insert(KernelIDs.end(),
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
Expand Down
7 changes: 7 additions & 0 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,13 @@ class ProgramManager {
std::map<std::vector<unsigned char>, ur_kernel_handle_t>;
std::unordered_map<std::string, MaterializedEntries> m_MaterializedKernels;

// Holds bfloat16 device library images, the key is 0 for fallback version
// and 1 for native version. These bfloat16 device library images are
// provided by compiler long time ago, we expect no further update, so
// keeping 1 copy should be OK.
std::unordered_map<uint32_t, RTDeviceBinaryImageUPtr>
m_Bfloat16DeviceLibImages;

friend class ::ProgramManagerTest;
};
} // namespace detail
Expand Down
Loading