Skip to content

[OpenMP] Fix work-stealing stack clobber with taskwait #126049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions openmp/runtime/src/kmp_taskdeps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,19 @@
static std::atomic<kmp_int32> kmp_node_id_seed = 0;
#endif

static void __kmp_init_node(kmp_depnode_t *node) {
static void __kmp_init_node(kmp_depnode_t *node, bool on_stack) {
node->dn.successors = NULL;
node->dn.task = NULL; // will point to the right task
// once dependences have been processed
for (int i = 0; i < MAX_MTX_DEPS; ++i)
node->dn.mtx_locks[i] = NULL;
node->dn.mtx_num_locks = 0;
__kmp_init_lock(&node->dn.lock);
KMP_ATOMIC_ST_RLX(&node->dn.nrefs, 1); // init creates the first reference
// Init creates the first reference. Bit 0 indicates that this node
// resides on the stack. The refcount is incremented and decremented in
// steps of two, maintaining use of even numbers for heap nodes and odd
// numbers for stack nodes.
KMP_ATOMIC_ST_RLX(&node->dn.nrefs, on_stack ? 3 : 2);
#ifdef KMP_SUPPORT_GRAPH_OUTPUT
node->dn.id = KMP_ATOMIC_INC(&kmp_node_id_seed);
#endif
Expand All @@ -51,7 +55,7 @@ static void __kmp_init_node(kmp_depnode_t *node) {
}

static inline kmp_depnode_t *__kmp_node_ref(kmp_depnode_t *node) {
KMP_ATOMIC_INC(&node->dn.nrefs);
KMP_ATOMIC_ADD(&node->dn.nrefs, 2);
return node;
}

Expand Down Expand Up @@ -825,7 +829,7 @@ kmp_int32 __kmpc_omp_task_with_deps(ident_t *loc_ref, kmp_int32 gtid,
(kmp_depnode_t *)__kmp_thread_malloc(thread, sizeof(kmp_depnode_t));
#endif

__kmp_init_node(node);
__kmp_init_node(node, /*on_stack=*/false);
new_taskdata->td_depnode = node;

if (__kmp_check_deps(gtid, node, new_task, &current_task->td_dephash,
Expand Down Expand Up @@ -1007,7 +1011,7 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
}

kmp_depnode_t node = {0};
__kmp_init_node(&node);
__kmp_init_node(&node, /*on_stack=*/true);

if (!__kmp_check_deps(gtid, &node, NULL, &current_task->td_dephash,
DEP_BARRIER, ndeps, dep_list, ndeps_noalias,
Expand All @@ -1018,6 +1022,16 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
#if OMPT_SUPPORT
__ompt_taskwait_dep_finish(current_task, taskwait_task_data);
#endif /* OMPT_SUPPORT */

// There may still be references to this node here, due to task stealing.
// Wait for them to be released.
kmp_int32 nrefs;
while ((nrefs = node.dn.nrefs) > 3) {
KMP_DEBUG_ASSERT((nrefs & 1) == 1);
KMP_YIELD(TRUE);
}
KMP_DEBUG_ASSERT(nrefs == 3);

return;
}

Expand All @@ -1032,9 +1046,15 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,

// Wait until the last __kmp_release_deps is finished before we free the
// current stack frame holding the "node" variable; once its nrefs count
// reaches 1, we're sure nobody else can try to reference it again.
while (node.dn.nrefs > 1)
// reaches 3 (meaning 1, since bit zero of the refcount indicates a stack
// rather than a heap address), we're sure nobody else can try to reference
// it again.
kmp_int32 nrefs;
while ((nrefs = node.dn.nrefs) > 3) {
KMP_DEBUG_ASSERT((nrefs & 1) == 1);
KMP_YIELD(TRUE);
}
KMP_DEBUG_ASSERT(nrefs == 3);

#if OMPT_SUPPORT
__ompt_taskwait_dep_finish(current_task, taskwait_task_data);
Expand Down
7 changes: 5 additions & 2 deletions openmp/runtime/src/kmp_taskdeps.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ static inline void __kmp_node_deref(kmp_info_t *thread, kmp_depnode_t *node) {
if (!node)
return;

kmp_int32 n = KMP_ATOMIC_DEC(&node->dn.nrefs) - 1;
kmp_int32 n = KMP_ATOMIC_SUB(&node->dn.nrefs, 2) - 2;
KMP_DEBUG_ASSERT(n >= 0);
if (n == 0) {
if ((n & ~1) == 0) {
#if USE_ITT_BUILD && USE_ITT_NOTIFY
__itt_sync_destroy(node);
#endif
// These two assertions are somewhat redundant. The first is intended to
// detect if we are trying to free a depnode on the stack.
KMP_DEBUG_ASSERT((node->dn.nrefs & 1) == 0);
KMP_ASSERT(node->dn.nrefs == 0);
#if USE_FAST_MEMORY
__kmp_fast_free(thread, node);
Expand Down