Skip to content

Rewrite std::sync::TaskPool to be load balancing and panic-resistant #18941

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 16, 2014
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 167 additions & 63 deletions src/libstd/sync/task_pool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
Expand All @@ -12,91 +12,195 @@

use core::prelude::*;

use task;
use task::spawn;
use vec::Vec;
use comm::{channel, Sender};
use comm::{channel, Sender, Receiver};
use sync::{Arc, Mutex};

enum Msg<T> {
Execute(proc(&T):Send),
Quit
struct Sentinel<'a> {
jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
active: bool
}

/// A task pool used to execute functions in parallel.
pub struct TaskPool<T> {
channels: Vec<Sender<Msg<T>>>,
next_index: uint,
impl<'a> Sentinel<'a> {
fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
Sentinel {
jobs: jobs,
active: true
}
}

// Cancel and destroy this sentinel.
fn cancel(mut self) {
self.active = false;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh nice idea, I like this better than task::failing()

}
}

#[unsafe_destructor]
impl<T> Drop for TaskPool<T> {
impl<'a> Drop for Sentinel<'a> {
fn drop(&mut self) {
for channel in self.channels.iter_mut() {
channel.send(Quit);
if self.active {
spawn_in_pool(self.jobs.clone())
}
}
}

impl<T> TaskPool<T> {
/// Spawns a new task pool with `n_tasks` tasks. The provided
/// `init_fn_factory` returns a function which, given the index of the
/// task, should return local data to be kept around in that task.
/// A task pool used to execute functions in parallel.
///
/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
/// panic.
///
/// # Example
///
/// ```rust
/// # use sync::TaskPool;
/// # use iter::AdditiveIterator;
///
/// let pool = TaskPool::new(4u);
///
/// let (tx, rx) = channel();
/// for _ in range(0, 8u) {
/// let tx = tx.clone();
/// pool.execute(proc() {
/// tx.send(1u);
/// });
/// }
///
/// assert_eq!(rx.iter().take(8u).sum(), 8u);
/// ```
pub struct TaskPool {
// How the taskpool communicates with subtasks.
//
// This is the only such Sender, so when it is dropped all subtasks will
// quit.
jobs: Sender<proc(): Send>
}

impl TaskPool {
/// Spawns a new task pool with `tasks` tasks.
///
/// # Panics
///
/// This function will panic if `n_tasks` is less than 1.
pub fn new(n_tasks: uint,
init_fn_factory: || -> proc(uint):Send -> T)
-> TaskPool<T> {
assert!(n_tasks >= 1);

let channels = Vec::from_fn(n_tasks, |i| {
let (tx, rx) = channel::<Msg<T>>();
let init_fn = init_fn_factory();

let task_body = proc() {
let local_data = init_fn(i);
loop {
match rx.recv() {
Execute(f) => f(&local_data),
Quit => break
}
}
};
/// This function will panic if `tasks` is 0.
pub fn new(tasks: uint) -> TaskPool {
assert!(tasks >= 1);

// Run on this scheduler.
task::spawn(task_body);
let (tx, rx) = channel::<proc(): Send>();
let rx = Arc::new(Mutex::new(rx));

tx
});
// Taskpool tasks.
for _ in range(0, tasks) {
spawn_in_pool(rx.clone());
}

return TaskPool {
channels: channels,
next_index: 0,
};
TaskPool { jobs: tx }
}

/// Executes the function `f` on a task in the pool. The function
/// receives a reference to the local data returned by the `init_fn`.
pub fn execute(&mut self, f: proc(&T):Send) {
self.channels[self.next_index].send(Execute(f));
self.next_index += 1;
if self.next_index == self.channels.len() { self.next_index = 0; }
/// Executes the function `job` on a task in the pool.
pub fn execute(&self, job: proc():Send) {
self.jobs.send(job);
}
}

#[test]
fn test_task_pool() {
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
let mut pool = TaskPool::new(4, f);
for _ in range(0u, 8) {
pool.execute(proc(i) println!("Hello from thread {}!", *i));
}
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
spawn(proc() {
// Will spawn a new task on panic unless it is cancelled.
let sentinel = Sentinel::new(&jobs);

loop {
let message = {
// Only lock jobs for the time it takes
// to get a job, not run it.
let lock = jobs.lock();
lock.recv_opt()
};

match message {
Ok(job) => job(),

// The Taskpool was dropped.
Err(..) => break
}
}

sentinel.cancel();
})
}

#[test]
#[should_fail]
fn test_zero_tasks_panic() {
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
TaskPool::new(0, f);
#[cfg(test)]
mod test {
use core::prelude::*;
use super::*;
use comm::channel;
use iter::range;

const TEST_TASKS: uint = 4u;

#[test]
fn test_works() {
use iter::AdditiveIterator;

let pool = TaskPool::new(TEST_TASKS);

let (tx, rx) = channel();
for _ in range(0, TEST_TASKS) {
let tx = tx.clone();
pool.execute(proc() {
tx.send(1u);
});
}

assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
}

#[test]
#[should_fail]
fn test_zero_tasks_panic() {
TaskPool::new(0);
}

#[test]
fn test_recovery_from_subtask_panic() {
use iter::AdditiveIterator;

let pool = TaskPool::new(TEST_TASKS);

// Panic all the existing tasks.
for _ in range(0, TEST_TASKS) {
pool.execute(proc() { panic!() });
}

// Ensure new tasks were spawned to compensate.
let (tx, rx) = channel();
for _ in range(0, TEST_TASKS) {
let tx = tx.clone();
pool.execute(proc() {
tx.send(1u);
});
}

assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
}

#[test]
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
use sync::{Arc, Barrier};

let pool = TaskPool::new(TEST_TASKS);
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));

// Panic all the existing tasks in a bit.
for _ in range(0, TEST_TASKS) {
let waiter = waiter.clone();
pool.execute(proc() {
waiter.wait();
panic!();
});
}

drop(pool);

// Kick off the failure.
waiter.wait();
}
}