Skip to content

Commit 80fd665

Browse files
authored
[SYCL] Keep multiple copies for bf16 device library image (#17461)
SYCL RT addImages function may be invoked multiple times for different sycl binary images, more than 1 of these sycl binary images may depend on bfloat16 device library. These bfloat16 device library images are provided by compiler and the implementation are stable now, so we only keep single copy for native and fallback version bfloat16 device library in program manager, these images will not be removed unless program manager is destroyed. --------- Signed-off-by: jinge90 <[email protected]>
1 parent 816a8da commit 80fd665

File tree

3 files changed

+123
-88
lines changed

3 files changed

+123
-88
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 108 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -615,22 +615,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
615615
}
616616

617617
// Quick check to see whether BinImage is a compiler-generated device image.
618-
static bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
618+
bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
619619
// SYCL devicelib image.
620-
if (BinImage->getDeviceLibMetadata().isAvailable())
620+
if ((m_Bfloat16DeviceLibImages[0].get() == BinImage) ||
621+
m_Bfloat16DeviceLibImages[1].get() == BinImage)
621622
return true;
622623

623624
return false;
624625
}
625626

626-
static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
627-
const device &Dev) {
627+
bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
628+
RTDeviceBinaryImage *BinImage, const device &Dev) {
628629
// Decide whether a devicelib image should be used.
629-
if (BinImage->getDeviceLibMetadata().isAvailable()) {
630-
const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
631-
BinImage->getDeviceLibMetadata();
632-
uint32_t DeviceLibMeta =
633-
DeviceBinaryProperty(*(DeviceLibMetaProp.begin())).asUint32();
630+
int Bfloat16DeviceLibVersion = -1;
631+
if (m_Bfloat16DeviceLibImages[0].get() == BinImage)
632+
Bfloat16DeviceLibVersion = 0;
633+
else if (m_Bfloat16DeviceLibImages[1].get() == BinImage)
634+
Bfloat16DeviceLibVersion = 1;
635+
636+
if (Bfloat16DeviceLibVersion != -1) {
634637
// Currently, only bfloat conversion devicelib are supported, so the prop
635638
// DeviceLibMeta are only used to represent fallback or native version.
636639
// For bfloat16 conversion devicelib, we have fallback and native version.
@@ -644,7 +647,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
644647
detail::getSyclObjImpl(Dev);
645648
std::string NativeBF16ExtName = "cl_intel_bfloat16_conversions";
646649
bool NativeBF16Supported = (DeviceImpl->has_extension(NativeBF16ExtName));
647-
return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
650+
return NativeBF16Supported ==
651+
(Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
648652
}
649653

650654
return false;
@@ -1838,87 +1842,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18381842
return {};
18391843
}
18401844

1841-
static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
1842-
// For bfloat16 device library image, we should keep it. However, in some
1843-
// scenario, __sycl_register_lib can be called multiple times and the same
1844-
// bfloat16 device library image may be handled multiple times which is not
1845-
// needed. 2 static bool variables are created to record whether native or
1846-
// fallback bfloat16 device library image has been handled, if yes, we just
1847-
// need to skip it.
1848-
// We cannot prevent redundant loads of device library images if they are part
1849-
// of a runtime-compiled device binary, as these will be freed when the
1850-
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1851-
// on the presence of RTC device library images.
1845+
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
1846+
uint32_t *LibVersion = nullptr) {
18521847
sycl_device_binary_property_set ImgPS;
1853-
static bool IsNativeBF16DeviceLibHandled = false;
1854-
static bool IsFallbackBF16DeviceLibHandled = false;
18551848
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
18561849
++ImgPS) {
18571850
if (ImgPS->Name &&
18581851
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
1852+
if (!LibVersion)
1853+
return true;
1854+
1855+
// Valid version for bfloat16 device library is 0(fallback), 1(native).
1856+
*LibVersion = 2;
18591857
sycl_device_binary_property ImgP;
18601858
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
18611859
++ImgP) {
18621860
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
18631861
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
18641862
break;
18651863
}
1866-
if (ImgP == ImgPS->PropertiesEnd)
1867-
return true;
1868-
1869-
// A valid bfloat16 device library image is found here.
1870-
// If it originated from RTC, we cannot skip it, but do not mark it as
1871-
// being present.
1872-
if (IsRTC)
1873-
return false;
1874-
1875-
// Otherwise, we need to check whether it has been handled already.
1876-
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
1877-
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
1878-
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
1879-
return true;
1880-
1881-
if (BF16NativeVal == 0)
1882-
IsFallbackBF16DeviceLibHandled = true;
1883-
else
1884-
IsNativeBF16DeviceLibHandled = true;
1885-
1886-
return false;
1864+
if (ImgP != ImgPS->PropertiesEnd)
1865+
*LibVersion = DeviceBinaryProperty(ImgP).asUint32();
1866+
return true;
18871867
}
18881868
}
1889-
return true;
1869+
1870+
return false;
18901871
}
18911872

