Skip to content

Commit 7fdf1fd

Browse files
committed
Only work-steal in the main loop for join and scope
1 parent f192a48 commit 7fdf1fd

File tree

8 files changed

+127
-89
lines changed

8 files changed

+127
-89
lines changed

rayon-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ num_cpus = "1.2"
2222
crossbeam-channel = "0.5.0"
2323
crossbeam-deque = "0.8.1"
2424
crossbeam-utils = "0.8.0"
25+
smallvec = "1.11.0"
2526

2627
[dev-dependencies]
2728
rand = "0.8"

rayon-core/src/broadcast/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ where
117117
registry.inject_broadcast(job_refs);
118118

119119
// Wait for all jobs to complete, then collect the results, maybe propagating a panic.
120-
latch.wait(current_thread);
120+
latch.wait(current_thread, None);
121121
jobs.into_iter().map(|job| job.into_result()).collect()
122122
}
123123

rayon-core/src/job.rs

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ pub(super) trait Job {
2626
unsafe fn execute(this: *const ());
2727
}
2828

29+
#[derive(PartialEq, Eq, Hash, Copy, Clone)]
30+
pub(super) struct JobRefId {
31+
pointer: usize,
32+
}
33+
2934
/// Effectively a Job trait object. Each JobRef **must** be executed
3035
/// exactly once, or else data may leak.
3136
///
@@ -54,11 +59,11 @@ impl JobRef {
5459
}
5560
}
5661

57-
/// Returns an opaque handle that can be saved and compared,
58-
/// without making `JobRef` itself `Copy + Eq`.
5962
#[inline]
60-
pub(super) fn id(&self) -> impl Eq {
61-
(self.pointer, self.execute_fn)
63+
pub(super) fn id(&self) -> JobRefId {
64+
JobRefId {
65+
pointer: self.pointer as usize,
66+
}
6267
}
6368

6469
#[inline]
@@ -102,10 +107,6 @@ where
102107
JobRef::new(self)
103108
}
104109

105-
pub(super) unsafe fn run_inline(self, stolen: bool) -> R {
106-
self.func.into_inner().unwrap()(stolen)
107-
}
108-
109110
pub(super) unsafe fn into_result(self) -> R {
110111
self.result.into_inner().into_return_value()
111112
}
@@ -136,15 +137,15 @@ where
136137
/// (Probably `StackJob` should be refactored in a similar fashion.)
137138
pub(super) struct HeapJob<BODY>
138139
where
139-
BODY: FnOnce() + Send,
140+
BODY: FnOnce(JobRefId) + Send,
140141
{
141142
job: BODY,
142143
tlv: Tlv,
143144
}
144145

145146
impl<BODY> HeapJob<BODY>
146147
where
147-
BODY: FnOnce() + Send,
148+
BODY: FnOnce(JobRefId) + Send,
148149
{
149150
pub(super) fn new(tlv: Tlv, job: BODY) -> Box<Self> {
150151
Box::new(HeapJob { job, tlv })
@@ -168,12 +169,13 @@ where
168169

169170
impl<BODY> Job for HeapJob<BODY>
170171
where
171-
BODY: FnOnce() + Send,
172+
BODY: FnOnce(JobRefId) + Send,
172173
{
173174
unsafe fn execute(this: *const ()) {
175+
let pointer = this as usize;
174176
let this = Box::from_raw(this as *mut Self);
175177
tlv::set(this.tlv);
176-
(this.job)();
178+
(this.job)(JobRefId { pointer });
177179
}
178180
}
179181

rayon-core/src/join/mod.rs

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use crate::job::StackJob;
22
use crate::latch::SpinLatch;
3-
use crate::registry::{self, WorkerThread};
4-
use crate::tlv::{self, Tlv};
3+
use crate::registry;
4+
use crate::tlv;
55
use crate::unwind;
6-
use std::any::Any;
6+
use std::sync::atomic::{AtomicBool, Ordering};
77

88
use crate::FnContext;
99

@@ -135,68 +135,37 @@ where
135135
// Create virtual wrapper for task b; this all has to be
136136
// done here so that the stack frame can keep it all live
137137
// long enough.
138-
let job_b = StackJob::new(tlv, call_b(oper_b), SpinLatch::new(worker_thread));
138+
let job_b_started = AtomicBool::new(false);
139+
let job_b = StackJob::new(
140+
tlv,
141+
|migrated| {
142+
job_b_started.store(true, Ordering::Relaxed);
143+
call_b(oper_b)(migrated)
144+
},
145+
SpinLatch::new(worker_thread),
146+
);
139147
let job_b_ref = job_b.as_job_ref();
140148
let job_b_id = job_b_ref.id();
141149
worker_thread.push(job_b_ref);
142150

143151
// Execute task a; hopefully b gets stolen in the meantime.
144152
let status_a = unwind::halt_unwinding(call_a(oper_a, injected));
145-
let result_a = match status_a {
146-
Ok(v) => v,
147-
Err(err) => join_recover_from_panic(worker_thread, &job_b.latch, err, tlv),
148-
};
149-
150-
// Now that task A has finished, try to pop job B from the
151-
// local stack. It may already have been popped by job A; it
152-
// may also have been stolen. There may also be some tasks
153-
// pushed on top of it in the stack, and we will have to pop
154-
// those off to get to it.
155-
while !job_b.latch.probe() {
156-
if let Some(job) = worker_thread.take_local_job() {
157-
if job_b_id == job.id() {
158-
// Found it! Let's run it.
159-
//
160-
// Note that this could panic, but it's ok if we unwind here.
161153

162-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
163-
tlv::set(tlv);
164-
165-
let result_b = job_b.run_inline(injected);
166-
return (result_a, result_b);
167-
} else {
168-
worker_thread.execute(job);
169-
}
170-
} else {
171-
// Local deque is empty. Time to steal from other
172-
// threads.
173-
worker_thread.wait_until(&job_b.latch);
174-
debug_assert!(job_b.latch.probe());
175-
break;
176-
}
177-
}
154+
// Wait for job B or execute it if it's in the local queue.
155+
worker_thread.wait_for_jobs(
156+
&job_b.latch,
157+
|| job_b_started.load(Ordering::Relaxed),
158+
|job| job.id() == job_b_id,
159+
);
178160

179161
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
180162
tlv::set(tlv);
181163

164+
let result_a = match status_a {
165+
Ok(v) => v,
166+
Err(err) => unwind::resume_unwinding(err),
167+
};
168+
182169
(result_a, job_b.into_result())
183170
})
184171
}
185-
186-
/// If job A panics, we still cannot return until we are sure that job
187-
/// B is complete. This is because it may contain references into the
188-
/// enclosing stack frame(s).
189-
#[cold] // cold path
190-
unsafe fn join_recover_from_panic(
191-
worker_thread: &WorkerThread,
192-
job_b_latch: &SpinLatch<'_>,
193-
err: Box<dyn Any + Send>,
194-
tlv: Tlv,
195-
) -> ! {
196-
worker_thread.wait_until(job_b_latch);
197-
198-
// Restore the TLV since we might have run some jobs overwriting it when waiting for job b.
199-
tlv::set(tlv);
200-
201-
unwind::resume_unwinding(err)
202-
}

