Skip to content

Commit 00723ce

Browse files
jinge90KornevNikita
authored andcommitted
[SYCL] Keep multiple copies for bf16 device library image (#17461)
SYCL RT addImages function may be invoked multiple times for different sycl binary images, more than 1 of these sycl binary images may depend on bfloat16 device library. These bfloat16 device library images are provided by compiler and the implementation are stable now, so we only keep single copy for native and fallback version bfloat16 device library in program manager, these images will not be removed unless program manager is destroyed. --------- Signed-off-by: jinge90 <[email protected]>
1 parent b23d69e commit 00723ce

File tree

3 files changed

+123
-88
lines changed

3 files changed

+123
-88
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 108 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -614,22 +614,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
614614
}
615615

616616
// Quick check to see whether BinImage is a compiler-generated device image.
617-
static bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
617+
bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
618618
// SYCL devicelib image.
619-
if (BinImage->getDeviceLibMetadata().isAvailable())
619+
if ((m_Bfloat16DeviceLibImages[0].get() == BinImage) ||
620+
m_Bfloat16DeviceLibImages[1].get() == BinImage)
620621
return true;
621622

622623
return false;
623624
}
624625

625-
static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
626-
const device &Dev) {
626+
bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
627+
RTDeviceBinaryImage *BinImage, const device &Dev) {
627628
// Decide whether a devicelib image should be used.
628-
if (BinImage->getDeviceLibMetadata().isAvailable()) {
629-
const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
630-
BinImage->getDeviceLibMetadata();
631-
uint32_t DeviceLibMeta =
632-
DeviceBinaryProperty(*(DeviceLibMetaProp.begin())).asUint32();
629+
int Bfloat16DeviceLibVersion = -1;
630+
if (m_Bfloat16DeviceLibImages[0].get() == BinImage)
631+
Bfloat16DeviceLibVersion = 0;
632+
else if (m_Bfloat16DeviceLibImages[1].get() == BinImage)
633+
Bfloat16DeviceLibVersion = 1;
634+
635+
if (Bfloat16DeviceLibVersion != -1) {
633636
// Currently, only bfloat conversion devicelib are supported, so the prop
634637
// DeviceLibMeta are only used to represent fallback or native version.
635638
// For bfloat16 conversion devicelib, we have fallback and native version.
@@ -643,7 +646,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
643646
detail::getSyclObjImpl(Dev);
644647
std::string NativeBF16ExtName = "cl_intel_bfloat16_conversions";
645648
bool NativeBF16Supported = (DeviceImpl->has_extension(NativeBF16ExtName));
646-
return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
649+
return NativeBF16Supported ==
650+
(Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
647651
}
648652

649653
return false;
@@ -1837,87 +1841,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18371841
return {};
18381842
}
18391843

1840-
static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
1841-
// For bfloat16 device library image, we should keep it. However, in some
1842-
// scenario, __sycl_register_lib can be called multiple times and the same
1843-
// bfloat16 device library image may be handled multiple times which is not
1844-
// needed. 2 static bool variables are created to record whether native or
1845-
// fallback bfloat16 device library image has been handled, if yes, we just
1846-
// need to skip it.
1847-
// We cannot prevent redundant loads of device library images if they are part
1848-
// of a runtime-compiled device binary, as these will be freed when the
1849-
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1850-
// on the presence of RTC device library images.
1844+
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
1845+
uint32_t *LibVersion = nullptr) {
18511846
sycl_device_binary_property_set ImgPS;
1852-
static bool IsNativeBF16DeviceLibHandled = false;
1853-
static bool IsFallbackBF16DeviceLibHandled = false;
18541847
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
18551848
++ImgPS) {
18561849
if (ImgPS->Name &&
18571850
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
1851+
if (!LibVersion)
1852+
return true;
1853+
1854+
// Valid version for bfloat16 device library is 0(fallback), 1(native).
1855+
*LibVersion = 2;
18581856
sycl_device_binary_property ImgP;
18591857
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
18601858
++ImgP) {
18611859
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
18621860
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
18631861
break;
18641862
}
1865-
if (ImgP == ImgPS->PropertiesEnd)
1866-
return true;
1867-
1868-
// A valid bfloat16 device library image is found here.
1869-
// If it originated from RTC, we cannot skip it, but do not mark it as
1870-
// being present.
1871-
if (IsRTC)
1872-
return false;
1873-
1874-
// Otherwise, we need to check whether it has been handled already.
1875-
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
1876-
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
1877-
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
1878-
return true;
1879-
1880-
if (BF16NativeVal == 0)
1881-
IsFallbackBF16DeviceLibHandled = true;
1882-
else
1883-
IsNativeBF16DeviceLibHandled = true;
1884-
1885-
return false;
1863+
if (ImgP != ImgPS->PropertiesEnd)
1864+
*LibVersion = DeviceBinaryProperty(ImgP).asUint32();
1865+
return true;
18861866
}
18871867
}
1888-
return true;
1868+
1869+
return false;
18891870
}
18901871

