Skip to content

Commit e7c54c3

Browse files
[SYCL] Prioritize set kernels over lookup (#18157)
The current implementation of SYCL kernel launches prioritizes looking up kernels through the kernel bundles rather than using the set kernel. These changes instead prioritizes using the kernel, which not only saves the look-up overhead and fixes a kernel implementation lifetime issue caused by #17380. Signed-off-by: Larsen, Steffen <[email protected]>
1 parent c19e176 commit e7c54c3

File tree

3 files changed

+42
-42
lines changed

3 files changed

+42
-42
lines changed

sycl/source/detail/graph_impl.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,18 +1474,18 @@ void exec_graph_impl::populateURKernelUpdateStructs(
14741474
ur_kernel_handle_t UrKernel = nullptr;
14751475
auto Kernel = ExecCG.MSyclKernel;
14761476
auto KernelBundleImplPtr = ExecCG.MKernelBundle;
1477-
std::shared_ptr<sycl::detail::kernel_impl> SyclKernelImpl = nullptr;
14781477
const sycl::detail::KernelArgMask *EliminatedArgMask = nullptr;
14791478

1480-
if (auto SyclKernelImpl = KernelBundleImplPtr
1481-
? KernelBundleImplPtr->tryGetKernel(
1482-
ExecCG.MKernelName, KernelBundleImplPtr)
1483-
: std::shared_ptr<kernel_impl>{nullptr}) {
1484-
UrKernel = SyclKernelImpl->getHandleRef();
1485-
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
1486-
} else if (Kernel != nullptr) {
1479+
if (Kernel != nullptr) {
14871480
UrKernel = Kernel->getHandleRef();
14881481
EliminatedArgMask = Kernel->getKernelArgMask();
1482+
} else if (auto SyclKernelImpl =
1483+
KernelBundleImplPtr
1484+
? KernelBundleImplPtr->tryGetKernel(ExecCG.MKernelName,
1485+
KernelBundleImplPtr)
1486+
: std::shared_ptr<kernel_impl>{nullptr}) {
1487+
UrKernel = SyclKernelImpl->getHandleRef();
1488+
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
14891489
} else {
14901490
ur_program_handle_t UrProgram = nullptr;
14911491
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =

sycl/source/detail/helpers.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ retrieveKernelBinary(const QueueImplPtr &Queue, const char *KernelName,
7272
const RTDeviceBinaryImage *DeviceImage = nullptr;
7373
ur_program_handle_t Program = nullptr;
7474
auto KernelBundleImpl = KernelCG->getKernelBundle();
75-
if (auto SyclKernelImpl =
76-
KernelBundleImpl
77-
? KernelBundleImpl->tryGetKernel(KernelName, KernelBundleImpl)
78-
: std::shared_ptr<kernel_impl>{nullptr}) {
75+
if (KernelCG->MSyclKernel != nullptr) {
76+
DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref();
77+
Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref();
78+
} else if (auto SyclKernelImpl =
79+
KernelBundleImpl ? KernelBundleImpl->tryGetKernel(
80+
KernelName, KernelBundleImpl)
81+
: std::shared_ptr<kernel_impl>{nullptr}) {
7982
// Retrieve the device image from the kernel bundle.
8083
DeviceImage = SyclKernelImpl->getDeviceImage()->get_bin_image_ref();
8184
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
82-
} else if (KernelCG->MSyclKernel != nullptr) {
83-
DeviceImage = KernelCG->MSyclKernel->getDeviceImage()->get_bin_image_ref();
84-
Program = KernelCG->MSyclKernel->getDeviceImage()->get_ur_program_ref();
8585
} else {
8686
auto ContextImpl = Queue->getContextImplPtr();
8787
auto DeviceImpl = Queue->getDeviceImplPtr();

sycl/source/detail/scheduler/commands.cpp

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,16 +1997,16 @@ void instrumentationAddExtraKernelMetadata(
19971997
std::mutex *KernelMutex = nullptr;
19981998
const KernelArgMask *EliminatedArgMask = nullptr;
19991999

2000-
if (auto SyclKernelImpl = KernelBundleImplPtr
2001-
? KernelBundleImplPtr->tryGetKernel(
2002-
KernelName, KernelBundleImplPtr)
2003-
: std::shared_ptr<kernel_impl>{nullptr}) {
2004-
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
2005-
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
2006-
} else if (nullptr != SyclKernel) {
2000+
if (nullptr != SyclKernel) {
20072001
Program = SyclKernel->getProgramRef();
20082002
if (!SyclKernel->isCreatedFromSource())
20092003
EliminatedArgMask = SyclKernel->getKernelArgMask();
2004+
} else if (auto SyclKernelImpl =
2005+
KernelBundleImplPtr ? KernelBundleImplPtr->tryGetKernel(
2006+
KernelName, KernelBundleImplPtr)
2007+
: std::shared_ptr<kernel_impl>{nullptr}) {
2008+
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
2009+
Program = SyclKernelImpl->getDeviceImage()->get_ur_program_ref();
20102010
} else if (Queue) {
20112011
// NOTE: Queue can be null when kernel is directly enqueued to a command
20122012
// buffer
@@ -2521,17 +2521,17 @@ getCGKernelInfo(const CGExecKernel &CommandGroup, ContextImplPtr ContextImpl,
25212521
const KernelArgMask *EliminatedArgMask = nullptr;
25222522
auto &KernelBundleImplPtr = CommandGroup.MKernelBundle;
25232523

2524-
if (auto SyclKernelImpl =
2525-
KernelBundleImplPtr
2526-
? KernelBundleImplPtr->tryGetKernel(CommandGroup.MKernelName,
2527-
KernelBundleImplPtr)
2528-
: std::shared_ptr<kernel_impl>{nullptr}) {
2524+
if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) {
2525+
UrKernel = Kernel->getHandleRef();
2526+
EliminatedArgMask = Kernel->getKernelArgMask();
2527+
} else if (auto SyclKernelImpl =
2528+
KernelBundleImplPtr
2529+
? KernelBundleImplPtr->tryGetKernel(
2530+
CommandGroup.MKernelName, KernelBundleImplPtr)
2531+
: std::shared_ptr<kernel_impl>{nullptr}) {
25292532
UrKernel = SyclKernelImpl->getHandleRef();
25302533
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
25312534
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
2532-
} else if (auto Kernel = CommandGroup.MSyclKernel; Kernel != nullptr) {
2533-
UrKernel = Kernel->getHandleRef();
2534-
EliminatedArgMask = Kernel->getKernelArgMask();
25352535
} else {
25362536
ur_program_handle_t UrProgram = nullptr;
25372537
std::tie(UrKernel, std::ignore, EliminatedArgMask, UrProgram) =
@@ -2678,18 +2678,7 @@ void enqueueImpKernel(
26782678
std::shared_ptr<kernel_impl> SyclKernelImpl;
26792679
std::shared_ptr<device_image_impl> DeviceImageImpl;
26802680

2681-
if ((SyclKernelImpl = KernelBundleImplPtr
2682-
? KernelBundleImplPtr->tryGetKernel(
2683-
KernelName, KernelBundleImplPtr)
2684-
: std::shared_ptr<kernel_impl>{nullptr})) {
2685-
Kernel = SyclKernelImpl->getHandleRef();
2686-
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
2687-
2688-
Program = DeviceImageImpl->get_ur_program_ref();
2689-
2690-
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
2691-
KernelMutex = SyclKernelImpl->getCacheMutex();
2692-
} else if (nullptr != MSyclKernel) {
2681+
if (nullptr != MSyclKernel) {
26932682
assert(MSyclKernel->get_info<info::kernel::context>() ==
26942683
Queue->get_context());
26952684
Kernel = MSyclKernel->getHandleRef();
@@ -2703,6 +2692,17 @@ void enqueueImpKernel(
27032692
// their duplication in such cases.
27042693
KernelMutex = &MSyclKernel->getNoncacheableEnqueueMutex();
27052694
EliminatedArgMask = MSyclKernel->getKernelArgMask();
2695+
} else if ((SyclKernelImpl = KernelBundleImplPtr
2696+
? KernelBundleImplPtr->tryGetKernel(
2697+
KernelName, KernelBundleImplPtr)
2698+
: std::shared_ptr<kernel_impl>{nullptr})) {
2699+
Kernel = SyclKernelImpl->getHandleRef();
2700+
DeviceImageImpl = SyclKernelImpl->getDeviceImage();
2701+
2702+
Program = DeviceImageImpl->get_ur_program_ref();
2703+
2704+
EliminatedArgMask = SyclKernelImpl->getKernelArgMask();
2705+
KernelMutex = SyclKernelImpl->getCacheMutex();
27062706
} else {
27072707
std::tie(Kernel, KernelMutex, EliminatedArgMask, Program) =
27082708
detail::ProgramManager::getInstance().getOrCreateKernel(

0 commit comments

Comments
 (0)