1892-
static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
1893-
// Check whether the first device binary contains a legacy format offload
1894-
// entry with a `$` in its name.
1895-
if (DeviceBinary->NumDeviceBinaries > 0) {
1896-
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
1897-
if (Binary->EntriesBegin != Binary->EntriesEnd) {
1898-
sycl_offload_entry Entry = Binary->EntriesBegin;
1899-
if (!Entry->IsNewOffloadEntryType() &&
1900-
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
1901-
return true;
1902-
}
1903-
}
1873+
static sycl_device_binary_property_set
1874+
getExportedSymbolPS(sycl_device_binary RawImg) {
1875+
sycl_device_binary_property_set ImgPS;
1876+
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
1877+
++ImgPS) {
1878+
if (ImgPS->Name &&
1879+
!strcmp(__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name))
1880+
return ImgPS;
19041881
}
1905-
return false;
1882+
1883+
return nullptr;
1884+
}
1885+
1886+
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
1887+
// For bfloat16 device library image, we should keep it although it doesn't
1888+
// include any kernel.
1889+
if (isBfloat16DeviceLibImage(RawImg))
1890+
return false;
1891+
1892+
// We may extend the logic here other than bfloat16 device library image.
1893+
return true;
19061894
}
19071895

19081896
void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19091897
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
1910-
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
19111898
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
19121899
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
19131900
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
19141901
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
1915-
// If the image does not contain kernels, skip it unless it is one of the
1916-
// bfloat16 device libraries, and it wasn't loaded before or resulted from
1917-
// runtime compilation.
1918-
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
1902+
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
19191903
continue;
19201904

19211905
std::unique_ptr<RTDeviceBinaryImage> Img;
1906+
bool IsBfloat16DeviceLib = false;
1907+
uint32_t Bfloat16DeviceLibVersion = 0;
19221908
if (isDeviceImageCompressed(RawImg))
19231909
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
19241910
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1928,25 +1914,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19281914
"SYCL RT was built without ZSTD support."
19291915
"Aborting. ");
19301916
#endif
1931-
else
1932-
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1917+
else {
1918+
IsBfloat16DeviceLib =
1919+
isBfloat16DeviceLibImage(RawImg, &Bfloat16DeviceLibVersion);
1920+
if (!IsBfloat16DeviceLib)
1921+
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1922+
}
19331923

19341924
static uint32_t SequenceID = 0;
19351925

1936-
// Fill the kernel argument mask map
1937-
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1938-
Img->getKernelParamOptInfo();
1939-
if (KPOIRange.isAvailable()) {
1940-
KernelNameToArgMaskMap &ArgMaskMap =
1941-
m_EliminatedKernelArgMasks[Img.get()];
1942-
for (const auto &Info : KPOIRange)
1943-
ArgMaskMap[Info->Name] =
1944-
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1926+
// Fill the kernel argument mask map, no need to do this for bfloat16
1927+
// device library image since it doesn't include any kernel.
1928+
if (!IsBfloat16DeviceLib) {
1929+
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1930+
Img->getKernelParamOptInfo();
1931+
if (KPOIRange.isAvailable()) {
1932+
KernelNameToArgMaskMap &ArgMaskMap =
1933+
m_EliminatedKernelArgMasks[Img.get()];
1934+
for (const auto &Info : KPOIRange)
1935+
ArgMaskMap[Info->Name] =
1936+
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1937+
}
19451938
}
19461939

19471940
// Fill maps for kernel bundles
19481941
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
19491942

1943+
// For bfloat16 device library image, it doesn't include any kernel, device
1944+
// global, virtual function, so just skip adding it to any related maps.
1945+
// The bfloat16 device library are provided by compiler and may be used by
1946+
// different sycl device images, program manager will own single copy for
1947+
// native and fallback version bfloat16 device library, these device
1948+
// library images will not be erased unless program manager is destroyed.
1949+
{
1950+
if (IsBfloat16DeviceLib) {
1951+
assert((Bfloat16DeviceLibVersion < 2) &&
1952+
"Invalid Bfloat16 Device Library Index.");
1953+
if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get())
1954+
continue;
1955+
size_t ImgSize =
1956+
static_cast<size_t>(RawImg->BinaryEnd - RawImg->BinaryStart);
1957+
std::unique_ptr<char[]> Data(new char[ImgSize]);
1958+
std::memcpy(Data.get(), RawImg->BinaryStart, ImgSize);
1959+
auto DynBfloat16DeviceLibImg =
1960+
std::make_unique<DynRTDeviceBinaryImage>(std::move(Data), ImgSize);
1961+
auto ESPropSet = getExportedSymbolPS(RawImg);
1962+
sycl_device_binary_property ESProp;
1963+
for (ESProp = ESPropSet->PropertiesBegin;
1964+
ESProp != ESPropSet->PropertiesEnd; ++ESProp) {
1965+
m_ExportedSymbolImages.insert(
1966+
{ESProp->Name, DynBfloat16DeviceLibImg.get()});
1967+
}
1968+
m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1969+
std::move(DynBfloat16DeviceLibImg);
1970+
continue;
1971+
}
1972+
}
1973+
19501974
// Register all exported symbols
19511975
for (const sycl_device_binary_property &ESProp :
19521976
Img->getExportedSymbols()) {
@@ -2111,19 +2135,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
21112135
}
21122136

