Skip to content

[SYCL] Keep multiple copies for bf16 device library image #17461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Apr 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 108 additions & 84 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,22 +615,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
}

// Quick check to see whether BinImage is a compiler-generated device image.
static bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
// SYCL devicelib image.
if (BinImage->getDeviceLibMetadata().isAvailable())
if ((m_Bfloat16DeviceLibImages[0].get() == BinImage) ||
m_Bfloat16DeviceLibImages[1].get() == BinImage)
return true;

return false;
}

static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
const device &Dev) {
bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
RTDeviceBinaryImage *BinImage, const device &Dev) {
// Decide whether a devicelib image should be used.
if (BinImage->getDeviceLibMetadata().isAvailable()) {
const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
BinImage->getDeviceLibMetadata();
uint32_t DeviceLibMeta =
DeviceBinaryProperty(*(DeviceLibMetaProp.begin())).asUint32();
int Bfloat16DeviceLibVersion = -1;
if (m_Bfloat16DeviceLibImages[0].get() == BinImage)
Bfloat16DeviceLibVersion = 0;
else if (m_Bfloat16DeviceLibImages[1].get() == BinImage)
Bfloat16DeviceLibVersion = 1;

if (Bfloat16DeviceLibVersion != -1) {
// Currently, only bfloat conversion devicelib are supported, so the prop
// DeviceLibMeta are only used to represent fallback or native version.
// For bfloat16 conversion devicelib, we have fallback and native version.
Expand All @@ -644,7 +647,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
detail::getSyclObjImpl(Dev);
std::string NativeBF16ExtName = "cl_intel_bfloat16_conversions";
bool NativeBF16Supported = (DeviceImpl->has_extension(NativeBF16ExtName));
return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
return NativeBF16Supported ==
(Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
}

return false;
Expand Down Expand Up @@ -1838,87 +1842,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
return {};
}

static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
// For bfloat16 device library image, we should keep it. However, in some
// scenario, __sycl_register_lib can be called multiple times and the same
// bfloat16 device library image may be handled multiple times which is not
// needed. 2 static bool variables are created to record whether native or
// fallback bfloat16 device library image has been handled, if yes, we just
// need to skip it.
// We cannot prevent redundant loads of device library images if they are part
// of a runtime-compiled device binary, as these will be freed when the
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
// on the presence of RTC device library images.
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
uint32_t *LibVersion = nullptr) {
sycl_device_binary_property_set ImgPS;
static bool IsNativeBF16DeviceLibHandled = false;
static bool IsFallbackBF16DeviceLibHandled = false;
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
++ImgPS) {
if (ImgPS->Name &&
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
if (!LibVersion)
return true;

// Valid version for bfloat16 device library is 0(fallback), 1(native).
*LibVersion = 2;
sycl_device_binary_property ImgP;
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
++ImgP) {
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
break;
}
if (ImgP == ImgPS->PropertiesEnd)
return true;

// A valid bfloat16 device library image is found here.
// If it originated from RTC, we cannot skip it, but do not mark it as
// being present.
if (IsRTC)
return false;

// Otherwise, we need to check whether it has been handled already.
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
return true;

if (BF16NativeVal == 0)
IsFallbackBF16DeviceLibHandled = true;
else
IsNativeBF16DeviceLibHandled = true;

return false;
if (ImgP != ImgPS->PropertiesEnd)
*LibVersion = DeviceBinaryProperty(ImgP).asUint32();
return true;
}
}
return true;

return false;
}

static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
// Check whether the first device binary contains a legacy format offload
// entry with a `$` in its name.
if (DeviceBinary->NumDeviceBinaries > 0) {
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
if (Binary->EntriesBegin != Binary->EntriesEnd) {
sycl_offload_entry Entry = Binary->EntriesBegin;
if (!Entry->IsNewOffloadEntryType() &&
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
return true;
}
}
static sycl_device_binary_property_set
getExportedSymbolPS(sycl_device_binary RawImg) {
sycl_device_binary_property_set ImgPS;
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
++ImgPS) {
if (ImgPS->Name &&
!strcmp(__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name))
return ImgPS;
}
return false;

return nullptr;
}

static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
// For bfloat16 device library image, we should keep it although it doesn't
// include any kernel.
if (isBfloat16DeviceLibImage(RawImg))
return false;

// We may extend the logic here other than bfloat16 device library image.
return true;
}

void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// If the image does not contain kernels, skip it unless it is one of the
// bfloat16 device libraries, and it wasn't loaded before or resulted from
// runtime compilation.
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
continue;

std::unique_ptr<RTDeviceBinaryImage> Img;
bool IsBfloat16DeviceLib = false;
uint32_t Bfloat16DeviceLibVersion = 0;
if (isDeviceImageCompressed(RawImg))
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
Expand All @@ -1928,25 +1914,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
"SYCL RT was built without ZSTD support."
"Aborting. ");
#endif
else
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
else {
IsBfloat16DeviceLib =
isBfloat16DeviceLibImage(RawImg, &Bfloat16DeviceLibVersion);
if (!IsBfloat16DeviceLib)
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
}

static uint32_t SequenceID = 0;