1891-
static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
1892-
// Check whether the first device binary contains a legacy format offload
1893-
// entry with a `$` in its name.
1894-
if (DeviceBinary->NumDeviceBinaries > 0) {
1895-
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
1896-
if (Binary->EntriesBegin != Binary->EntriesEnd) {
1897-
sycl_offload_entry Entry = Binary->EntriesBegin;
1898-
if (!Entry->IsNewOffloadEntryType() &&
1899-
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
1900-
return true;
1901-
}
1902-
}
1872+
static sycl_device_binary_property_set
1873+
getExportedSymbolPS(sycl_device_binary RawImg) {
1874+
sycl_device_binary_property_set ImgPS;
1875+
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
1876+
++ImgPS) {
1877+
if (ImgPS->Name &&
1878+
!strcmp(__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name))
1879+
return ImgPS;
19031880
}
1904-
return false;
1881+
1882+
return nullptr;
1883+
}
1884+
1885+
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
1886+
// For bfloat16 device library image, we should keep it although it doesn't
1887+
// include any kernel.
1888+
if (isBfloat16DeviceLibImage(RawImg))
1889+
return false;
1890+
1891+
// We may extend the logic here other than bfloat16 device library image.
1892+
return true;
19051893
}
19061894

19071895
void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19081896
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
1909-
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
19101897
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
19111898
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
19121899
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
19131900
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
1914-
// If the image does not contain kernels, skip it unless it is one of the
1915-
// bfloat16 device libraries, and it wasn't loaded before or resulted from
1916-
// runtime compilation.
1917-
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
1901+
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
19181902
continue;
19191903

19201904
std::unique_ptr<RTDeviceBinaryImage> Img;
1905+
bool IsBfloat16DeviceLib = false;
1906+
uint32_t Bfloat16DeviceLibVersion = 0;
19211907
if (isDeviceImageCompressed(RawImg))
19221908
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
19231909
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1927,25 +1913,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19271913
"SYCL RT was built without ZSTD support."
19281914
"Aborting. ");
19291915
#endif
1930-
else
1931-
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1916+
else {
1917+
IsBfloat16DeviceLib =
1918+
isBfloat16DeviceLibImage(RawImg, &Bfloat16DeviceLibVersion);
1919+
if (!IsBfloat16DeviceLib)
1920+
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1921+
}
19321922

19331923
static uint32_t SequenceID = 0;
19341924

1935-
// Fill the kernel argument mask map
1936-
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1937-
Img->getKernelParamOptInfo();
1938-
if (KPOIRange.isAvailable()) {
1939-
KernelNameToArgMaskMap &ArgMaskMap =
1940-
m_EliminatedKernelArgMasks[Img.get()];
1941-
for (const auto &Info : KPOIRange)
1942-
ArgMaskMap[Info->Name] =
1943-
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1925+
// Fill the kernel argument mask map, no need to do this for bfloat16
1926+
// device library image since it doesn't include any kernel.
1927+
if (!IsBfloat16DeviceLib) {
1928+
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1929+
Img->getKernelParamOptInfo();
1930+
if (KPOIRange.isAvailable()) {
1931+
KernelNameToArgMaskMap &ArgMaskMap =
1932+
m_EliminatedKernelArgMasks[Img.get()];
1933+
for (const auto &Info : KPOIRange)
1934+
ArgMaskMap[Info->Name] =
1935+
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1936+
}
19441937
}
19451938

19461939
// Fill maps for kernel bundles
19471940
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
19481941

1942+
// For bfloat16 device library image, it doesn't include any kernel, device
1943+
// global, virtual function, so just skip adding it to any related maps.
1944+
// The bfloat16 device library are provided by compiler and may be used by
1945+
// different sycl device images, program manager will own single copy for
1946+
// native and fallback version bfloat16 device library, these device
1947+
// library images will not be erased unless program manager is destroyed.
1948+
{
1949+
if (IsBfloat16DeviceLib) {
1950+
assert((Bfloat16DeviceLibVersion < 2) &&
1951+
"Invalid Bfloat16 Device Library Index.");
1952+
if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get())
1953+
continue;
1954+
size_t ImgSize =
1955+
static_cast<size_t>(RawImg->BinaryEnd - RawImg->BinaryStart);
1956+
std::unique_ptr<char[]> Data(new char[ImgSize]);
1957+
std::memcpy(Data.get(), RawImg->BinaryStart, ImgSize);
1958+
auto DynBfloat16DeviceLibImg =
1959+
std::make_unique<DynRTDeviceBinaryImage>(std::move(Data), ImgSize);
1960+
auto ESPropSet = getExportedSymbolPS(RawImg);
1961+
sycl_device_binary_property ESProp;
1962+
for (ESProp = ESPropSet->PropertiesBegin;
1963+
ESProp != ESPropSet->PropertiesEnd; ++ESProp) {
1964+
m_ExportedSymbolImages.insert(
1965+
{ESProp->Name, DynBfloat16DeviceLibImg.get()});
1966+
}
1967+
m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1968+
std::move(DynBfloat16DeviceLibImg);
1969+
continue;
1970+
}
1971+
}
1972+
19491973
// Register all exported symbols
19501974
for (const sycl_device_binary_property &ESProp :
19511975
Img->getExportedSymbols()) {
@@ -2110,19 +2134,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
21102134
}
21112135

