rollup merge of #18941: reem/better-task-pool
This commit is contained in:
commit
c7fc332a22
1 changed files with 167 additions and 63 deletions
|
@ -1,4 +1,4 @@
|
|||
// Copyright 2012 The Rust Project Developers. See the COPYRIGHT
|
||||
// Copyright 2014 The Rust Project Developers. See the COPYRIGHT
|
||||
// file at the top-level directory of this distribution and at
|
||||
// http://rust-lang.org/COPYRIGHT.
|
||||
//
|
||||
|
@ -12,91 +12,195 @@
|
|||
|
||||
use core::prelude::*;
|
||||
|
||||
use task;
|
||||
use task::spawn;
|
||||
use vec::Vec;
|
||||
use comm::{channel, Sender};
|
||||
use comm::{channel, Sender, Receiver};
|
||||
use sync::{Arc, Mutex};
|
||||
|
||||
enum Msg<T> {
|
||||
Execute(proc(&T):Send),
|
||||
Quit
|
||||
struct Sentinel<'a> {
|
||||
jobs: &'a Arc<Mutex<Receiver<proc(): Send>>>,
|
||||
active: bool
|
||||
}
|
||||
|
||||
/// A task pool used to execute functions in parallel.
|
||||
pub struct TaskPool<T> {
|
||||
channels: Vec<Sender<Msg<T>>>,
|
||||
next_index: uint,
|
||||
impl<'a> Sentinel<'a> {
|
||||
fn new(jobs: &Arc<Mutex<Receiver<proc(): Send>>>) -> Sentinel {
|
||||
Sentinel {
|
||||
jobs: jobs,
|
||||
active: true
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel and destroy this sentinel.
|
||||
fn cancel(mut self) {
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
#[unsafe_destructor]
|
||||
impl<T> Drop for TaskPool<T> {
|
||||
impl<'a> Drop for Sentinel<'a> {
|
||||
fn drop(&mut self) {
|
||||
for channel in self.channels.iter_mut() {
|
||||
channel.send(Quit);
|
||||
if self.active {
|
||||
spawn_in_pool(self.jobs.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> TaskPool<T> {
|
||||
/// Spawns a new task pool with `n_tasks` tasks. The provided
|
||||
/// `init_fn_factory` returns a function which, given the index of the
|
||||
/// task, should return local data to be kept around in that task.
|
||||
/// A task pool used to execute functions in parallel.
|
||||
///
|
||||
/// Spawns `n` worker tasks and replenishes the pool if any worker tasks
|
||||
/// panic.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use sync::TaskPool;
|
||||
/// # use iter::AdditiveIterator;
|
||||
///
|
||||
/// let pool = TaskPool::new(4u);
|
||||
///
|
||||
/// let (tx, rx) = channel();
|
||||
/// for _ in range(0, 8u) {
|
||||
/// let tx = tx.clone();
|
||||
/// pool.execute(proc() {
|
||||
/// tx.send(1u);
|
||||
/// });
|
||||
/// }
|
||||
///
|
||||
/// assert_eq!(rx.iter().take(8u).sum(), 8u);
|
||||
/// ```
|
||||
pub struct TaskPool {
|
||||
// How the taskpool communicates with subtasks.
|
||||
//
|
||||
// This is the only such Sender, so when it is dropped all subtasks will
|
||||
// quit.
|
||||
jobs: Sender<proc(): Send>
|
||||
}
|
||||
|
||||
impl TaskPool {
|
||||
/// Spawns a new task pool with `tasks` tasks.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// This function will panic if `n_tasks` is less than 1.
|
||||
pub fn new(n_tasks: uint,
|
||||
init_fn_factory: || -> proc(uint):Send -> T)
|
||||
-> TaskPool<T> {
|
||||
assert!(n_tasks >= 1);
|
||||
/// This function will panic if `tasks` is 0.
|
||||
pub fn new(tasks: uint) -> TaskPool {
|
||||
assert!(tasks >= 1);
|
||||
|
||||
let channels = Vec::from_fn(n_tasks, |i| {
|
||||
let (tx, rx) = channel::<Msg<T>>();
|
||||
let init_fn = init_fn_factory();
|
||||
let (tx, rx) = channel::<proc(): Send>();
|
||||
let rx = Arc::new(Mutex::new(rx));
|
||||
|
||||
let task_body = proc() {
|
||||
let local_data = init_fn(i);
|
||||
loop {
|
||||
match rx.recv() {
|
||||
Execute(f) => f(&local_data),
|
||||
Quit => break
|
||||
}
|
||||
}
|
||||
// Taskpool tasks.
|
||||
for _ in range(0, tasks) {
|
||||
spawn_in_pool(rx.clone());
|
||||
}
|
||||
|
||||
TaskPool { jobs: tx }
|
||||
}
|
||||
|
||||
/// Executes the function `job` on a task in the pool.
|
||||
pub fn execute(&self, job: proc():Send) {
|
||||
self.jobs.send(job);
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_in_pool(jobs: Arc<Mutex<Receiver<proc(): Send>>>) {
|
||||
spawn(proc() {
|
||||
// Will spawn a new task on panic unless it is cancelled.
|
||||
let sentinel = Sentinel::new(&jobs);
|
||||
|
||||
loop {
|
||||
let message = {
|
||||
// Only lock jobs for the time it takes
|
||||
// to get a job, not run it.
|
||||
let lock = jobs.lock();
|
||||
lock.recv_opt()
|
||||
};
|
||||
|
||||
// Run on this scheduler.
|
||||
task::spawn(task_body);
|
||||
match message {
|
||||
Ok(job) => job(),
|
||||
|
||||
tx
|
||||
});
|
||||
// The Taskpool was dropped.
|
||||
Err(..) => break
|
||||
}
|
||||
}
|
||||
|
||||
return TaskPool {
|
||||
channels: channels,
|
||||
next_index: 0,
|
||||
};
|
||||
sentinel.cancel();
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use core::prelude::*;
|
||||
use super::*;
|
||||
use comm::channel;
|
||||
use iter::range;
|
||||
|
||||
const TEST_TASKS: uint = 4u;
|
||||
|
||||
#[test]
|
||||
fn test_works() {
|
||||
use iter::AdditiveIterator;
|
||||
|
||||
let pool = TaskPool::new(TEST_TASKS);
|
||||
|
||||
let (tx, rx) = channel();
|
||||
for _ in range(0, TEST_TASKS) {
|
||||
let tx = tx.clone();
|
||||
pool.execute(proc() {
|
||||
tx.send(1u);
|
||||
});
|
||||
}
|
||||
|
||||
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
|
||||
}
|
||||
|
||||
/// Executes the function `f` on a task in the pool. The function
|
||||
/// receives a reference to the local data returned by the `init_fn`.
|
||||
pub fn execute(&mut self, f: proc(&T):Send) {
|
||||
self.channels[self.next_index].send(Execute(f));
|
||||
self.next_index += 1;
|
||||
if self.next_index == self.channels.len() { self.next_index = 0; }
|
||||
#[test]
|
||||
#[should_fail]
|
||||
fn test_zero_tasks_panic() {
|
||||
TaskPool::new(0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_recovery_from_subtask_panic() {
|
||||
use iter::AdditiveIterator;
|
||||
|
||||
let pool = TaskPool::new(TEST_TASKS);
|
||||
|
||||
// Panic all the existing tasks.
|
||||
for _ in range(0, TEST_TASKS) {
|
||||
pool.execute(proc() { panic!() });
|
||||
}
|
||||
|
||||
// Ensure new tasks were spawned to compensate.
|
||||
let (tx, rx) = channel();
|
||||
for _ in range(0, TEST_TASKS) {
|
||||
let tx = tx.clone();
|
||||
pool.execute(proc() {
|
||||
tx.send(1u);
|
||||
});
|
||||
}
|
||||
|
||||
assert_eq!(rx.iter().take(TEST_TASKS).sum(), TEST_TASKS);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
|
||||
use sync::{Arc, Barrier};
|
||||
|
||||
let pool = TaskPool::new(TEST_TASKS);
|
||||
let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
|
||||
|
||||
// Panic all the existing tasks in a bit.
|
||||
for _ in range(0, TEST_TASKS) {
|
||||
let waiter = waiter.clone();
|
||||
pool.execute(proc() {
|
||||
waiter.wait();
|
||||
panic!();
|
||||
});
|
||||
}
|
||||
|
||||
drop(pool);
|
||||
|
||||
// Kick off the failure.
|
||||
waiter.wait();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_task_pool() {
|
||||
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
|
||||
let mut pool = TaskPool::new(4, f);
|
||||
for _ in range(0u, 8) {
|
||||
pool.execute(proc(i) println!("Hello from thread {}!", *i));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_fail]
|
||||
fn test_zero_tasks_panic() {
|
||||
let f: || -> proc(uint):Send -> uint = || { proc(i) i };
|
||||
TaskPool::new(0, f);
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue