Skip to content

[SYCL][RTC] Adopt recent changes from sycl-post-link #17447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 92 additions & 33 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,19 @@ static bool getDeviceLibraries(const ArgList &Args,
return FoundUnknownLib;
}

static Expected<std::unique_ptr<llvm::Module>>
loadBitcodeLibrary(StringRef LibPath, LLVMContext &Context) {
SMDiagnostic Diag;
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
if (!Lib) {
std::string DiagMsg;
raw_string_ostream SOS(DiagMsg);
Diag.print(/*ProgName=*/nullptr, SOS);
return createStringError(DiagMsg);
}
return std::move(Lib);
}

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList,
std::string &BuildLog) {
Expand Down Expand Up @@ -558,16 +571,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
for (const std::string &LibName : LibNames) {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;

SMDiagnostic Diag;
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
if (!Lib) {
std::string DiagMsg;
raw_string_ostream SOS(DiagMsg);
Diag.print(/*ProgName=*/nullptr, SOS);
return createStringError(DiagMsg);
auto LibOrErr = loadBitcodeLibrary(LibPath, Context);
if (!LibOrErr) {
return LibOrErr.takeError();
}

if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) {
if (Linker::linkModules(Module, std::move(*LibOrErr),
Linker::LinkOnlyNeeded)) {
return createStringError("Unable to link device library %s: %s",
LibPath.c_str(), BuildLog.c_str());
}
Expand Down Expand Up @@ -607,6 +617,31 @@ static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) {
return SPLIT_AUTO;
}

static void encodeProperties(PropertySetRegistry &Properties,
RTCDevImgInfo &DevImgInfo) {
const auto &PropertySets = Properties.getPropSets();

DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
for (auto [KV, FrozenPropSet] :
zip_equal(PropertySets, DevImgInfo.Properties)) {
const auto &PropertySetName = KV.first;
const auto &PropertySet = KV.second;
FrozenPropSet =
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
for (auto [KV2, FrozenProp] :
zip_equal(PropertySet, FrozenPropSet.Values)) {
const auto &PropertyName = KV2.first;
const auto &PropertyValue = KV2.second;
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
? FrozenPropertyValue{PropertyName.str(),
PropertyValue.asUint32()}
: FrozenPropertyValue{
PropertyName.str(), PropertyValue.asRawByteArray(),
PropertyValue.getRawByteArraySize()};
}
};
}

Expected<PostLinkResult>
jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
const InputArgList &UserArgList) {
Expand Down Expand Up @@ -637,9 +672,9 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
// `removeDeviceGlobalFromCompilerUsed` methods.

assert(!isModuleUsingAsan(*Module));
// Otherwise: Need to instrument each image scope device globals if the module
// has been instrumented by sanitizer pass.
assert(!(isModuleUsingAsan(*Module) || isModuleUsingMsan(*Module) ||
isModuleUsingTsan(*Module)));
// Otherwise: Run `SanitizerKernelMetadataPass`.

// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
// LLVM IR specification.
Expand Down Expand Up @@ -668,6 +703,7 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
// `-fno-sycl-device-code-split-esimd` as a prerequisite for compiling
// `invoke_simd` code.

bool IsBF16DeviceLibUsed = false;
while (Splitter->hasMoreSplits()) {
ModuleDesc MDesc = Splitter->nextSplit();

Expand Down Expand Up @@ -701,35 +737,58 @@ jit_compiler::performPostLink(std::unique_ptr<llvm::Module> Module,
/*DeviceGlobals=*/false};
PropertySetRegistry Properties =
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);

// When the split mode is none, the required work group size will be added
// to the whole module, which will make the runtime unable to launch the
// other kernels in the module that have different required work group
// sizes or no required work group sizes. So we need to remove the
// required work group size metadata in this case.
if (SplitMode == module_split::SPLIT_NONE) {
Properties.remove(PropSetRegTy::SYCL_DEVICE_REQUIREMENTS,
PropSetRegTy::PROPERTY_REQD_WORK_GROUP_SIZE);
}

// TODO: Manually add `compile_target` property as in
// `saveModuleProperties`?
const auto &PropertySets = Properties.getPropSets();

DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
for (auto [KV, FrozenPropSet] :
zip_equal(PropertySets, DevImgInfo.Properties)) {
const auto &PropertySetName = KV.first;
const auto &PropertySet = KV.second;
FrozenPropSet =
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
for (auto [KV2, FrozenProp] :
zip_equal(PropertySet, FrozenPropSet.Values)) {
const auto &PropertyName = KV2.first;
const auto &PropertyValue = KV2.second;
FrozenProp =
PropertyValue.getType() == PropertyValue::Type::UINT32
? FrozenPropertyValue{PropertyName.str(),
PropertyValue.asUint32()}
: FrozenPropertyValue{PropertyName.str(),
PropertyValue.asRawByteArray(),
PropertyValue.getRawByteArraySize()};
}
};

encodeProperties(Properties, DevImgInfo);

IsBF16DeviceLibUsed |= isSYCLDeviceLibBF16Used(MDesc.getModule());
Modules.push_back(MDesc.releaseModulePtr());
}
}

if (IsBF16DeviceLibUsed) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
}

