Skip to content

Commit 7ab08c3

Browse files
committed
Add thread-local values which are preserved with jobs
1 parent 87dfd36 commit 7ab08c3

File tree

7 files changed

+70
-7
lines changed

7 files changed

+70
-7
lines changed

rayon-core/src/job.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::latch::Latch;
2+
use crate::tlv;
23
use crate::unwind;
34
use crossbeam_queue::SegQueue;
45
use std::any::Any;
@@ -73,6 +74,7 @@ where
7374
pub(super) latch: L,
7475
func: UnsafeCell<Option<F>>,
7576
result: UnsafeCell<JobResult<R>>,
77+
tlv: usize,
7678
}
7779

7880
impl<L, F, R> StackJob<L, F, R>
@@ -81,11 +83,12 @@ where
8183
F: FnOnce(bool) -> R + Send,
8284
R: Send,
8385
{
84-
pub(super) fn new(func: F, latch: L) -> StackJob<L, F, R> {
86+
pub(super) fn new(tlv: usize, func: F, latch: L) -> StackJob<L, F, R> {
8587
StackJob {
8688
latch,
8789
func: UnsafeCell::new(Some(func)),
8890
result: UnsafeCell::new(JobResult::None),
91+
tlv,
8992
}
9093
}
9194

@@ -114,6 +117,7 @@ where
114117
}
115118

116119
let this = &*this;
120+
tlv::set(this.tlv);
117121
let abort = unwind::AbortIfPanic;
118122
let func = (*this.func.get()).take().unwrap();
119123
(*this.result.get()) = match unwind::halt_unwinding(call(func)) {
@@ -136,15 +140,17 @@ where
136140
BODY: FnOnce() + Send,
137141
{
138142
job: UnsafeCell<Option<BODY>>,
143+
tlv: usize,
139144
}
140145

141146
impl<BODY> HeapJob<BODY>
142147
where
143148
BODY: FnOnce() + Send,
144149
{
145-
pub(super) fn new(func: BODY) -> Self {
150+
pub(super) fn new(tlv: usize, func: BODY) -> Self {
146151
HeapJob {
147152
job: UnsafeCell::new(Some(func)),
153+
tlv,
148154
}
149155
}
150156

@@ -163,6 +169,7 @@ where
163169
{
164170
unsafe fn execute(this: *const Self) {
165171
let this: Box<Self> = mem::transmute(this);
172+
tlv::set(this.tlv);
166173
let job = (*this.job.get()).take().unwrap();
167174
job();
168175
}

rayon-core/src/join/mod.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::job::StackJob;
22
use crate::latch::{LatchProbe, SpinLatch};
33
use crate::log::Event::*;
44
use crate::registry::{self, WorkerThread};
5+
use crate::tlv;
56
use crate::unwind;
67
use std::any::Any;
78

@@ -135,18 +136,19 @@ where
135136
worker: worker_thread.index()
136137
});
137138

139+
let tlv = tlv::get();
138140
// Create virtual wrapper for task b; this all has to be
139141
// done here so that the stack frame can keep it all live
140142
// long enough.
141-
let job_b = StackJob::new(call_b(oper_b), SpinLatch::new());
143+
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new());
142144
let job_b_ref = job_b.as_job_ref();
143145
worker_thread.push(job_b_ref);
144146

145147
// Execute task a; hopefully b gets stolen in the meantime.
146148
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
147149
let result_a = match status_a {
148150
Ok(v) => v,
149-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err),
151+
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
150152
};
151153

152154
// Now that task A has finished, try to pop job B from the
@@ -163,6 +165,9 @@ where
163165
log!(PoppedRhs {
164166
worker: worker_thread.index()
165167
});
168+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
169+
tlv::set(tlv);
170+
166171
let result_b = job_b.run_inline(injected);
167172
return (result_a, result_b);
168173
} else {
@@ -183,6 +188,9 @@ where
183188
}
184189
}
185190

191+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
192+
tlv::set(tlv);
193+
186194
(result_a, job_b.into_result())
187195
})
188196
}
@@ -195,7 +203,12 @@ unsafe fn join_recover_from_panic(
195203
worker_thread: &WorkerThread,
196204
job_b_latch: &SpinLatch,
197205
err: Box<dyn Any + Send>,
206+
tlv: usize,
198207
) -> ! {
199208
worker_thread.wait_until(job_b_latch);
209+
210+
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
211+
tlv::set(tlv);
212+
200213
unwind::resume_unwinding(err)
201214
}

rayon-core/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ mod util;
5151
mod compile_fail;
5252
mod test;
5353

54+
pub mod tlv;
55+
5456
pub use self::join::{join, join_context};
5557
pub use self::registry::ThreadBuilder;
5658
pub use self::scope::{scope, Scope};