21132137
void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
2114-
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
21152138
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
21162139
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
21172140
auto DevImgIt = m_DeviceImages.find(RawImg);
21182141
if (DevImgIt == m_DeviceImages.end())
21192142
continue;
21202143
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
21212144
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
2122-
// Skip clean up if there are no offload entries, unless `DeviceBinary`
2123-
// resulted from runtime compilation: Then, this is one of the `bfloat16`
2124-
// device libraries, so we want to make sure that the image and its exported
2125-
// symbols are removed from the program manager's maps.
2126-
if (EntriesB == EntriesE && !IsRTC)
2145+
if (EntriesB == EntriesE)
21272146
continue;
21282147

21292148
RTDeviceBinaryImage *Img = DevImgIt->second.get();
@@ -2651,7 +2670,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26512670
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
26522671
{
26532672
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
2654-
DepKernelIDs = m_BinImg2KernelIDs[Dep];
2673+
// For device library images, they are not in m_BinImg2KernelIDs since
2674+
// no kernel is included.
2675+
auto DepIt = m_BinImg2KernelIDs.find(Dep);
2676+
if (DepIt != m_BinImg2KernelIDs.end())
2677+
DepKernelIDs = DepIt->second;
26552678
}
26562679

26572680
assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
@@ -2865,9 +2888,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28652888
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
28662889
getSyclObjImpl(Img);
28672890
// Duplicates are not expected here, otherwise urProgramLink should fail
2868-
KernelIDs.insert(KernelIDs.end(),
2869-
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
2870-
DeviceImageImpl->get_kernel_ids_ptr()->end());
2891+
if (DeviceImageImpl->get_kernel_ids_ptr())
2892+
KernelIDs.insert(KernelIDs.end(),
2893+
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
2894+
DeviceImageImpl->get_kernel_ids_ptr()->end());
28712895
// To be able to answer queries about specialziation constants, the new
28722896
// device image should have the specialization constants from all the linked
28732897
// images.

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include <sycl/device.hpp>
2525
#include <sycl/kernel_bundle.hpp>
2626

27+
#include <array>
2728
#include <cstdint>
2829
#include <map>
2930
#include <memory>
@@ -376,11 +377,15 @@ class ProgramManager {
376377
collectDependentDeviceImagesForVirtualFunctions(
377378
const RTDeviceBinaryImage &Img, const device &Dev);
378379

380+
bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
381+
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
382+
const device &Dev);
383+
379384
protected:
380385
/// The three maps below are used during kernel resolution. Any kernel is
381386
/// identified by its name.
382387
using RTDeviceBinaryImageUPtr = std::unique_ptr<RTDeviceBinaryImage>;
383-
388+
using DynRTDeviceBinaryImageUPtr = std::unique_ptr<DynRTDeviceBinaryImage>;
384389
/// Maps names of kernels to their unique kernel IDs.
385390
/// TODO: Use std::unordered_set with transparent hash and equality functions
386391
/// when C++20 is enabled for the runtime library.
@@ -498,6 +503,12 @@ class ProgramManager {
498503
std::map<std::vector<unsigned char>, ur_kernel_handle_t>;
499504
std::unordered_map<std::string, MaterializedEntries> m_MaterializedKernels;
500505

506+
// Holds bfloat16 device library images, the 1st element is for fallback
507+
// version and 2nd is for native version. These bfloat16 device library
508+
// images are provided by compiler long time ago, we expect no further
509+
// update, so keeping 1 copy should be OK.
510+
std::array<DynRTDeviceBinaryImageUPtr, 2> m_Bfloat16DeviceLibImages;
511+
501512
friend class ::ProgramManagerTest;
502513
};
503514
} // namespace detail

sycl/test-e2e/KernelCompiler/sycl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -532,11 +532,11 @@ int main() {
532532
if (!ok) {
533533
return -1;
534534
}
535-
535+
// Run test_device_libraries twice to verify bfloat16 device library.
536536
return test_build_and_run(q) || test_device_code_split(q) ||
537537
test_device_libraries(q) || test_esimd(q) ||
538-
test_unsupported_options(q) || test_error(q) ||
539-
test_no_visible_ids(q) || test_warning(q);
538+
test_device_libraries(q) || test_unsupported_options(q) ||
539+
test_error(q) || test_no_visible_ids(q) || test_warning(q);
540540
#else
541541
static_assert(false, "Kernel Compiler feature test macro undefined");
542542
#endif

0 commit comments

Comments
 (0)