21122136
void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
2113-
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
21142137
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
21152138
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
21162139
auto DevImgIt = m_DeviceImages.find(RawImg);
21172140
if (DevImgIt == m_DeviceImages.end())
21182141
continue;
21192142
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
21202143
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
2121-
// Skip clean up if there are no offload entries, unless `DeviceBinary`
2122-
// resulted from runtime compilation: Then, this is one of the `bfloat16`
2123-
// device libraries, so we want to make sure that the image and its exported
2124-
// symbols are removed from the program manager's maps.
2125-
if (EntriesB == EntriesE && !IsRTC)
2144+
if (EntriesB == EntriesE)
21262145
continue;
21272146

21282147
RTDeviceBinaryImage *Img = DevImgIt->second.get();
@@ -2650,7 +2669,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26502669
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
26512670
{
26522671
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2653-
DepKernelIDs = m_BinImg2KernelIDs[Dep];
2672+
// For device library images, they are not in m_BinImg2KernelIDs since
2673+
// no kernel is included.
2674+
auto DepIt = m_BinImg2KernelIDs.find(Dep);
2675+
if (DepIt != m_BinImg2KernelIDs.end())
2676+
DepKernelIDs = DepIt->second;
26542677
}
26552678

26562679
assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
@@ -2863,9 +2886,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28632886
for (const device_image_plain &Img : Imgs) {
28642887
std::shared_ptr<device_image_impl> DeviceImageImpl = getSyclObjImpl(Img);
28652888
// Duplicates are not expected here, otherwise urProgramLink should fail
2866-
KernelIDs.insert(KernelIDs.end(),
2867-
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
2868-
DeviceImageImpl->get_kernel_ids_ptr()->end());
2889+
if (DeviceImageImpl->get_kernel_ids_ptr())
2890+
KernelIDs.insert(KernelIDs.end(),
2891+
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
2892+
DeviceImageImpl->get_kernel_ids_ptr()->end());
28692893
// To be able to answer queries about specialziation constants, the new
28702894
// device image should have the specialization constants from all the linked
28712895
// images.

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <sycl/device.hpp>
2525
#include <sycl/kernel_bundle.hpp>
2626

27+
#include <array>
2728
#include <cstdint>
2829
#include <map>
2930
#include <memory>
@@ -376,11 +377,15 @@ class ProgramManager {
376377
collectDependentDeviceImagesForVirtualFunctions(
377378
const RTDeviceBinaryImage &Img, const device &Dev);
378379

380+
bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
381+
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
382+
const device &Dev);
383+
379384
protected:
380385
/// The three maps below are used during kernel resolution. Any kernel is
381386
/// identified by its name.
382387
using RTDeviceBinaryImageUPtr = std::unique_ptr<RTDeviceBinaryImage>;
383-
388+
using DynRTDeviceBinaryImageUPtr = std::unique_ptr<DynRTDeviceBinaryImage>;
384389
/// Maps names of kernels to their unique kernel IDs.
385390
/// TODO: Use std::unordered_set with transparent hash and equality functions
386391
/// when C++20 is enabled for the runtime library.
@@ -498,6 +503,12 @@ class ProgramManager {
498503
std::map<std::vector<unsigned char>, ur_kernel_handle_t>;
499504
std::unordered_map<std::string, MaterializedEntries> m_MaterializedKernels;
500505

506+
// Holds bfloat16 device library images, the 1st element is for fallback
507+
// version and 2nd is for native version. These bfloat16 device library
508+
// images are provided by compiler long time ago, we expect no further
509+
// update, so keeping 1 copy should be OK.
510+
std::array<DynRTDeviceBinaryImageUPtr, 2> m_Bfloat16DeviceLibImages;
511+
501512
friend class ::ProgramManagerTest;
502513
};
503514
} // namespace detail

sycl/test-e2e/KernelCompiler/sycl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,11 @@ int main() {
533533
if (!ok) {
534534
return -1;
535535
}
536-
536+
// Run test_device_libraries twice to verify bfloat16 device library.
537537
return test_build_and_run(q) || test_device_code_split(q) ||
538538
test_device_libraries(q) || test_esimd(q) ||
539-
test_unsupported_options(q) || test_error(q) ||
540-
test_no_visible_ids(q) || test_warning(q);
539+
test_device_libraries(q) || test_unsupported_options(q) ||
540+
test_error(q) || test_no_visible_ids(q) || test_warning(q);
541541
#else
542542
static_assert(false, "Kernel Compiler feature test macro undefined");
543543
#endif

0 commit comments

Comments
 (0)