Skip to content

Commit 2fdf191

Browse files
jtb20jprotze
authored andcommitted
[OpenMP] Fix crash with task stealing and task dependencies (#126049)
This patch series demonstrates and fixes a bug that causes crashes with OpenMP 'taskwait' directives in heavily multi-threaded scenarios. TLDR: The early return from __kmpc_omp_taskwait_deps_51 missed the synchronization mechanism in place for the late return. Additional debug assertions check for the implied invariants of the code. @jpeyton52 found the timing hole as this sequence of events: > > 1. THREAD 1: A regular task with dependences is created, call it T1 > 2. THREAD 1: Call into `__kmpc_omp_taskwait_deps_51()` and create a stack based depnode (`NULL` task), call it T2 (stack) > 3. THREAD 2: Steals task T1 and executes it getting to `__kmp_release_deps()` region. > 4. THREAD 1: During processing of dependences for T2 (stack) (within `__kmp_check_deps()` region), a link is created T1 -> T2. This increases T2's (stack) `nrefs` count. > 5. THREAD 2: Iterates through the successors list: decrement the T2's (stack) npredecessor count. BUT HASN'T YET `__kmp_node_deref()`-ed it. > 6. THREAD 1: Now when finished with `__kmp_check_deps()`, it returns false because npredecessor count is 0, but T2's (stack) `nrefs` count is 2 because THREAD 2 still references it! > 7. THREAD 1: Because `__kmp_check_deps()` returns false, early exit. > _Now the stack based depnode is invalid, but THREAD 2 still references it._ > > We've reached improper stack referencing behavior. Varied results/crashes/ asserts can occur if THREAD 1 comes back and recreates the exact same depnode in the exact same stack address during the same time THREAD 2 calls `__kmp_node_deref()`.
1 parent 2ad1089 commit 2fdf191

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

openmp/runtime/src/kmp_taskdeps.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,19 @@
3333
static std::atomic<kmp_int32> kmp_node_id_seed = 0;
3434
#endif
3535

36-
static void __kmp_init_node(kmp_depnode_t *node) {
36+
static void __kmp_init_node(kmp_depnode_t *node, bool on_stack) {
3737
node->dn.successors = NULL;
3838
node->dn.task = NULL; // will point to the right task
3939
// once dependences have been processed
4040
for (int i = 0; i < MAX_MTX_DEPS; ++i)
4141
node->dn.mtx_locks[i] = NULL;
4242
node->dn.mtx_num_locks = 0;
4343
__kmp_init_lock(&node->dn.lock);
44-
KMP_ATOMIC_ST_RLX(&node->dn.nrefs, 1); // init creates the first reference
44+
// Init creates the first reference. Bit 0 indicates that this node
45+
// resides on the stack. The refcount is incremented and decremented in
46+
// steps of two, maintaining use of even numbers for heap nodes and odd
47+
// numbers for stack nodes.
48+
KMP_ATOMIC_ST_RLX(&node->dn.nrefs, on_stack ? 3 : 2);
4549
#ifdef KMP_SUPPORT_GRAPH_OUTPUT
4650
node->dn.id = KMP_ATOMIC_INC(&kmp_node_id_seed);
4751
#endif
@@ -51,7 +55,7 @@ static void __kmp_init_node(kmp_depnode_t *node) {
5155
}
5256

5357
static inline kmp_depnode_t *__kmp_node_ref(kmp_depnode_t *node) {
54-
KMP_ATOMIC_INC(&node->dn.nrefs);
58+
KMP_ATOMIC_ADD(&node->dn.nrefs, 2);
5559
return node;
5660
}
5761

@@ -825,7 +829,7 @@ kmp_int32 __kmpc_omp_task_with_deps(ident_t *loc_ref, kmp_int32 gtid,
825829
(kmp_depnode_t *)__kmp_thread_malloc(thread, sizeof(kmp_depnode_t));
826830
#endif
827831

828-
__kmp_init_node(node);
832+
__kmp_init_node(node, /*on_stack=*/false);
829833
new_taskdata->td_depnode = node;
830834

831835
if (__kmp_check_deps(gtid, node, new_task, &current_task->td_dephash,
@@ -1007,7 +1011,7 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
10071011
}
10081012

10091013
kmp_depnode_t node = {0};
1010-
__kmp_init_node(&node);
1014+
__kmp_init_node(&node, /*on_stack=*/true);
10111015

10121016
if (!__kmp_check_deps(gtid, &node, NULL, &current_task->td_dephash,
10131017
DEP_BARRIER, ndeps, dep_list, ndeps_noalias,
@@ -1018,6 +1022,16 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
10181022
#if OMPT_SUPPORT
10191023
__ompt_taskwait_dep_finish(current_task, taskwait_task_data);
10201024
#endif /* OMPT_SUPPORT */
1025+
1026+
// There may still be references to this node here, due to task stealing.
1027+
// Wait for them to be released.
1028+
kmp_int32 nrefs;
1029+
while ((nrefs = node.dn.nrefs) > 3) {
1030+
KMP_DEBUG_ASSERT((nrefs & 1) == 1);
1031+
KMP_YIELD(TRUE);
1032+
}
1033+
KMP_DEBUG_ASSERT(nrefs == 3);
1034+
10211035
return;
10221036
}
10231037

@@ -1032,9 +1046,15 @@ void __kmpc_omp_taskwait_deps_51(ident_t *loc_ref, kmp_int32 gtid,
10321046

10331047
// Wait until the last __kmp_release_deps is finished before we free the
10341048
// current stack frame holding the "node" variable; once its nrefs count
1035-
// reaches 1, we're sure nobody else can try to reference it again.
1036-
while (node.dn.nrefs > 1)
1049+
// reaches 3 (meaning 1, since bit zero of the refcount indicates a stack
1050+
// rather than a heap address), we're sure nobody else can try to reference
1051+
// it again.
1052+
kmp_int32 nrefs;
1053+
while ((nrefs = node.dn.nrefs) > 3) {
1054+
KMP_DEBUG_ASSERT((nrefs & 1) == 1);
10371055
KMP_YIELD(TRUE);
1056+
}
1057+
KMP_DEBUG_ASSERT(nrefs == 3);
10381058

10391059
#if OMPT_SUPPORT
10401060
__ompt_taskwait_dep_finish(current_task, taskwait_task_data);

openmp/runtime/src/kmp_taskdeps.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@ static inline void __kmp_node_deref(kmp_info_t *thread, kmp_depnode_t *node) {
2222
if (!node)
2323
return;
2424

25-
kmp_int32 n = KMP_ATOMIC_DEC(&node->dn.nrefs) - 1;
25+
kmp_int32 n = KMP_ATOMIC_SUB(&node->dn.nrefs, 2) - 2;
2626
KMP_DEBUG_ASSERT(n >= 0);
27-
if (n == 0) {
27+
if ((n & ~1) == 0) {
2828
#if USE_ITT_BUILD && USE_ITT_NOTIFY
2929
__itt_sync_destroy(node);
3030
#endif
31+
// These two assertions are somewhat redundant. The first is intended to
32+
// detect if we are trying to free a depnode on the stack.
33+
KMP_DEBUG_ASSERT((node->dn.nrefs & 1) == 0);
3134
KMP_ASSERT(node->dn.nrefs == 0);
3235
#if USE_FAST_MEMORY
3336
__kmp_fast_free(thread, node);

0 commit comments

Comments
 (0)