@@ -614,22 +614,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
614
614
}
615
615
616
616
// Quick check to see whether BinImage is a compiler-generated device image.
617
- static bool isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
617
+ bool ProgramManager:: isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
618
618
// SYCL devicelib image.
619
- if (BinImage->getDeviceLibMetadata ().isAvailable ())
619
+ if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
620
+ m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
620
621
return true ;
621
622
622
623
return false ;
623
624
}
624
625
625
- static bool isSpecialDeviceImageShouldBeUsed (RTDeviceBinaryImage *BinImage,
626
- const device &Dev) {
626
+ bool ProgramManager:: isSpecialDeviceImageShouldBeUsed (
627
+ RTDeviceBinaryImage *BinImage, const device &Dev) {
627
628
// Decide whether a devicelib image should be used.
628
- if (BinImage->getDeviceLibMetadata ().isAvailable ()) {
629
- const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
630
- BinImage->getDeviceLibMetadata ();
631
- uint32_t DeviceLibMeta =
632
- DeviceBinaryProperty (*(DeviceLibMetaProp.begin ())).asUint32 ();
629
+ int Bfloat16DeviceLibVersion = -1 ;
630
+ if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
631
+ Bfloat16DeviceLibVersion = 0 ;
632
+ else if (m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
633
+ Bfloat16DeviceLibVersion = 1 ;
634
+
635
+ if (Bfloat16DeviceLibVersion != -1 ) {
633
636
// Currently, only bfloat conversion devicelib are supported, so the prop
634
637
// DeviceLibMeta are only used to represent fallback or native version.
635
638
// For bfloat16 conversion devicelib, we have fallback and native version.
@@ -643,7 +646,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
643
646
detail::getSyclObjImpl (Dev);
644
647
std::string NativeBF16ExtName = " cl_intel_bfloat16_conversions" ;
645
648
bool NativeBF16Supported = (DeviceImpl->has_extension (NativeBF16ExtName));
646
- return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
649
+ return NativeBF16Supported ==
650
+ (Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
647
651
}
648
652
649
653
return false ;
@@ -1837,87 +1841,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
1837
1841
return {};
1838
1842
}
1839
1843
1840
- static bool shouldSkipEmptyImage (sycl_device_binary RawImg, bool IsRTC) {
1841
- // For bfloat16 device library image, we should keep it. However, in some
1842
- // scenario, __sycl_register_lib can be called multiple times and the same
1843
- // bfloat16 device library image may be handled multiple times which is not
1844
- // needed. 2 static bool variables are created to record whether native or
1845
- // fallback bfloat16 device library image has been handled, if yes, we just
1846
- // need to skip it.
1847
- // We cannot prevent redundant loads of device library images if they are part
1848
- // of a runtime-compiled device binary, as these will be freed when the
1849
- // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1850
- // on the presence of RTC device library images.
1844
+ static bool isBfloat16DeviceLibImage (sycl_device_binary RawImg,
1845
+ uint32_t *LibVersion = nullptr ) {
1851
1846
sycl_device_binary_property_set ImgPS;
1852
- static bool IsNativeBF16DeviceLibHandled = false ;
1853
- static bool IsFallbackBF16DeviceLibHandled = false ;
1854
1847
for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1855
1848
++ImgPS) {
1856
1849
if (ImgPS->Name &&
1857
1850
!strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name )) {
1851
+ if (!LibVersion)
1852
+ return true ;
1853
+
1854
+ // Valid version for bfloat16 device library is 0(fallback), 1(native).
1855
+ *LibVersion = 2 ;
1858
1856
sycl_device_binary_property ImgP;
1859
1857
for (ImgP = ImgPS->PropertiesBegin ; ImgP != ImgPS->PropertiesEnd ;
1860
1858
++ImgP) {
1861
1859
if (ImgP->Name && !strcmp (" bfloat16" , ImgP->Name ) &&
1862
1860
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
1863
1861
break ;
1864
1862
}
1865
- if (ImgP == ImgPS->PropertiesEnd )
1866
- return true ;
1867
-
1868
- // A valid bfloat16 device library image is found here.
1869
- // If it originated from RTC, we cannot skip it, but do not mark it as
1870
- // being present.
1871
- if (IsRTC)
1872
- return false ;
1873
-
1874
- // Otherwise, we need to check whether it has been handled already.
1875
- uint32_t BF16NativeVal = DeviceBinaryProperty (ImgP).asUint32 ();
1876
- if (((BF16NativeVal == 0 ) && IsFallbackBF16DeviceLibHandled) ||
1877
- ((BF16NativeVal == 1 ) && IsNativeBF16DeviceLibHandled))
1878
- return true ;
1879
-
1880
- if (BF16NativeVal == 0 )
1881
- IsFallbackBF16DeviceLibHandled = true ;
1882
- else
1883
- IsNativeBF16DeviceLibHandled = true ;
1884
-
1885
- return false ;
1863
+ if (ImgP != ImgPS->PropertiesEnd )
1864
+ *LibVersion = DeviceBinaryProperty (ImgP).asUint32 ();
1865
+ return true ;
1886
1866
}
1887
1867
}
1888
- return true ;
1868
+
1869
+ return false ;
1889
1870
}
1890
1871
1891
- static bool isCompiledAtRuntime (sycl_device_binaries DeviceBinary) {
1892
- // Check whether the first device binary contains a legacy format offload
1893
- // entry with a `$` in its name.
1894
- if (DeviceBinary->NumDeviceBinaries > 0 ) {
1895
- sycl_device_binary Binary = DeviceBinary->DeviceBinaries ;
1896
- if (Binary->EntriesBegin != Binary->EntriesEnd ) {
1897
- sycl_offload_entry Entry = Binary->EntriesBegin ;
1898
- if (!Entry->IsNewOffloadEntryType () &&
1899
- std::string_view{Entry->name }.find (' $' ) != std::string_view::npos) {
1900
- return true ;
1901
- }
1902
- }
1872
+ static sycl_device_binary_property_set
1873
+ getExportedSymbolPS (sycl_device_binary RawImg) {
1874
+ sycl_device_binary_property_set ImgPS;
1875
+ for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1876
+ ++ImgPS) {
1877
+ if (ImgPS->Name &&
1878
+ !strcmp (__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name ))
1879
+ return ImgPS;
1903
1880
}
1904
- return false ;
1881
+
1882
+ return nullptr ;
1883
+ }
1884
+
1885
+ static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1886
+ // For bfloat16 device library image, we should keep it although it doesn't
1887
+ // include any kernel.
1888
+ if (isBfloat16DeviceLibImage (RawImg))
1889
+ return false ;
1890
+
1891
+ // We may extend the logic here other than bfloat16 device library image.
1892
+ return true ;
1905
1893
}
1906
1894
1907
1895
void ProgramManager::addImages (sycl_device_binaries DeviceBinary) {
1908
1896
const bool DumpImages = std::getenv (" SYCL_DUMP_IMAGES" ) && !m_UseSpvFile;
1909
- const bool IsRTC = isCompiledAtRuntime (DeviceBinary);
1910
1897
for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
1911
1898
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
1912
1899
const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
1913
1900
const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
1914
- // If the image does not contain kernels, skip it unless it is one of the
1915
- // bfloat16 device libraries, and it wasn't loaded before or resulted from
1916
- // runtime compilation.
1917
- if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg, IsRTC))
1901
+ if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg))
1918
1902
continue ;
1919
1903
1920
1904
std::unique_ptr<RTDeviceBinaryImage> Img;
1905
+ bool IsBfloat16DeviceLib = false ;
1906
+ uint32_t Bfloat16DeviceLibVersion = 0 ;
1921
1907
if (isDeviceImageCompressed (RawImg))
1922
1908
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
1923
1909
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1927,25 +1913,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
1927
1913
" SYCL RT was built without ZSTD support."
1928
1914
" Aborting. " );
1929
1915
#endif
1930
- else
1931
- Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1916
+ else {
1917
+ IsBfloat16DeviceLib =
1918
+ isBfloat16DeviceLibImage (RawImg, &Bfloat16DeviceLibVersion);
1919
+ if (!IsBfloat16DeviceLib)
1920
+ Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1921
+ }
1932
1922
1933
1923
static uint32_t SequenceID = 0 ;
1934
1924
1935
- // Fill the kernel argument mask map
1936
- const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1937
- Img->getKernelParamOptInfo ();
1938
- if (KPOIRange.isAvailable ()) {
1939
- KernelNameToArgMaskMap &ArgMaskMap =
1940
- m_EliminatedKernelArgMasks[Img.get ()];
1941
- for (const auto &Info : KPOIRange)
1942
- ArgMaskMap[Info->Name ] =
1943
- createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1925
+ // Fill the kernel argument mask map, no need to do this for bfloat16
1926
+ // device library image since it doesn't include any kernel.
1927
+ if (!IsBfloat16DeviceLib) {
1928
+ const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1929
+ Img->getKernelParamOptInfo ();
1930
+ if (KPOIRange.isAvailable ()) {
1931
+ KernelNameToArgMaskMap &ArgMaskMap =
1932
+ m_EliminatedKernelArgMasks[Img.get ()];
1933
+ for (const auto &Info : KPOIRange)
1934
+ ArgMaskMap[Info->Name ] =
1935
+ createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1936
+ }
1944
1937
}
1945
1938
1946
1939
// Fill maps for kernel bundles
1947
1940
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
1948
1941
1942
+ // For bfloat16 device library image, it doesn't include any kernel, device
1943
+ // global, virtual function, so just skip adding it to any related maps.
1944
+ // The bfloat16 device library are provided by compiler and may be used by
1945
+ // different sycl device images, program manager will own single copy for
1946
+ // native and fallback version bfloat16 device library, these device
1947
+ // library images will not be erased unless program manager is destroyed.
1948
+ {
1949
+ if (IsBfloat16DeviceLib) {
1950
+ assert ((Bfloat16DeviceLibVersion < 2 ) &&
1951
+ " Invalid Bfloat16 Device Library Index." );
1952
+ if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get ())
1953
+ continue ;
1954
+ size_t ImgSize =
1955
+ static_cast <size_t >(RawImg->BinaryEnd - RawImg->BinaryStart );
1956
+ std::unique_ptr<char []> Data (new char [ImgSize]);
1957
+ std::memcpy (Data.get (), RawImg->BinaryStart , ImgSize);
1958
+ auto DynBfloat16DeviceLibImg =
1959
+ std::make_unique<DynRTDeviceBinaryImage>(std::move (Data), ImgSize);
1960
+ auto ESPropSet = getExportedSymbolPS (RawImg);
1961
+ sycl_device_binary_property ESProp;
1962
+ for (ESProp = ESPropSet->PropertiesBegin ;
1963
+ ESProp != ESPropSet->PropertiesEnd ; ++ESProp) {
1964
+ m_ExportedSymbolImages.insert (
1965
+ {ESProp->Name , DynBfloat16DeviceLibImg.get ()});
1966
+ }
1967
+ m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1968
+ std::move (DynBfloat16DeviceLibImg);
1969
+ continue ;
1970
+ }
1971
+ }
1972
+
1949
1973
// Register all exported symbols
1950
1974
for (const sycl_device_binary_property &ESProp :
1951
1975
Img->getExportedSymbols ()) {
@@ -2110,19 +2134,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
2110
2134
}
2111
2135
2112
2136
void ProgramManager::removeImages (sycl_device_binaries DeviceBinary) {
2113
- bool IsRTC = isCompiledAtRuntime (DeviceBinary);
2114
2137
for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
2115
2138
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
2116
2139
auto DevImgIt = m_DeviceImages.find (RawImg);
2117
2140
if (DevImgIt == m_DeviceImages.end ())
2118
2141
continue ;
2119
2142
const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
2120
2143
const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
2121
- // Skip clean up if there are no offload entries, unless `DeviceBinary`
2122
- // resulted from runtime compilation: Then, this is one of the `bfloat16`
2123
- // device libraries, so we want to make sure that the image and its exported
2124
- // symbols are removed from the program manager's maps.
2125
- if (EntriesB == EntriesE && !IsRTC)
2144
+ if (EntriesB == EntriesE)
2126
2145
continue ;
2127
2146
2128
2147
RTDeviceBinaryImage *Img = DevImgIt->second .get ();
@@ -2650,7 +2669,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
2650
2669
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
2651
2670
{
2652
2671
std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2653
- DepKernelIDs = m_BinImg2KernelIDs[Dep];
2672
+ // For device library images, they are not in m_BinImg2KernelIDs since
2673
+ // no kernel is included.
2674
+ auto DepIt = m_BinImg2KernelIDs.find (Dep);
2675
+ if (DepIt != m_BinImg2KernelIDs.end ())
2676
+ DepKernelIDs = DepIt->second ;
2654
2677
}
2655
2678
2656
2679
assert (ImgInfoPair.second .State == getBinImageState (Dep) &&
@@ -2863,9 +2886,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
2863
2886
for (const device_image_plain &Img : Imgs) {
2864
2887
std::shared_ptr<device_image_impl> DeviceImageImpl = getSyclObjImpl (Img);
2865
2888
// Duplicates are not expected here, otherwise urProgramLink should fail
2866
- KernelIDs.insert (KernelIDs.end (),
2867
- DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2868
- DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2889
+ if (DeviceImageImpl->get_kernel_ids_ptr ())
2890
+ KernelIDs.insert (KernelIDs.end (),
2891
+ DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2892
+ DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2869
2893
// To be able to answer queries about specialziation constants, the new
2870
2894
// device image should have the specialization constants from all the linked
2871
2895
// images.
0 commit comments