auto &Ctx = Modules.front()->getContext();
auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
auto LibOrErr = loadBitcodeLibrary(LibPath, Ctx);
if (!LibOrErr) {
return LibOrErr.takeError();
}

std::unique_ptr<llvm::Module> LibModule = std::move(*LibOrErr);
PropertySetRegistry Properties =
computeDeviceLibProperties(*LibModule, LibName);
encodeProperties(Properties, DevImgInfoVec.emplace_back());
Modules.push_back(std::move(LibModule));

return Error::success();
};

if (auto Err = WrapLibraryInDevImg("libsycl-fallback-bfloat16.bc")) {
return std::move(Err);
}
if (auto Err = WrapLibraryInDevImg("libsycl-native-bfloat16.bc")) {
return std::move(Err);
}
}

assert(DevImgInfoVec.size() == Modules.size());
RTCBundleInfo BundleInfo;
BundleInfo.DevImgInfos = DynArray<RTCDevImgInfo>{DevImgInfoVec.size()};
Expand Down
46 changes: 39 additions & 7 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1837,13 +1837,17 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
return {};
}

static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
// For bfloat16 device library image, we should keep it. However, in some
// scenario, __sycl_register_lib can be called multiple times and the same
// bfloat16 device library image may be handled multiple times which is not
// needed. 2 static bool variables are created to record whether native or
// fallback bfloat16 device library image has been handled, if yes, we just
// need to skip it.
// We cannot prevent redundant loads of device library images if they are part
// of a runtime-compiled device binary, as these will be freed when the
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
// on the presence of RTC device library images.
sycl_device_binary_property_set ImgPS;
static bool IsNativeBF16DeviceLibHandled = false;
static bool IsFallbackBF16DeviceLibHandled = false;
Expand All @@ -1861,8 +1865,13 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
if (ImgP == ImgPS->PropertiesEnd)
return true;

// A valid bfloat16 device library image is found here, need to check
// wheter it has been handled already.
// A valid bfloat16 device library image is found here.
// If it originated from RTC, we cannot skip it, but do not mark it as
// being present.
if (IsRTC)
return false;

// Otherwise, we need to check whether it has been handled already.
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
Expand All @@ -1879,14 +1888,33 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
return true;
}

static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
// Check whether the first device binary contains a legacy format offload
// entry with a `$` in its name.
if (DeviceBinary->NumDeviceBinaries > 0) {
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
if (Binary->EntriesBegin != Binary->EntriesEnd) {
sycl_offload_entry Entry = Binary->EntriesBegin;
if (!Entry->IsNewOffloadEntryType() &&
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
return true;
}
}
}
return false;
}

void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// Treat the image as empty one
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
// If the image does not contain kernels, skip it unless it is one of the
// bfloat16 device libraries, and it wasn't loaded before or resulted from
// runtime compilation.
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
continue;

std::unique_ptr<RTDeviceBinaryImage> Img;
Expand Down Expand Up @@ -2081,15 +2109,19 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
}

void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
auto DevImgIt = m_DeviceImages.find(RawImg);
if (DevImgIt == m_DeviceImages.end())
continue;
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// Treat the image as empty one
if (EntriesB == EntriesE)
// Skip clean up if there are no offload entries, unless `DeviceBinary`
// resulted from runtime compilation: Then, this is one of the `bfloat16`
// device libraries, so we want to make sure that the image and its exported
// symbols are removed from the program manager's maps.
if (EntriesB == EntriesE && !IsRTC)
continue;

RTDeviceBinaryImage *Img = DevImgIt->second.get();
Expand Down
12 changes: 9 additions & 3 deletions sycl/test-e2e/KernelCompiler/sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@ void device_libs_kernel(float *ptr) {

// cl_intel_devicelib_imf
ptr[3] = sycl::ext::intel::math::sqrt(ptr[3] * 2);

// cl_intel_devicelib_imf_bf16
ptr[4] = sycl::ext::intel::math::float2bfloat16(ptr[4] * 0.5f);

// cl_intel_devicelib_bfloat16
ptr[5] = sycl::ext::oneapi::bfloat16{ptr[5] / 0.25f};
}
)===";

Expand Down Expand Up @@ -435,7 +441,7 @@ int test_device_libraries() {
exe_kb kbExe = syclex::build(kbSrc);

sycl::kernel k = kbExe.ext_oneapi_get_kernel("device_libs_kernel");
constexpr size_t nElem = 4;
constexpr size_t nElem = 6;
float *ptr = sycl::malloc_shared<float>(nElem, q);
for (int i = 0; i < nElem; ++i)
ptr[i] = 1.0f;
Expand All @@ -446,8 +452,8 @@ int test_device_libraries() {
});
q.wait_and_throw();

// Check that the kernel was executed. Given the {1.0, 1.0, 1.0, 1.0} input,
// the expected result is approximately {0.84, 1.41, 0.0, 1.41}.
// Check that the kernel was executed. Given the {1.0, ..., 1.0} input,
// the expected result is approximately {0.84, 1.41, 0.0, 1.41, 0.5, 4.0}.
for (unsigned i = 0; i < nElem; ++i) {
std::cout << ptr[i] << ' ';
assert(ptr[i] != 1.0f);
Expand Down