diff --git a/src/thread.rs b/src/thread.rs index a635801..6bc0fd8 100644 --- a/src/thread.rs +++ b/src/thread.rs @@ -106,17 +106,17 @@ /// ``` /// /// Much more straightforward. - +use std::any::Any; use std::cell::RefCell; use std::fmt; +use std::io; use std::marker::PhantomData; use std::mem; use std::ops::DerefMut; -use std::panic::{self, AssertUnwindSafe}; +use std::panic; use std::rc::Rc; -use std::thread; -use std::io; use std::sync::{Arc, Mutex}; +use std::thread; #[doc(hidden)] trait FnBox { @@ -152,50 +152,36 @@ where { let closure: Box + 'a> = Box::new(f); let closure: Box + Send> = mem::transmute(closure); - builder.spawn(move || { - closure.call_box() - }) + builder.spawn(move || closure.call_box()) } pub struct Scope<'env> { - /// The list of the deferred functions and thread join jobs. - dtors: RefCell>>>, + /// The list of the thread join jobs. + joins: RefCell> + 'env>>>, + /// Thread panics invoked so far. + panics: RefCell>>, // !Send + !Sync _marker: PhantomData<*const ()>, } -struct DtorChain<'env, T> { - dtor: Box + 'env>, - next: Option>>, -} - -impl<'env, T> DtorChain<'env, T> { - pub fn pop(chain: &mut Option>) -> Option + 'env>> { - chain.take().map(|mut node| { - *chain = node.next.take().map(|b| *b); - node.dtor - }) - } -} - struct JoinState { join_handle: thread::JoinHandle<()>, - result: ScopedThreadResult + result: ScopedThreadResult, } impl JoinState { fn new(join_handle: thread::JoinHandle<()>, result: ScopedThreadResult) -> JoinState { JoinState { join_handle, - result + result, } } fn join(self) -> thread::Result { let result = self.result; - self.join_handle.join().map(|_| { - result.lock().unwrap().take().unwrap() - }) + self.join_handle + .join() + .map(|_| result.lock().unwrap().take().unwrap()) } } @@ -210,16 +196,14 @@ pub struct ScopedJoinHandle<'scope, T: 'scope> { unsafe impl<'scope, T> Send for ScopedJoinHandle<'scope, T> {} unsafe impl<'scope, T> Sync for ScopedJoinHandle<'scope, T> {} -/// Create a new `Scope` for [*scoped thread spawning*](struct.Scope.html#method.spawn). +/// Creates a new `Scope` for [*scoped thread spawning*](struct.Scope.html#method.spawn). /// -/// In addition, you can [register ad-hoc functions](struct.Scope.html#method.defer) that are -/// deferred to be run. No matter what happens, before the `Scope` is dropped, it is guaranteed that -/// all the unjoined spawned scoped threads are joined and the deferred functions are run. +/// No matter what happens, before the `Scope` is dropped, it is guaranteed that all the unjoined +/// spawned scoped threads are joined. /// -/// `thread::scope()` returns `Ok(())` if all the unjoined spawned threads and the deferred -/// functions did not panic. It returns `Err(e)` if one of them panics with `e`. If many of them -/// panics, it is still guaranteed that all the threads are joined and all the functions are run, -/// and `thread::scope()` returns `Err(e)` with `e` from a panicking thread or function. +/// `thread::scope()` returns `Ok(())` if all the unjoined spawned threads did not panic. It returns +/// `Err(e)` if one of them panics with `e`. If many of them panic, it is still guaranteed that all +/// the threads are joined, and `thread::scope()` returns `Err(e)` with `e` from a panicking thread. /// /// # Examples /// @@ -227,22 +211,34 @@ unsafe impl<'scope, T> Sync for ScopedJoinHandle<'scope, T> {} /// /// ``` /// crossbeam_utils::thread::scope(|scope| { -/// scope.defer(|| println!("Exiting scope")); +/// scope.spawn(|| println!("Exiting scope")); /// scope.spawn(|| println!("Running child thread in scope")); /// }).unwrap(); -/// // Prints messages /// ``` pub fn scope<'env, F, R>(f: F) -> thread::Result where F: FnOnce(&Scope<'env>) -> R, { let mut scope = Scope { - dtors: RefCell::new(None), + joins: RefCell::new(Vec::new()), + panics: RefCell::new(Vec::new()), _marker: PhantomData, }; - let ret = f(&scope); - scope.drop_all()?; - Ok(ret) + + // Executes the scoped function. Panic will be catched as `Err`. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&scope))); + + // Joins all the threads. + scope.join_all(); + let panic = scope.panics.borrow_mut().pop(); + + // If any of the threads panicked, returns the panic's payload. + if let Some(payload) = panic { + return Err(payload); + } + + // Returns the result of the scoped function. + result } impl<'env> fmt::Debug for Scope<'env> { @@ -259,37 +255,20 @@ impl<'scope, T> fmt::Debug for ScopedJoinHandle<'scope, T> { impl<'env> Scope<'env> { // This method is carefully written in a transactional style, so that it can be called directly - // and, if any dtor panics, can be resumed in the unwinding this causes. By initially running - // the method outside of any destructor, we avoid any leakage problems due to + // and, if any thread join panics, can be resumed in the unwinding this causes. By initially + // running the method outside of any destructor, we avoid any leakage problems due to // @rust-lang/rust#14875. - fn drop_all(&mut self) -> thread::Result<()> { - let mut ret = Ok(()); - while let Some(dtor) = DtorChain::pop(&mut self.dtors.borrow_mut()) { - ret = ret.and(dtor.call_box()); + // + // FIXME(jeehoonkang): @rust-lang/rust#14875 is fixed, so maybe we can remove the above comment. + // But I'd like to write tests to check it before removing the comment. + fn join_all(&mut self) { + let mut joins = self.joins.borrow_mut(); + for join in joins.drain(..) { + let result = join.call_box(); + if let Err(payload) = result { + self.panics.borrow_mut().push(payload); + } } - ret - } - - fn defer_inner(&self, f: F) - where - F: (FnOnce() -> thread::Result<()>) + 'env, - { - let mut dtors = self.dtors.borrow_mut(); - *dtors = Some(DtorChain { - dtor: Box::new(f), - next: dtors.take().map(Box::new), - }); - } - - /// Schedule code to be executed when exiting the scope. - /// - /// This is akin to having a destructor on the stack, except that it is *guaranteed* to be - /// run. It is guaranteed that the function is called after all the spawned threads are joined. - pub fn defer(&self, f: F) - where - F: FnOnce() + 'env, - { - self.defer_inner(move || panic::catch_unwind(AssertUnwindSafe(f))); } /// Create a scoped thread. @@ -363,14 +342,14 @@ impl<'scope, 'env: 'scope> ScopedThreadBuilder<'scope, 'env> { let deferred_handle = Rc::new(RefCell::new(Some(join_state))); let my_handle = deferred_handle.clone(); - self.scope.defer_inner(move || { + self.scope.joins.borrow_mut().push(Box::new(move || { let state = deferred_handle.borrow_mut().deref_mut().take(); if let Some(state) = state { state.join().map(|_| ()) } else { Ok(()) } - }); + })); Ok(ScopedJoinHandle { inner: my_handle, @@ -407,9 +386,107 @@ impl<'scope, T> ScopedJoinHandle<'scope, T> { impl<'env> Drop for Scope<'env> { fn drop(&mut self) { - // Actually, there should be no deferred functions left to be run. - self.drop_all().unwrap(); + // Note that `self.joins` can be non-empty when the code inside a `scope()` panics and + // `drop()` is called in unwinding. Even if it's the case, we will join the unjoined + // threads. + // + // We ignore panics from any threads because we're in course of unwinding anyway. + self.join_all(); } } type ScopedThreadResult = Arc>>; + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::AtomicUsize; + use std::sync::atomic::Ordering; + use std::{thread, time}; + + const TIMES: usize = 10; + const SMALL_STACK_SIZE: usize = 20; + + #[test] + fn join() { + let counter = AtomicUsize::new(0); + scope(|scope| { + let handle = scope.spawn(|| { + counter.store(1, Ordering::Relaxed); + }); + assert!(handle.join().is_ok()); + + let panic_handle = scope.spawn(|| { + panic!("\"My honey is running out!\", said Pooh."); + }); + assert!(panic_handle.join().is_err()); + }).unwrap(); + + // There should be sufficient synchronization. + assert_eq!(1, counter.load(Ordering::Relaxed)); + } + + #[test] + fn counter() { + let counter = AtomicUsize::new(0); + scope(|scope| { + for _ in 0..TIMES { + scope.spawn(|| { + counter.fetch_add(1, Ordering::Relaxed); + }); + } + }).unwrap(); + + assert_eq!(TIMES, counter.load(Ordering::Relaxed)); + } + + #[test] + fn counter_builder() { + let counter = AtomicUsize::new(0); + scope(|scope| { + for i in 0..TIMES { + scope + .builder() + .name(format!("child-{}", i)) + .stack_size(SMALL_STACK_SIZE) + .spawn(|| { + counter.fetch_add(1, Ordering::Relaxed); + }) + .unwrap(); + } + }).unwrap(); + + assert_eq!(TIMES, counter.load(Ordering::Relaxed)); + } + + #[test] + fn counter_panic() { + let counter = AtomicUsize::new(0); + let result = scope(|scope| { + scope.spawn(|| { + panic!("\"My honey is running out!\", said Pooh."); + }); + thread::sleep(time::Duration::from_millis(100)); + + for _ in 0..TIMES { + scope.spawn(|| { + counter.fetch_add(1, Ordering::Relaxed); + }); + } + }); + + assert_eq!(TIMES, counter.load(Ordering::Relaxed)); + assert!(result.is_err()); + } + + #[test] + fn panic_twice() { + let result = scope(|scope| { + scope.spawn(|| { + panic!(); + }); + panic!(); + }); + assert!(result.is_err()); + } +}