diff --git a/src/tools/parallel_handler.rs b/src/tools/parallel_handler.rs index 855cd690..cfa32199 100644 --- a/src/tools/parallel_handler.rs +++ b/src/tools/parallel_handler.rs @@ -9,20 +9,6 @@ pub struct SendHandle { abort: Arc>>, } -/// A thread pool which run the supplied closure -/// -/// The send command sends data to the worker threads. If one handler -/// returns an error, we mark the channel as failed and it is no -/// longer possible to send data. -/// -/// When done, the 'complete()' method needs to be called to check for -/// outstanding errors. -pub struct ParallelHandler { - handles: Vec>, - name: String, - input: SendHandle, -} - impl SendHandle { /// Returns the first error happened, if any @@ -42,6 +28,20 @@ impl SendHandle { } } +/// A thread pool which run the supplied closure +/// +/// The send command sends data to the worker threads. If one handler +/// returns an error, we mark the channel as failed and it is no +/// longer possible to send data. +/// +/// When done, the 'complete()' method needs to be called to check for +/// outstanding errors. +pub struct ParallelHandler { + handles: Vec>, + name: String, + input: Option>, +} + impl Clone for SendHandle { fn clone(&self) -> Self { Self { input: self.input.clone(), abort: self.abort.clone() } @@ -94,40 +94,74 @@ impl ParallelHandler { Self { handles, name: name.to_string(), - input: SendHandle { + input: Some(SendHandle { input: input_tx, abort, - }, + }), } } /// Returns a cloneable channel to send data to the worker threads pub fn channel(&self) -> SendHandle { - self.input.clone() + self.input.as_ref().unwrap().clone() } /// Send data to the worker threads pub fn send(&self, input: I) -> Result<(), Error> { - self.input.send(input)?; + self.input.as_ref().unwrap().send(input)?; Ok(()) } /// Wait for worker threads to complete and check for errors - pub fn complete(self) -> Result<(), Error> { - self.input.check_abort()?; - drop(self.input); - let mut msg = Vec::new(); - for (i, handle) in self.handles.into_iter().enumerate() { - if let Err(panic) = handle.join() { - match panic.downcast::<&str>() { - Ok(panic_msg) => msg.push(format!("thread {} ({}) paniced: {}", self.name, i, panic_msg)), - Err(_) => msg.push(format!("thread {} ({}) paniced", self.name, i)), - } - } - } - if msg.is_empty() { + pub fn complete(mut self) -> Result<(), Error> { + self.input.as_ref().unwrap().check_abort()?; + drop(self.input.take()); + + let msg_list = self.join_threads(); + + if msg_list.is_empty() { return Ok(()); } - Err(format_err!("{}", msg.join("\n"))) + Err(format_err!("{}", msg_list.join("\n"))) + } + + fn join_threads(&mut self) -> Vec { + + let mut msg_list = Vec::new(); + + let mut i = 0; + loop { + let handle = match self.handles.pop() { + Some(handle) => handle, + None => break, + }; + if let Err(panic) = handle.join() { + match panic.downcast::<&str>() { + Ok(panic_msg) => msg_list.push( + format!("thread {} ({}) paniced: {}", self.name, i, panic_msg) + ), + Err(_) => msg_list.push( + format!("thread {} ({}) paniced", self.name, i) + ), + } + } + i += 1; + } + msg_list + } +} + +// Note: We make sure that all threads will be joined +impl Drop for ParallelHandler { + fn drop(&mut self) { + drop(self.input.take()); + loop { + match self.handles.pop() { + Some(handle) => { + let _ = handle.join(); + } + None => break, + } + } } }