@@ -615,22 +615,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
615
615
}
616
616
617
617
// Quick check to see whether BinImage is a compiler-generated device image.
618
- static bool isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
618
+ bool ProgramManager:: isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
619
619
// SYCL devicelib image.
620
- if (BinImage->getDeviceLibMetadata ().isAvailable ())
620
+ if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
621
+ m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
621
622
return true ;
622
623
623
624
return false ;
624
625
}
625
626
626
- static bool isSpecialDeviceImageShouldBeUsed (RTDeviceBinaryImage *BinImage,
627
- const device &Dev) {
627
+ bool ProgramManager:: isSpecialDeviceImageShouldBeUsed (
628
+ RTDeviceBinaryImage *BinImage, const device &Dev) {
628
629
// Decide whether a devicelib image should be used.
629
- if (BinImage->getDeviceLibMetadata ().isAvailable ()) {
630
- const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
631
- BinImage->getDeviceLibMetadata ();
632
- uint32_t DeviceLibMeta =
633
- DeviceBinaryProperty (*(DeviceLibMetaProp.begin ())).asUint32 ();
630
+ int Bfloat16DeviceLibVersion = -1 ;
631
+ if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
632
+ Bfloat16DeviceLibVersion = 0 ;
633
+ else if (m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
634
+ Bfloat16DeviceLibVersion = 1 ;
635
+
636
+ if (Bfloat16DeviceLibVersion != -1 ) {
634
637
// Currently, only bfloat conversion devicelib are supported, so the prop
635
638
// DeviceLibMeta are only used to represent fallback or native version.
636
639
// For bfloat16 conversion devicelib, we have fallback and native version.
@@ -644,7 +647,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
644
647
detail::getSyclObjImpl (Dev);
645
648
std::string NativeBF16ExtName = " cl_intel_bfloat16_conversions" ;
646
649
bool NativeBF16Supported = (DeviceImpl->has_extension (NativeBF16ExtName));
647
- return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
650
+ return NativeBF16Supported ==
651
+ (Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
648
652
}
649
653
650
654
return false ;
@@ -1838,87 +1842,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
1838
1842
return {};
1839
1843
}
1840
1844
1841
- static bool shouldSkipEmptyImage (sycl_device_binary RawImg, bool IsRTC) {
1842
- // For bfloat16 device library image, we should keep it. However, in some
1843
- // scenario, __sycl_register_lib can be called multiple times and the same
1844
- // bfloat16 device library image may be handled multiple times which is not
1845
- // needed. 2 static bool variables are created to record whether native or
1846
- // fallback bfloat16 device library image has been handled, if yes, we just
1847
- // need to skip it.
1848
- // We cannot prevent redundant loads of device library images if they are part
1849
- // of a runtime-compiled device binary, as these will be freed when the
1850
- // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1851
- // on the presence of RTC device library images.
1845
+ static bool isBfloat16DeviceLibImage (sycl_device_binary RawImg,
1846
+ uint32_t *LibVersion = nullptr ) {
1852
1847
sycl_device_binary_property_set ImgPS;
1853
- static bool IsNativeBF16DeviceLibHandled = false ;
1854
- static bool IsFallbackBF16DeviceLibHandled = false ;
1855
1848
for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1856
1849
++ImgPS) {
1857
1850
if (ImgPS->Name &&
1858
1851
!strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name )) {
1852
+ if (!LibVersion)
1853
+ return true ;
1854
+
1855
+ // Valid version for bfloat16 device library is 0(fallback), 1(native).
1856
+ *LibVersion = 2 ;
1859
1857
sycl_device_binary_property ImgP;
1860
1858
for (ImgP = ImgPS->PropertiesBegin ; ImgP != ImgPS->PropertiesEnd ;
1861
1859
++ImgP) {
1862
1860
if (ImgP->Name && !strcmp (" bfloat16" , ImgP->Name ) &&
1863
1861
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
1864
1862
break ;
1865
1863
}
1866
- if (ImgP == ImgPS->PropertiesEnd )
1867
- return true ;
1868
-
1869
- // A valid bfloat16 device library image is found here.
1870
- // If it originated from RTC, we cannot skip it, but do not mark it as
1871
- // being present.
1872
- if (IsRTC)
1873
- return false ;
1874
-
1875
- // Otherwise, we need to check whether it has been handled already.
1876
- uint32_t BF16NativeVal = DeviceBinaryProperty (ImgP).asUint32 ();
1877
- if (((BF16NativeVal == 0 ) && IsFallbackBF16DeviceLibHandled) ||
1878
- ((BF16NativeVal == 1 ) && IsNativeBF16DeviceLibHandled))
1879
- return true ;
1880
-
1881
- if (BF16NativeVal == 0 )
1882
- IsFallbackBF16DeviceLibHandled = true ;
1883
- else
1884
- IsNativeBF16DeviceLibHandled = true ;
1885
-
1886
- return false ;
1864
+ if (ImgP != ImgPS->PropertiesEnd )
1865
+ *LibVersion = DeviceBinaryProperty (ImgP).asUint32 ();
1866
+ return true ;
1887
1867
}
1888
1868
}
1889
- return true ;
1869
+
1870
+ return false ;
1890
1871
}
1891
1872
1892
- static bool isCompiledAtRuntime (sycl_device_binaries DeviceBinary) {
1893
- // Check whether the first device binary contains a legacy format offload
1894
- // entry with a `$` in its name.
1895
- if (DeviceBinary->NumDeviceBinaries > 0 ) {
1896
- sycl_device_binary Binary = DeviceBinary->DeviceBinaries ;
1897
- if (Binary->EntriesBegin != Binary->EntriesEnd ) {
1898
- sycl_offload_entry Entry = Binary->EntriesBegin ;
1899
- if (!Entry->IsNewOffloadEntryType () &&
1900
- std::string_view{Entry->name }.find (' $' ) != std::string_view::npos) {
1901
- return true ;
1902
- }
1903
- }
1873
+ static sycl_device_binary_property_set
1874
+ getExportedSymbolPS (sycl_device_binary RawImg) {
1875
+ sycl_device_binary_property_set ImgPS;
1876
+ for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1877
+ ++ImgPS) {
1878
+ if (ImgPS->Name &&
1879
+ !strcmp (__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name ))
1880
+ return ImgPS;
1904
1881
}
1905
- return false ;
1882
+
1883
+ return nullptr ;
1884
+ }
1885
+
1886
+ static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1887
+ // For bfloat16 device library image, we should keep it although it doesn't
1888
+ // include any kernel.
1889
+ if (isBfloat16DeviceLibImage (RawImg))
1890
+ return false ;
1891
+
1892
+ // We may extend the logic here other than bfloat16 device library image.
1893
+ return true ;
1906
1894
}
1907
1895
1908
1896
void ProgramManager::addImages (sycl_device_binaries DeviceBinary) {
1909
1897
const bool DumpImages = std::getenv (" SYCL_DUMP_IMAGES" ) && !m_UseSpvFile;
1910
- const bool IsRTC = isCompiledAtRuntime (DeviceBinary);
1911
1898
for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
1912
1899
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
1913
1900
const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
1914
1901
const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
1915
- // If the image does not contain kernels, skip it unless it is one of the
1916
- // bfloat16 device libraries, and it wasn't loaded before or resulted from
1917
- // runtime compilation.
1918
- if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg, IsRTC))
1902
+ if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg))
1919
1903
continue ;
1920
1904
1921
1905
std::unique_ptr<RTDeviceBinaryImage> Img;
1906
+ bool IsBfloat16DeviceLib = false ;
1907
+ uint32_t Bfloat16DeviceLibVersion = 0 ;
1922
1908
if (isDeviceImageCompressed (RawImg))
1923
1909
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
1924
1910
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1928,25 +1914,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
1928
1914
" SYCL RT was built without ZSTD support."
1929
1915
" Aborting. " );
1930
1916
#endif
1931
- else
1932
- Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1917
+ else {
1918
+ IsBfloat16DeviceLib =
1919
+ isBfloat16DeviceLibImage (RawImg, &Bfloat16DeviceLibVersion);
1920
+ if (!IsBfloat16DeviceLib)
1921
+ Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1922
+ }
1933
1923
1934
1924
static uint32_t SequenceID = 0 ;
1935
1925
1936
- // Fill the kernel argument mask map
1937
- const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1938
- Img->getKernelParamOptInfo ();
1939
- if (KPOIRange.isAvailable ()) {
1940
- KernelNameToArgMaskMap &ArgMaskMap =
1941
- m_EliminatedKernelArgMasks[Img.get ()];
1942
- for (const auto &Info : KPOIRange)
1943
- ArgMaskMap[Info->Name ] =
1944
- createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1926
+ // Fill the kernel argument mask map, no need to do this for bfloat16
1927
+ // device library image since it doesn't include any kernel.
1928
+ if (!IsBfloat16DeviceLib) {
1929
+ const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1930
+ Img->getKernelParamOptInfo ();
1931
+ if (KPOIRange.isAvailable ()) {
1932
+ KernelNameToArgMaskMap &ArgMaskMap =
1933
+ m_EliminatedKernelArgMasks[Img.get ()];
1934
+ for (const auto &Info : KPOIRange)
1935
+ ArgMaskMap[Info->Name ] =
1936
+ createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1937
+ }
1945
1938
}
1946
1939
1947
1940
// Fill maps for kernel bundles
1948
1941
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
1949
1942
1943
+ // For bfloat16 device library image, it doesn't include any kernel, device
1944
+ // global, virtual function, so just skip adding it to any related maps.
1945
+ // The bfloat16 device library are provided by compiler and may be used by
1946
+ // different sycl device images, program manager will own single copy for
1947
+ // native and fallback version bfloat16 device library, these device
1948
+ // library images will not be erased unless program manager is destroyed.
1949
+ {
1950
+ if (IsBfloat16DeviceLib) {
1951
+ assert ((Bfloat16DeviceLibVersion < 2 ) &&
1952
+ " Invalid Bfloat16 Device Library Index." );
1953
+ if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get ())
1954
+ continue ;
1955
+ size_t ImgSize =
1956
+ static_cast <size_t >(RawImg->BinaryEnd - RawImg->BinaryStart );
1957
+ std::unique_ptr<char []> Data (new char [ImgSize]);
1958
+ std::memcpy (Data.get (), RawImg->BinaryStart , ImgSize);
1959
+ auto DynBfloat16DeviceLibImg =
1960
+ std::make_unique<DynRTDeviceBinaryImage>(std::move (Data), ImgSize);
1961
+ auto ESPropSet = getExportedSymbolPS (RawImg);
1962
+ sycl_device_binary_property ESProp;
1963
+ for (ESProp = ESPropSet->PropertiesBegin ;
1964
+ ESProp != ESPropSet->PropertiesEnd ; ++ESProp) {
1965
+ m_ExportedSymbolImages.insert (
1966
+ {ESProp->Name , DynBfloat16DeviceLibImg.get ()});
1967
+ }
1968
+ m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1969
+ std::move (DynBfloat16DeviceLibImg);
1970
+ continue ;
1971
+ }
1972
+ }
1973
+
1950
1974
// Register all exported symbols
1951
1975
for (const sycl_device_binary_property &ESProp :
1952
1976
Img->getExportedSymbols ()) {
@@ -2111,19 +2135,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
2111
2135
}
2112
2136
2113
2137
void ProgramManager::removeImages (sycl_device_binaries DeviceBinary) {
2114
- bool IsRTC = isCompiledAtRuntime (DeviceBinary);
2115
2138
for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
2116
2139
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
2117
2140
auto DevImgIt = m_DeviceImages.find (RawImg);
2118
2141
if (DevImgIt == m_DeviceImages.end ())
2119
2142
continue ;
2120
2143
const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
2121
2144
const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
2122
- // Skip clean up if there are no offload entries, unless `DeviceBinary`
2123
- // resulted from runtime compilation: Then, this is one of the `bfloat16`
2124
- // device libraries, so we want to make sure that the image and its exported
2125
- // symbols are removed from the program manager's maps.
2126
- if (EntriesB == EntriesE && !IsRTC)
2145
+ if (EntriesB == EntriesE)
2127
2146
continue ;
2128
2147
2129
2148
RTDeviceBinaryImage *Img = DevImgIt->second .get ();
@@ -2651,7 +2670,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2651
2670
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
2652
2671
{
2653
2672
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2654
- DepKernelIDs = m_BinImg2KernelIDs[Dep];
2673
+ // For device library images, they are not in m_BinImg2KernelIDs since
2674
+ // no kernel is included.
2675
+ auto DepIt = m_BinImg2KernelIDs.find (Dep);
2676
+ if (DepIt != m_BinImg2KernelIDs.end ())
2677
+ DepKernelIDs = DepIt->second ;
2655
2678
}
2656
2679
2657
2680
assert (ImgInfoPair.second .State == getBinImageState (Dep) &&
@@ -2865,9 +2888,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
2865
2888
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
2866
2889
getSyclObjImpl (Img);
2867
2890
// Duplicates are not expected here, otherwise urProgramLink should fail
2868
- KernelIDs.insert (KernelIDs.end (),
2869
- DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2870
- DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2891
+ if (DeviceImageImpl->get_kernel_ids_ptr ())
2892
+ KernelIDs.insert (KernelIDs.end (),
2893
+ DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2894
+ DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2871
2895
// To be able to answer queries about specialziation constants, the new
2872
2896
// device image should have the specialization constants from all the linked
2873
2897
// images.
0 commit comments