// Fill the kernel argument mask map
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
Img->getKernelParamOptInfo();
if (KPOIRange.isAvailable()) {
KernelNameToArgMaskMap &ArgMaskMap =
m_EliminatedKernelArgMasks[Img.get()];
for (const auto &Info : KPOIRange)
ArgMaskMap[Info->Name] =
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
// Fill the kernel argument mask map, no need to do this for bfloat16
// device library image since it doesn't include any kernel.
if (!IsBfloat16DeviceLib) {
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
Img->getKernelParamOptInfo();
if (KPOIRange.isAvailable()) {
KernelNameToArgMaskMap &ArgMaskMap =
m_EliminatedKernelArgMasks[Img.get()];
for (const auto &Info : KPOIRange)
ArgMaskMap[Info->Name] =
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
}
}

// Fill maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);

// For bfloat16 device library image, it doesn't include any kernel, device
// global, virtual function, so just skip adding it to any related maps.
// The bfloat16 device library are provided by compiler and may be used by
// different sycl device images, program manager will own single copy for
// native and fallback version bfloat16 device library, these device
// library images will not be erased unless program manager is destroyed.
{
if (IsBfloat16DeviceLib) {
assert((Bfloat16DeviceLibVersion < 2) &&
"Invalid Bfloat16 Device Library Index.");
if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get())
continue;
size_t ImgSize =
static_cast<size_t>(RawImg->BinaryEnd - RawImg->BinaryStart);
std::unique_ptr<char[]> Data(new char[ImgSize]);
std::memcpy(Data.get(), RawImg->BinaryStart, ImgSize);
auto DynBfloat16DeviceLibImg =
std::make_unique<DynRTDeviceBinaryImage>(std::move(Data), ImgSize);
auto ESPropSet = getExportedSymbolPS(RawImg);
sycl_device_binary_property ESProp;
for (ESProp = ESPropSet->PropertiesBegin;
ESProp != ESPropSet->PropertiesEnd; ++ESProp) {
m_ExportedSymbolImages.insert(
{ESProp->Name, DynBfloat16DeviceLibImg.get()});
}
m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
std::move(DynBfloat16DeviceLibImg);
continue;
}
}

// Register all exported symbols
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
Expand Down Expand Up @@ -2111,19 +2135,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
}

void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
auto DevImgIt = m_DeviceImages.find(RawImg);
if (DevImgIt == m_DeviceImages.end())
continue;
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// Skip clean up if there are no offload entries, unless `DeviceBinary`
// resulted from runtime compilation: Then, this is one of the `bfloat16`
// device libraries, so we want to make sure that the image and its exported
// symbols are removed from the program manager's maps.
if (EntriesB == EntriesE && !IsRTC)
if (EntriesB == EntriesE)
continue;

RTDeviceBinaryImage *Img = DevImgIt->second.get();
Expand Down Expand Up @@ -2651,7 +2670,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
DepKernelIDs = m_BinImg2KernelIDs[Dep];
// For device library images, they are not in m_BinImg2KernelIDs since
// no kernel is included.
auto DepIt = m_BinImg2KernelIDs.find(Dep);
if (DepIt != m_BinImg2KernelIDs.end())
DepKernelIDs = DepIt->second;
}

assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
Expand Down Expand Up @@ -2865,9 +2888,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
getSyclObjImpl(Img);
// Duplicates are not expected here, otherwise urProgramLink should fail
KernelIDs.insert(KernelIDs.end(),
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
DeviceImageImpl->get_kernel_ids_ptr()->end());
if (DeviceImageImpl->get_kernel_ids_ptr())
KernelIDs.insert(KernelIDs.end(),
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
DeviceImageImpl->get_kernel_ids_ptr()->end());
// To be able to answer queries about specialziation constants, the new
// device image should have the specialization constants from all the linked
// images.
Expand Down
13 changes: 12 additions & 1 deletion sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <sycl/device.hpp>
#include <sycl/kernel_bundle.hpp>

#include <array>
#include <cstdint>
#include <map>
#include <memory>
Expand Down Expand Up @@ -376,11 +377,15 @@ class ProgramManager {
collectDependentDeviceImagesForVirtualFunctions(
const RTDeviceBinaryImage &Img, const device &Dev);

bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
const device &Dev);

protected:
/// The three maps below are used during kernel resolution. Any kernel is
/// identified by its name.
using RTDeviceBinaryImageUPtr = std::unique_ptr<RTDeviceBinaryImage>;

using DynRTDeviceBinaryImageUPtr = std::unique_ptr<DynRTDeviceBinaryImage>;
/// Maps names of kernels to their unique kernel IDs.
/// TODO: Use std::unordered_set with transparent hash and equality functions
/// when C++20 is enabled for the runtime library.
Expand Down Expand Up @@ -498,6 +503,12 @@ class ProgramManager {
std::map<std::vector<unsigned char>, ur_kernel_handle_t>;
std::unordered_map<std::string, MaterializedEntries> m_MaterializedKernels;

// Holds bfloat16 device library images, the 1st element is for fallback
// version and 2nd is for native version. These bfloat16 device library
// images are provided by compiler long time ago, we expect no further
// update, so keeping 1 copy should be OK.
std::array<DynRTDeviceBinaryImageUPtr, 2> m_Bfloat16DeviceLibImages;

friend class ::ProgramManagerTest;
};
} // namespace detail
Expand Down
6 changes: 3 additions & 3 deletions sycl/test-e2e/KernelCompiler/sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,11 @@ int main() {
if (!ok) {
return -1;
}

// Run test_device_libraries twice to verify bfloat16 device library.
return test_build_and_run(q) || test_device_code_split(q) ||
test_device_libraries(q) || test_esimd(q) ||
test_unsupported_options(q) || test_error(q) ||
test_no_visible_ids(q) || test_warning(q);
test_device_libraries(q) || test_unsupported_options(q) ||
test_error(q) || test_no_visible_ids(q) || test_warning(q);
#else
static_assert(false, "Kernel Compiler feature test macro undefined");
#endif
Expand Down