rayon-core/src/latch.rs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,6 @@ impl<'r> SpinLatch<'r> {
177177
..SpinLatch::new(thread)
178178
}
179179
}
180-
181-
#[inline]
182-
pub(super) fn probe(&self) -> bool {
183-
self.core_latch.probe()
184-
}
185180
}
186181

187182
impl<'r> AsCoreLatch for SpinLatch<'r> {

rayon-core/src/registry.rs

Lines changed: 58 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::{
1010
ReleaseThreadHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder, Yield,
1111
};
1212
use crossbeam_deque::{Injector, Steal, Stealer, Worker};
13+
use smallvec::SmallVec;
1314
use std::cell::Cell;
1415
use std::collections::hash_map::DefaultHasher;
1516
use std::fmt;
@@ -840,14 +841,58 @@ impl WorkerThread {
840841
/// stealing tasks as necessary.
841842
#[inline]
842843
pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
844+
self.wait_or_steal_until(latch, false)
845+
}
846+
847+
/// Wait until the latch is set. Executes local jobs if `is_job` is true for them and
848+
/// `all_jobs_started` still returns false.
849+
#[inline]
850+
pub(super) unsafe fn wait_for_jobs<L: AsCoreLatch + ?Sized>(
851+
&self,
852+
latch: &L,
853+
mut all_jobs_started: impl FnMut() -> bool,
854+
mut is_job: impl FnMut(&JobRef) -> bool,
855+
) {
856+
let mut jobs = SmallVec::<[JobRef; 8]>::new();
857+
858+
// Make sure all jobs have started.
859+
while !all_jobs_started() {
860+
if let Some(job) = self.worker.pop() {
861+
if is_job(&job) {
862+
// Found a job, let's run it.
863+
self.execute(job);
864+
} else {
865+
jobs.push(job);
866+
}
867+
} else {
868+
break;
869+
}
870+
}
871+
872+
// Restore the jobs that we weren't looking for.
873+
for job in jobs.into_iter().rev() {
874+
self.worker.push(job);
875+
}
876+
877+
// Wait for the jobs to finish.
878+
self.wait_until(latch);
879+
debug_assert!(latch.as_core_latch().probe());
880+
}
881+
882+
#[inline]
883+
pub(super) unsafe fn wait_or_steal_until<L: AsCoreLatch + ?Sized>(
884+
&self,
885+
latch: &L,
886+
steal: bool,
887+
) {
843888
let latch = latch.as_core_latch();
844889
if !latch.probe() {
845-
self.wait_until_cold(latch);
890+
self.wait_until_cold(latch, steal);
846891
}
847892
}
848893

849894
#[cold]
850-
unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
895+
unsafe fn wait_until_cold(&self, latch: &CoreLatch, steal: bool) {
851896
// the code below should swallow all panics and hence never
852897
// unwind; but if something does wrong, we want to abort,
853898
// because otherwise other code in rayon may assume that the
@@ -857,10 +902,16 @@ impl WorkerThread {
857902

858903
let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
859904
while !latch.probe() {
860-
if let Some(job) = self.find_work() {
861-
self.registry.sleep.work_found(idle_state);
862-
self.execute(job);
863-
idle_state = self.registry.sleep.start_looking(self.index, latch);
905+
if steal {
906+
if let Some(job) = self.find_work() {
907+
self.registry.sleep.work_found(idle_state);
908+
self.execute(job);
909+
idle_state = self.registry.sleep.start_looking(self.index, latch);
910+
} else {
911+
self.registry
912+
.sleep
913+
.no_work_found(&mut idle_state, latch, &self)
914+
}
864915
} else {
865916
self.registry
866917
.sleep
@@ -988,7 +1039,7 @@ unsafe fn main_loop(thread: ThreadBuilder) {
9881039
terminate_addr: my_terminate_latch.as_core_latch().addr(),
9891040
});
9901041
registry.acquire_thread();
991-
worker_thread.wait_until(my_terminate_latch);
1042+
worker_thread.wait_or_steal_until(my_terminate_latch, true);
9921043

9931044
// Should not be any work left in our queue.
9941045
debug_assert!(worker_thread.take_local_job().is_none());

0 commit comments

Comments
 (0)