rayon-core/src/registry.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,7 @@ impl Registry {
441441
// This thread isn't a member of *any* thread pool, so just block.
442442
debug_assert!(WorkerThread::current().is_null());
443443
let job = StackJob::new(
444+
0,
444445
|injected| {
445446
let worker_thread = WorkerThread::current();
446447
assert!(injected && !worker_thread.is_null());
@@ -465,6 +466,7 @@ impl Registry {
465466
debug_assert!(current_thread.registry().id() != self.id());
466467
let latch = TickleLatch::new(SpinLatch::new(), &current_thread.registry().sleep);
467468
let job = StackJob::new(
469+
0,
468470
|injected| {
469471
let worker_thread = WorkerThread::current();
470472
assert!(injected && !worker_thread.is_null());

rayon-core/src/scope/mod.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use crate::job::{HeapJob, JobFifo};
88
use crate::latch::{CountLatch, Latch};
99
use crate::log::Event::*;
1010
use crate::registry::{in_worker, Registry, WorkerThread};
11+
use crate::tlv;
1112
use crate::unwind;
1213
use std::any::Any;
1314
use std::fmt;
@@ -59,6 +60,9 @@ struct ScopeBase<'scope> {
5960
/// `Sync`, but it's still safe to let the `Scope` implement `Sync` because
6061
/// the closures are only *moved* across threads to be executed.
6162
marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
63+
64+
/// The TLV at the scope's creation. Used to set the TLV for spawned jobs.
65+
tlv: usize,
6266
}
6367

6468
/// Create a "fork-join" scope `s` and invokes the closure with a
@@ -451,7 +455,7 @@ impl<'scope> Scope<'scope> {
451455
{
452456
self.base.increment();
453457
unsafe {
454-
let job_ref = Box::new(HeapJob::new(move || {
458+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
455459
self.base.execute_job(move || body(self))
456460
}))
457461
.as_job_ref();
@@ -492,7 +496,7 @@ impl<'scope> ScopeFifo<'scope> {
492496
{
493497
self.base.increment();
494498
unsafe {
495-
let job_ref = Box::new(HeapJob::new(move || {
499+
let job_ref = Box::new(HeapJob::new(self.base.tlv, move || {
496500
self.base.execute_job(move || body(self))
497501
}))
498502
.as_job_ref();
@@ -519,6 +523,7 @@ impl<'scope> ScopeBase<'scope> {
519523
panic: AtomicPtr::new(ptr::null_mut()),
520524
job_completed_latch: CountLatch::new(),
521525
marker: PhantomData,
526+
tlv: tlv::get(),
522527
}
523528
}
524529

@@ -536,6 +541,8 @@ impl<'scope> ScopeBase<'scope> {
536541
{
537542
let result = self.execute_job_closure(func);
538543
self.steal_till_jobs_complete(owner_thread);
544+
// Restore the TLV if we ran some jobs while waiting
545+
tlv::set(self.tlv);
539546
result.unwrap() // only None if `op` panicked, and that would have been propagated
540547
}
541548

@@ -612,6 +619,8 @@ impl<'scope> ScopeBase<'scope> {
612619
log!(ScopeCompletePanicked {
613620
owner_thread: owner_thread.index()
614621
});
622+
// Restore the TLV if we ran some jobs while waiting
623+
tlv::set(self.tlv);
615624
let value: Box<Box<dyn Any + Send + 'static>> = mem::transmute(panic);
616625
unwind::resume_unwinding(*value);
617626
} else {

rayon-core/src/spawn/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ where
9292
// executed. This ref is decremented at the (*) below.
9393
registry.increment_terminate_count();
9494

95-
Box::new(HeapJob::new({
95+
Box::new(HeapJob::new(0, {
9696
let registry = registry.clone();
9797
move || {
9898
match unwind::halt_unwinding(func) {

rayon-core/src/tlv.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//! Allows access to the Rayon's thread local value
2+
//! which is preserved when moving jobs across threads
3+
4+
use std::cell::Cell;
5+
6+
thread_local!(pub(crate) static TLV: Cell<usize> = Cell::new(0));
7+
8+
/// Sets the current thread-local value to `value` inside the closure.
9+
/// The old value is restored when the closure ends
10+
pub fn with<F: FnOnce() -> R, R>(value: usize, f: F) -> R {
11+
struct Reset(usize);
12+
impl Drop for Reset {
13+
fn drop(&mut self) {
14+
TLV.with(|tlv| tlv.set(self.0));
15+
}
16+
}
17+
let _reset = Reset(get());
18+
TLV.with(|tlv| tlv.set(value));
19+
f()
20+
}
21+
22+
/// Sets the current thread-local value
23+
pub fn set(value: usize) {
24+
TLV.with(|tlv| tlv.set(value));
25+
}
26+
27+
/// Returns the current thread-local value
28+
pub fn get() -> usize {
29+
TLV.with(|tlv| tlv.get())
30+
}

0 commit comments

Comments
 (0)