diff --git a/Cargo.toml b/Cargo.toml index f2b78215..91dde574 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -153,3 +153,6 @@ enum-try-as-inner = "0.1.0" bitvec = "1.0" hashbrown = "0.14.5" crossbeam-channel = "0.5" +crossbeam-deque = "0.8" +crossbeam-utils = "0.8" +async-task = "4" diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 9e06251e..64ed7038 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -24,6 +24,9 @@ cfg-if.workspace = true tokio = { workspace = true, optional = true } tokio-util = { workspace = true } crossbeam-channel = { workspace = true } +crossbeam-deque = { workspace = true } +crossbeam-utils = { workspace = true } +async-task = { workspace = true } [dev-dependencies] tokio = { workspace = true, features = [ diff --git a/crates/common/src/context.rs b/crates/common/src/context.rs index f2fb4fad..a666a59e 100644 --- a/crates/common/src/context.rs +++ b/crates/common/src/context.rs @@ -1,283 +1,382 @@ //! Execution context. -mod mt; -mod st; #[cfg(any(test, feature = "test-utils"))] mod test; -pub use mt::{ - CustomSpawn, Multithread, MultithreadBuilder, MultithreadBuilderError, Spawn, SpawnError, - StdSpawn, +use std::sync::Arc; + +use futures::{ + AsyncRead, AsyncWrite, + future::{self, BoxFuture, Either}, }; + #[cfg(any(test, feature = "test-utils"))] pub use test::{ RecordedMtData, RecordingDuplex, ReplayDuplex, recording_mt_context, - recording_mt_context_with_limit, recording_mt_context_with_spawn, - recording_mt_context_with_spawn_and_limit, recording_st_context, - recording_st_context_with_limit, replay_mt_context, replay_mt_context_with_limit, - replay_mt_context_with_spawn, replay_mt_context_with_spawn_and_limit, replay_st_context, - test_mt_context, test_mt_context_with_concurrency, test_mt_context_with_spawn, test_st_context, + recording_mt_context_with_limit, recording_mt_context_with_spawn_and_limit, + recording_st_context, recording_st_context_with_limit, replay_mt_context, + replay_mt_context_with_limit, replay_mt_context_with_spawn_and_limit, replay_st_context, + test_mt_context, test_mt_context_with_spawn, test_st_context, }; -use core::fmt; -use std::sync::{Arc, Mutex}; - -use futures::{AsyncRead, AsyncWrite}; - -use crate::{ - ThreadId, - context::mt::{MtConfig, ThreadBuilder, Threads}, - io::Io, -}; +use crate::{ContextId, executor::Inner, io::Io, mux::Mux}; -/// A thread context. -#[derive(Debug)] +/// A task execution context. +/// +/// Each context owns an I/O channel and a [`ContextId`]. Use [`join`], +/// [`try_join`], [`map`] etc. to run sub-tasks concurrently; whether they +/// actually execute in parallel depends on how the context was built. +/// +/// [`join`]: Self::join +/// [`try_join`]: Self::try_join +/// [`map`]: Self::map pub struct Context { - id: ThreadId, + id: ContextId, io: Io, mode: Mode, + /// Sub-namespace counter incremented on each fork. + fork_counter: u32, +} + +enum Mode { + Single, + Multi { + mux: Arc, + executor: Option>, + }, +} + +impl std::fmt::Debug for Context { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mode = match &self.mode { + Mode::Single => "single", + Mode::Multi { + executor: Some(_), .. + } => "multi-threaded", + Mode::Multi { executor: None, .. } => "multi-cooperative", + }; + f.debug_struct("Context") + .field("id", &self.id) + .field("io", &self.io) + .field("mode", &mode) + .finish() + } } impl Context { - /// Creates a new single-threaded context. + /// Creates a new context that uses `mux` to allocate a channel per + /// sub-task. /// - /// # Arguments + /// Sub-tasks are executed cooperatively on the calling future. For + /// parallel execution, build an [`Executor`](crate::Executor) and use + /// [`Executor::new_context`](crate::Executor::new_context) instead. + pub fn new(mux: M) -> Result { + Self::with_prefix(mux, ContextId::default()) + } + + /// Creates a new context backed by a single I/O channel. + /// + /// Sub-tasks spawned via [`join`], [`try_join`], [`map`] etc. share the + /// channel and run **sequentially** in the order given. /// - /// * `io` - The I/O channel used by the context. - pub fn new_single_threaded(io: Io) -> Self + /// [`join`]: Self::join + /// [`try_join`]: Self::try_join + /// [`map`]: Self::map + pub fn new_single_threaded(io: I) -> Self where - Io: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, + I: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, { - Self { - id: ThreadId::default(), - io: crate::io::Io::from_io(io), - mode: Mode::St, - } + Self::from_io(Io::from_io(io)) } - #[allow(dead_code)] pub(crate) fn from_io(io: Io) -> Self { Self { - id: ThreadId::default(), + id: ContextId::default(), io, - mode: Mode::St, + mode: Mode::Single, + fork_counter: 0, } } - pub(crate) fn new_multi_threaded( - id: ThreadId, + /// Like [`Context::new`], but namespaces all channels under `prefix` so + /// several sub-protocols can share a mux without colliding. + pub fn with_prefix( + mux: M, + prefix: impl AsRef<[u8]>, + ) -> Result { + let mux: Arc = Arc::new(mux); + let id = ContextId::from_prefix(prefix); + let io = mux.open(id.as_ref()).map_err(ContextError::mux)?; + Ok(Self { + id, + io, + mode: Mode::Multi { + mux, + executor: None, + }, + fork_counter: 0, + }) + } + + pub(crate) fn with_executor( + id: ContextId, io: Io, - config: Arc, - builder: Arc>, + mux: Arc, + executor: Arc, ) -> Self { Self { - id: id.clone(), + id, io, - mode: Mode::Mt { - threads: Threads::new(id, config, builder), + mode: Mode::Multi { + mux, + executor: Some(executor), }, + fork_counter: 0, } } - /// Returns `true` if the context is multi-threaded. - pub fn is_multi_threaded(&self) -> bool { - matches!(self.mode, Mode::Mt { .. }) + fn child(&self, id: ContextId) -> Result { + let Mode::Multi { mux, executor } = &self.mode else { + unreachable!("child() called on a single-channel context"); + }; + let io = mux.open(id.as_ref()).map_err(ContextError::mux)?; + Ok(Self { + id, + io, + mode: Mode::Multi { + mux: mux.clone(), + executor: executor.clone(), + }, + fork_counter: 0, + }) + } + + fn next_fork(&mut self) -> ContextId { + let base = self.id.child(self.fork_counter); + self.fork_counter += 1; + base } - /// Returns the thread ID. - pub fn id(&self) -> &ThreadId { + /// Returns the context ID. + pub fn id(&self) -> &ContextId { &self.id } - /// Returns a reference to the thread's I/O channel. + /// Returns a reference to the I/O channel. pub fn io(&self) -> &Io { &self.io } - /// Returns a mutable reference to the thread's I/O channel. + /// Returns a mutable reference to the I/O channel. pub fn io_mut(&mut self) -> &mut Io { &mut self.io } - /// Executes a collection of tasks provided with a context. - /// - /// If multi-threading is available, the tasks are load balanced across - /// threads. Otherwise, they are executed sequentially. - pub async fn map<'a, F, T, R, W>( - &'a mut self, - items: Vec, - f: F, - weight: W, - ) -> Result, ContextError> + /// Applies `f` to each item concurrently, returning the results in input + /// order. + pub async fn map(&mut self, items: Vec, f: F) -> Result, ContextError> where - F: for<'b> AsyncFn(&'b mut Self, T) -> R + Clone + Send + 'static, + F: for<'a> Fn(&'a mut Context, T) -> BoxFuture<'a, R> + Clone + Send + 'static, T: Send + 'static, R: Send + 'static, - W: Fn(&T) -> usize + Send + 'static, { - match &mut self.mode { - Mode::St => Ok(st::map(self, items, f).await), - Mode::Mt { threads } => { - let threads = threads.get(threads.concurrency())?; - mt::map(threads, items, f, weight).await + if matches!(self.mode, Mode::Single) { + let mut results = Vec::with_capacity(items.len()); + for item in items { + results.push(f(self, item).await); } + return Ok(results); } + + let parent_id = self.next_fork(); + let executor = self.executor().cloned(); + let mut tasks = Vec::with_capacity(items.len()); + for (i, item) in items.into_iter().enumerate() { + let i = u32::try_from(i).expect("more than u32::MAX items"); + let mut ctx = self.child(parent_id.child(i))?; + let f = f.clone(); + tasks.push(run( + executor.as_ref(), + async move { f(&mut ctx, item).await }, + )); + } + Ok(future::join_all(tasks).await) } - /// Forks the thread and executes the provided closures concurrently. - /// - /// Implementations may not be able to fork, in which case the closures are - /// executed sequentially. - pub async fn join<'a, A, B, RA, RB>(&'a mut self, a: A, b: B) -> Result<(RA, RB), ContextError> + /// Runs `a` and `b` concurrently and returns both results. + pub async fn join(&mut self, a: A, b: B) -> Result<(RA, RB), ContextError> where - A: for<'b> AsyncFnOnce(&'b mut Self) -> RA + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Self) -> RB + Send + 'static, + A: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, RA> + Send + 'static, + B: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, RB> + Send + 'static, RA: Send + 'static, RB: Send + 'static, { - match &mut self.mode { - Mode::St => Ok(st::join(self, a, b).await), - Mode::Mt { threads } => { - let threads = threads.get(2)?; - mt::join(threads, a, b).await - } + if matches!(self.mode, Mode::Single) { + let ra = a(self).await; + let rb = b(self).await; + return Ok((ra, rb)); } + + let parent_id = self.next_fork(); + let executor = self.executor().cloned(); + let mut ctx_a = self.child(parent_id.child(0))?; + let mut ctx_b = self.child(parent_id.child(1))?; + + let task_a = run(executor.as_ref(), async move { a(&mut ctx_a).await }); + let task_b = run(executor.as_ref(), async move { b(&mut ctx_b).await }); + Ok(future::join(task_a, task_b).await) } - /// Forks the thread and executes the provided closures concurrently, - /// returning an error if one of the closures fails. - /// - /// This method is short circuiting, meaning that it returns as soon as one - /// of the closures fails, potentially canceling the other. - /// - /// Implementations may not be able to fork, in which case the closures are - /// executed sequentially. - pub async fn try_join<'a, A, B, RA, RB, E>( - &'a mut self, + /// Like [`Context::join`], but short-circuits as soon as either branch + /// returns an error, potentially cancelling the other. + pub async fn try_join( + &mut self, a: A, b: B, ) -> Result, ContextError> where - A: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, + A: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + B: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, RA: Send + 'static, RB: Send + 'static, E: Send + 'static, { - match &mut self.mode { - Mode::St => Ok(st::try_join(self, a, b).await), - Mode::Mt { threads } => { - let threads = threads.get(2)?; - mt::try_join(threads, a, b).await + if matches!(self.mode, Mode::Single) { + return Ok(async { + let ra = a(self).await?; + let rb = b(self).await?; + Ok((ra, rb)) } + .await); } + + let parent_id = self.next_fork(); + let executor = self.executor().cloned(); + let mut ctx_a = self.child(parent_id.child(0))?; + let mut ctx_b = self.child(parent_id.child(1))?; + + let task_a = run(executor.as_ref(), async move { a(&mut ctx_a).await }); + let task_b = run(executor.as_ref(), async move { b(&mut ctx_b).await }); + Ok(future::try_join(task_a, task_b).await) } - /// Same as [`Context::try_join`], but with three closures. - pub async fn try_join3<'a, A, B, C, RA, RB, RC, E>( - &'a mut self, + /// Same as [`Context::try_join`], but with three branches. + pub async fn try_join3( + &mut self, a: A, b: B, c: C, ) -> Result, ContextError> where - A: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - C: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, + A: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + B: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + C: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, RA: Send + 'static, RB: Send + 'static, RC: Send + 'static, E: Send + 'static, { - match &mut self.mode { - Mode::St => Ok(st::try_join3(self, a, b, c).await), - Mode::Mt { threads } => { - let threads = threads.get(3)?; - mt::try_join3(threads, a, b, c).await + if matches!(self.mode, Mode::Single) { + return Ok(async { + let ra = a(self).await?; + let rb = b(self).await?; + let rc = c(self).await?; + Ok((ra, rb, rc)) } + .await); } + + let parent_id = self.next_fork(); + let executor = self.executor().cloned(); + let mut ctx_a = self.child(parent_id.child(0))?; + let mut ctx_b = self.child(parent_id.child(1))?; + let mut ctx_c = self.child(parent_id.child(2))?; + + let task_a = run(executor.as_ref(), async move { a(&mut ctx_a).await }); + let task_b = run(executor.as_ref(), async move { b(&mut ctx_b).await }); + let task_c = run(executor.as_ref(), async move { c(&mut ctx_c).await }); + Ok(future::try_join3(task_a, task_b, task_c).await) } - /// Same as [`Context::try_join`], but with four closures. - pub async fn try_join4<'a, A, B, C, D, RA, RB, RC, RD, E>( - &'a mut self, + /// Same as [`Context::try_join`], but with four branches. + pub async fn try_join4( + &mut self, a: A, b: B, c: C, d: D, ) -> Result, ContextError> where - A: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - C: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, - D: for<'b> AsyncFnOnce(&'b mut Self) -> Result + Send + 'static, + A: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + B: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + C: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, + D: for<'a> FnOnce(&'a mut Context) -> BoxFuture<'a, Result> + Send + 'static, RA: Send + 'static, RB: Send + 'static, RC: Send + 'static, RD: Send + 'static, E: Send + 'static, { - match &mut self.mode { - Mode::St => Ok(st::try_join4(self, a, b, c, d).await), - Mode::Mt { threads } => { - let threads = threads.get(4)?; - mt::try_join4(threads, a, b, c, d).await + if matches!(self.mode, Mode::Single) { + return Ok(async { + let ra = a(self).await?; + let rb = b(self).await?; + let rc = c(self).await?; + let rd = d(self).await?; + Ok((ra, rb, rc, rd)) } + .await); + } + + let parent_id = self.next_fork(); + let executor = self.executor().cloned(); + let mut ctx_a = self.child(parent_id.child(0))?; + let mut ctx_b = self.child(parent_id.child(1))?; + let mut ctx_c = self.child(parent_id.child(2))?; + let mut ctx_d = self.child(parent_id.child(3))?; + + let task_a = run(executor.as_ref(), async move { a(&mut ctx_a).await }); + let task_b = run(executor.as_ref(), async move { b(&mut ctx_b).await }); + let task_c = run(executor.as_ref(), async move { c(&mut ctx_c).await }); + let task_d = run(executor.as_ref(), async move { d(&mut ctx_d).await }); + Ok(future::try_join4(task_a, task_b, task_c, task_d).await) + } + + fn executor(&self) -> Option<&Arc> { + if let Mode::Multi { executor, .. } = &self.mode { + executor.as_ref() + } else { + None } } } -#[derive(Debug)] -enum Mode { - /// Single-threaded. - St, - /// Multi-threaded. - Mt { threads: Threads }, +/// Spawns `fut` on `executor` if one is provided, otherwise yields the future +/// as-is. The output type is identical either way. +fn run( + executor: Option<&Arc>, + fut: F, +) -> impl std::future::Future + Send +where + F: std::future::Future + Send + 'static, + F::Output: Send + 'static, +{ + match executor { + Some(exec) => Either::Left(crate::executor::spawn_on(exec, fut)), + None => Either::Right(fut), + } } /// Error for [`Context`]. #[derive(Debug, thiserror::Error)] -#[error("context error: {kind}")] +#[error("context mux error")] pub struct ContextError { - kind: ErrorKind, #[source] - source: Option>, + source: std::io::Error, } impl ContextError { - #[allow(dead_code)] - pub(crate) fn new>>( - kind: ErrorKind, - source: E, - ) -> Self { - Self { - kind, - source: Some(source.into()), - } - } -} - -#[derive(Debug)] -#[allow(dead_code)] -pub(crate) enum ErrorKind { - Mux, - Thread, -} - -impl fmt::Display for ErrorKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - ErrorKind::Mux => write!(f, "multiplexer error"), - ErrorKind::Thread => write!(f, "thread error"), - } - } -} - -impl From for ContextError { - fn from(err: SpawnError) -> Self { - Self { - kind: ErrorKind::Thread, - source: Some(Box::new(err)), - } + fn mux(source: std::io::Error) -> Self { + Self { source } } } diff --git a/crates/common/src/context/mt.rs b/crates/common/src/context/mt.rs deleted file mode 100644 index be62570e..00000000 --- a/crates/common/src/context/mt.rs +++ /dev/null @@ -1,369 +0,0 @@ -mod builder; -mod spawn; -mod worker; - -use std::sync::{Arc, Mutex}; - -use futures::{FutureExt, StreamExt as _, stream::FuturesUnordered}; -use pollster::FutureExt as _; -use worker::{Handle, Worker}; - -use crate::{ - Context, ContextError, ThreadId, context::ErrorKind, load_balance::distribute_by_weight, - mux::Mux, -}; - -pub use builder::{MultithreadBuilder, MultithreadBuilderError}; -pub use spawn::{CustomSpawn, Spawn, SpawnError, StdSpawn}; - -#[derive(Debug)] -pub(crate) struct MtConfig { - concurrency: usize, -} - -/// A multi-threaded context. -#[derive(Debug)] -pub struct Multithread { - current_id: ThreadId, - config: Arc, - builder: Arc>, -} - -impl Multithread { - /// Creates a new builder. - pub fn builder() -> MultithreadBuilder { - MultithreadBuilder::default() - } - - /// Creates a new multi-threaded context. - pub fn new_context(&mut self) -> Result { - let id = self.current_id.increment().ok_or_else(|| { - ContextError::new(ErrorKind::Thread, "thread ID overflow".to_string()) - })?; - - let io = self - .builder - .lock() - .unwrap() - .mux - .open(id.clone()) - .map_err(|e| ContextError::new(ErrorKind::Mux, e))?; - - let ctx = - Context::new_multi_threaded(id.clone(), io, self.config.clone(), self.builder.clone()); - - Ok(ctx) - } -} - -pub(crate) struct ThreadBuilder { - spawn: Box, - mux: Box, -} - -impl ThreadBuilder { - fn spawn( - this: Arc>, - id: ThreadId, - config: Arc, - ) -> Result { - let io = this - .lock() - .unwrap() - .mux - .open(id.clone()) - .map_err(|e| ContextError::new(ErrorKind::Mux, e))?; - - let ctx = Context::new_multi_threaded(id.clone(), io, config, this.clone()); - let (worker, handle) = Worker::new(id, ctx); - - this.lock() - .unwrap() - .spawn - .spawn(Box::new(move || worker.run()))?; - - Ok(handle) - } -} - -impl std::fmt::Debug for ThreadBuilder { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("ThreadBuilder").finish_non_exhaustive() - } -} - -#[derive(Debug)] -pub(crate) struct Threads { - config: Arc, - builder: Arc>, - child_id: ThreadId, - children: Vec, -} - -impl Threads { - pub(crate) fn new( - parent_id: ThreadId, - config: Arc, - builder: Arc>, - ) -> Self { - Self { - config, - builder, - child_id: parent_id.fork(), - children: Vec::new(), - } - } - - pub(crate) fn concurrency(&self) -> usize { - self.config.concurrency - } - - pub(crate) fn get(&mut self, count: usize) -> Result<&[Handle], ContextError> { - if count > self.config.concurrency { - return Err(ContextError::new( - ErrorKind::Thread, - "requested more threads than available".to_string(), - )); - } else if self.children.len() < count { - let diff = count - self.children.len(); - for _ in 0..diff { - let id = self.child_id.increment_in_place().ok_or_else(|| { - ContextError::new(ErrorKind::Thread, "thread ID overflow".to_string()) - })?; - - let child = ThreadBuilder::spawn(self.builder.clone(), id, self.config.clone())?; - self.children.push(child); - } - } - - Ok(&self.children[..count]) - } -} - -pub(crate) async fn map( - threads: &[Handle], - items: Vec, - f: F, - weight: W, -) -> Result, ContextError> -where - F: for<'b> AsyncFn(&'b mut Context, T) -> R + Clone + Send + 'static, - T: Send + 'static, - R: Send + 'static, - W: Fn(&T) -> usize + Send + 'static, -{ - let items = items.into_iter().enumerate().collect::>(); - let item_count = items.len(); - let lanes = distribute_by_weight(items, |item| weight(&item.1), threads.len()); - - let mut queue = FuturesUnordered::new(); - for (lane, thread) in lanes.into_iter().zip(threads) { - let f = f.clone(); - let task = thread.send_with_return(move |ctx| { - async move { - let mut outputs = Vec::with_capacity(lane.len()); - for (i, item) in lane { - outputs.push((i, f(ctx, item).await)); - } - outputs - } - .block_on() - })?; - queue.push(task); - } - - let mut outputs = Vec::with_capacity(item_count); - while let Some(lane) = queue.next().await { - outputs.extend(lane?); - } - - outputs.sort_by_key(|(i, _)| *i); - - Ok(outputs.into_iter().map(|(_, output)| output).collect()) -} - -pub(crate) async fn join<'a, A, B, RA, RB>( - threads: &[Handle], - a: A, - b: B, -) -> Result<(RA, RB), ContextError> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> RA + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Context) -> RB + Send + 'static, - RA: Send + 'static, - RB: Send + 'static, -{ - assert_eq!(threads.len(), 2, "expecting exactly two threads"); - - let ra = threads[0].send_with_return(|ctx| a(ctx).block_on())?; - let rb = threads[1].send_with_return(|ctx| b(ctx).block_on())?; - - let (ra, rb) = futures::try_join!(ra, rb)?; - - Ok((ra, rb)) -} - -pub(crate) async fn try_join<'a, A, B, RA, RB, E>( - threads: &[Handle], - a: A, - b: B, -) -> Result, ContextError> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - RA: Send + 'static, - RB: Send + 'static, - E: Send + 'static, -{ - assert_eq!(threads.len(), 2, "expecting exactly two threads"); - - let mut a = threads[0].send_with_return(|ctx| a(ctx).block_on())?.fuse(); - let mut b = threads[1].send_with_return(|ctx| b(ctx).block_on())?.fuse(); - - let mut ra = None; - let mut rb = None; - loop { - futures::select! { - output = a => { - match output? { - Ok(output) => ra = Some(output), - Err(error) => return Ok(Err(error)), - } - }, - output = b => { - match output? { - Ok(output) => rb = Some(output), - Err(error) => return Ok(Err(error)), - } - } - complete => break, - } - } - - let ra = ra.expect("a future should have resolved"); - let rb = rb.expect("b future should have resolved"); - - Ok(Ok((ra, rb))) -} - -pub(crate) async fn try_join3<'a, A, B, C, RA, RB, RC, E>( - threads: &[Handle], - a: A, - b: B, - c: C, -) -> Result, ContextError> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - C: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - RA: Send + 'static, - RB: Send + 'static, - RC: Send + 'static, - E: Send + 'static, -{ - assert_eq!(threads.len(), 3, "expecting exactly three threads"); - - let mut a = threads[0].send_with_return(|ctx| a(ctx).block_on())?.fuse(); - let mut b = threads[1].send_with_return(|ctx| b(ctx).block_on())?.fuse(); - let mut c = threads[2].send_with_return(|ctx| c(ctx).block_on())?.fuse(); - - let mut ra = None; - let mut rb = None; - let mut rc = None; - loop { - futures::select! { - output = a => { - match output? { - Ok(output) => ra = Some(output), - Err(error) => return Ok(Err(error)), - } - }, - output = b => { - match output? { - Ok(output) => rb = Some(output), - Err(error) => return Ok(Err(error)), - } - } - output = c => { - match output? { - Ok(output) => rc = Some(output), - Err(error) => return Ok(Err(error)), - } - } - complete => break, - } - } - - let ra = ra.expect("a future should have resolved"); - let rb = rb.expect("b future should have resolved"); - let rc = rc.expect("c future should have resolved"); - - Ok(Ok((ra, rb, rc))) -} - -pub(crate) async fn try_join4<'a, A, B, C, D, RA, RB, RC, RD, E>( - threads: &[Handle], - a: A, - b: B, - c: C, - d: D, -) -> Result, ContextError> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - C: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - D: for<'b> AsyncFnOnce(&'b mut Context) -> Result + Send + 'static, - RA: Send + 'static, - RB: Send + 'static, - RC: Send + 'static, - RD: Send + 'static, - E: Send + 'static, -{ - assert_eq!(threads.len(), 4, "expecting exactly four threads"); - - let mut a = threads[0].send_with_return(|ctx| a(ctx).block_on())?.fuse(); - let mut b = threads[1].send_with_return(|ctx| b(ctx).block_on())?.fuse(); - let mut c = threads[2].send_with_return(|ctx| c(ctx).block_on())?.fuse(); - let mut d = threads[3].send_with_return(|ctx| d(ctx).block_on())?.fuse(); - - let mut ra = None; - let mut rb = None; - let mut rc = None; - let mut rd = None; - loop { - futures::select! { - output = a => { - match output? { - Ok(output) => ra = Some(output), - Err(error) => return Ok(Err(error)), - } - }, - output = b => { - match output? { - Ok(output) => rb = Some(output), - Err(error) => return Ok(Err(error)), - } - } - output = c => { - match output? { - Ok(output) => rc = Some(output), - Err(error) => return Ok(Err(error)), - } - } - output = d => { - match output? { - Ok(output) => rd = Some(output), - Err(error) => return Ok(Err(error)), - } - } - complete => break, - } - } - - let ra = ra.expect("a future should have resolved"); - let rb = rb.expect("b future should have resolved"); - let rc = rc.expect("c future should have resolved"); - let rd = rd.expect("d future should have resolved"); - - Ok(Ok((ra, rb, rc, rd))) -} diff --git a/crates/common/src/context/mt/builder.rs b/crates/common/src/context/mt/builder.rs deleted file mode 100644 index 4eb9e935..00000000 --- a/crates/common/src/context/mt/builder.rs +++ /dev/null @@ -1,95 +0,0 @@ -use std::sync::{Arc, Mutex}; - -use crate::{ - ThreadId, - context::{ - CustomSpawn, MtConfig, SpawnError, ThreadBuilder, - mt::{ - Multithread, - spawn::{Spawn, StdSpawn}, - }, - }, - mux::Mux, -}; - -/// Builder for [`Multithread`]. -pub struct MultithreadBuilder { - /// Maximum concurrency level per thread. - concurrency: usize, - /// Closure invoked to spawn a new thread. - spawn_handler: S, - /// Multiplexer. - mux: Option>, -} - -impl Default for MultithreadBuilder { - fn default() -> Self { - Self { - concurrency: 8, - spawn_handler: StdSpawn, - mux: None, - } - } -} - -impl MultithreadBuilder -where - S: Spawn, -{ - /// Builds a new multi-threaded context. - pub fn build(self) -> Result { - let mux = self - .mux - .ok_or(MultithreadBuilderError(ErrorRepr::MissingField("mux")))?; - - let builder = ThreadBuilder { - spawn: Box::new(self.spawn_handler), - mux, - }; - - Ok(Multithread { - current_id: ThreadId::default(), - config: Arc::new(MtConfig { - concurrency: self.concurrency, - }), - builder: Arc::new(Mutex::new(builder)), - }) - } -} - -impl MultithreadBuilder { - /// Sets a custom function for spawning threads. - pub fn spawn_handler(self, spawn: F) -> MultithreadBuilder> - where - F: FnMut(Box) -> Result<(), SpawnError> + Send + 'static, - { - MultithreadBuilder { - spawn_handler: CustomSpawn(spawn), - concurrency: self.concurrency, - mux: self.mux, - } - } - - /// Sets the multiplexer. - pub fn mux>>(mut self, mux: M) -> Self { - self.mux = Some(mux.into()); - self - } - - /// Sets the maximum concurrency level per thread. - pub fn concurrency(mut self, concurrency: usize) -> Self { - self.concurrency = concurrency; - self - } -} - -/// Error for [`MultithreadBuilder`]. -#[derive(Debug, thiserror::Error)] -#[error(transparent)] -pub struct MultithreadBuilderError(#[from] ErrorRepr); - -#[derive(Debug, thiserror::Error)] -enum ErrorRepr { - #[error("missing required field: {0}")] - MissingField(&'static str), -} diff --git a/crates/common/src/context/mt/spawn.rs b/crates/common/src/context/mt/spawn.rs deleted file mode 100644 index 67ab8e27..00000000 --- a/crates/common/src/context/mt/spawn.rs +++ /dev/null @@ -1,48 +0,0 @@ -/// Error for [`Spawn`] -#[derive(Debug, thiserror::Error)] -#[error("spawn error: {source}")] -pub struct SpawnError { - source: Box, -} - -impl SpawnError { - /// Creates a new spawn error. - pub fn new(source: E) -> Self - where - E: Into>, - { - Self { - source: source.into(), - } - } -} - -#[doc(hidden)] -pub trait Spawn: Send + 'static { - /// Spawns a new thread. - fn spawn(&mut self, f: Box) -> Result<(), SpawnError>; -} - -#[doc(hidden)] -pub struct StdSpawn; - -impl Spawn for StdSpawn { - fn spawn(&mut self, f: Box) -> Result<(), SpawnError> { - std::thread::Builder::new() - .spawn(f) - .map(|_| ()) - .map_err(SpawnError::new) - } -} - -#[doc(hidden)] -pub struct CustomSpawn(pub F); - -impl Spawn for CustomSpawn -where - F: FnMut(Box) -> Result<(), SpawnError> + Send + 'static, -{ - fn spawn(&mut self, f: Box) -> Result<(), SpawnError> { - (self.0)(f) - } -} diff --git a/crates/common/src/context/mt/worker.rs b/crates/common/src/context/mt/worker.rs deleted file mode 100644 index ff7f3374..00000000 --- a/crates/common/src/context/mt/worker.rs +++ /dev/null @@ -1,83 +0,0 @@ -use std::future::Future; - -use crossbeam_channel::{Receiver, Sender, unbounded}; -use futures::{TryFutureExt, channel::oneshot}; - -use crate::{Context, ContextError, ThreadId, context::ErrorKind}; - -type Job = Box; - -pub(crate) struct Handle { - id: ThreadId, - sender: Sender, -} - -impl Handle { - /// Sends a job to the worker. - pub(crate) fn send(&self, job: F) -> Result<(), ContextError> - where - F: FnOnce(&mut Context) + Send + 'static, - { - self.sender.send(Box::new(job)).map_err(|_| { - ContextError::new( - ErrorKind::Thread, - format!("failed to send job to worker {}", self.id), - ) - }) - } - - /// Sends a job to the worker and returns a future that resolves to the - /// result of the job. - pub(crate) fn send_with_return( - &self, - job: F, - ) -> Result>, ContextError> - where - F: FnOnce(&mut Context) -> R + Send + 'static, - R: Send + 'static, - { - let (sender, receive) = oneshot::channel(); - - self.send(move |ctx| { - let result = job(ctx); - let _ = sender.send(result); - })?; - - let id = self.id.clone(); - Ok(receive.map_err(move |_| { - ContextError::new( - ErrorKind::Thread, - format!("failed to receive result from worker {id}"), - ) - })) - } -} - -impl std::fmt::Debug for Handle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("Handle").field("id", &self.id).finish() - } -} - -pub(crate) struct Worker { - ctx: Context, - queue: Receiver, -} - -impl Worker { - pub(crate) fn new(id: ThreadId, ctx: Context) -> (Self, Handle) { - let (sender, receiver) = unbounded(); - let worker = Self { - ctx, - queue: receiver, - }; - let handle = Handle { id, sender }; - (worker, handle) - } - - pub(crate) fn run(mut self) { - while let Ok(job) = self.queue.recv() { - job(&mut self.ctx); - } - } -} diff --git a/crates/common/src/context/st.rs b/crates/common/src/context/st.rs deleted file mode 100644 index 5d67f60a..00000000 --- a/crates/common/src/context/st.rs +++ /dev/null @@ -1,73 +0,0 @@ -use crate::Context; - -pub(crate) async fn map<'a, F, T, R>(ctx: &'a mut Context, items: Vec, f: F) -> Vec -where - F: for<'b> AsyncFn(&'b mut Context, T) -> R, -{ - let mut results = Vec::with_capacity(items.len()); - for item in items { - results.push(f(ctx, item).await); - } - results -} - -pub(crate) async fn join<'a, A, B, RA, RB>(ctx: &'a mut Context, a: A, b: B) -> (RA, RB) -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> RA, - B: for<'b> AsyncFnOnce(&'b mut Context) -> RB, -{ - let a = a(ctx).await; - let b = b(ctx).await; - (a, b) -} - -pub(crate) async fn try_join<'a, A, B, RA, RB, E>( - ctx: &'a mut Context, - a: A, - b: B, -) -> Result<(RA, RB), E> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result, -{ - let a = a(ctx).await?; - let b = b(ctx).await?; - Ok((a, b)) -} - -pub(crate) async fn try_join3<'a, A, B, C, RA, RB, RC, E>( - ctx: &'a mut Context, - a: A, - b: B, - c: C, -) -> Result<(RA, RB, RC), E> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - C: for<'b> AsyncFnOnce(&'b mut Context) -> Result, -{ - let a = a(ctx).await?; - let b = b(ctx).await?; - let c = c(ctx).await?; - Ok((a, b, c)) -} - -pub(crate) async fn try_join4<'a, A, B, C, D, RA, RB, RC, RD, E>( - ctx: &'a mut Context, - a: A, - b: B, - c: C, - d: D, -) -> Result<(RA, RB, RC, RD), E> -where - A: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - B: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - C: for<'b> AsyncFnOnce(&'b mut Context) -> Result, - D: for<'b> AsyncFnOnce(&'b mut Context) -> Result, -{ - let a = a(ctx).await?; - let b = b(ctx).await?; - let c = c(ctx).await?; - let d = d(ctx).await?; - Ok((a, b, c, d)) -} diff --git a/crates/common/src/context/test/helpers.rs b/crates/common/src/context/test/helpers.rs index c5bfa43c..9a6f1bb1 100644 --- a/crates/common/src/context/test/helpers.rs +++ b/crates/common/src/context/test/helpers.rs @@ -1,28 +1,14 @@ //! Basic test context helpers. -use crate::mux::test_framed_mux; -use futures::{AsyncRead, AsyncWrite}; use serio::channel::duplex; use crate::{ - context::{Context, Multithread, SpawnError}, + context::Context, + executor::{Executor, ExecutorBuilder}, io::Io, - mux::Mux, + mux::test_framed_mux, }; -/// Creates a single-threaded context with a custom frame limit. -/// -/// # Arguments -/// -/// * `io` - The I/O channel used by the context. -/// * `max_frame_length` - Maximum frame size in bytes. -pub(super) fn new_st_context_with_limit(io: I, max_frame_length: usize) -> Context -where - I: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, -{ - Context::from_io(Io::from_io_with_limit(io, max_frame_length)) -} - /// Creates a pair of single-threaded contexts using memory I/O channels. pub fn test_st_context(io_buffer: usize) -> (Context, Context) { let (io_0, io_1) = duplex(io_buffer); @@ -33,76 +19,30 @@ pub fn test_st_context(io_buffer: usize) -> (Context, Context) { ) } -/// Creates a pair of multi-threaded contexts using multiplexed I/O channels. -pub fn test_mt_context(io_buffer: usize) -> (Multithread, Multithread) { +/// Creates a pair of multi-threaded executors sharing multiplexed I/O channels. +pub fn test_mt_context(io_buffer: usize) -> (Executor, Executor) { let (mux_0, mux_1) = test_framed_mux(io_buffer); - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - ( - Multithread::builder().mux(mux_0).build().unwrap(), - Multithread::builder().mux(mux_1).build().unwrap(), + ExecutorBuilder::default().build(mux_0), + ExecutorBuilder::default().build(mux_1), ) } -/// Creates a pair of multi-threaded contexts with a custom spawn handler. -/// -/// This is useful for WASM environments where `std::thread::spawn` is not -/// available and a custom spawner like `web_spawn` is needed. -pub fn test_mt_context_with_spawn(io_buffer: usize, spawn: F) -> (Multithread, Multithread) +/// Like [`test_mt_context`], but uses a custom worker spawn callback (e.g. +/// `web_spawn::spawn` on wasm). +pub fn test_mt_context_with_spawn(io_buffer: usize, spawn: F) -> (Executor, Executor) where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, + F: Fn(Box) -> Result<(), std::io::Error> + + Clone + + Send + + Sync + + 'static, { let (mux_0, mux_1) = test_framed_mux(io_buffer); - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - - ( - Multithread::builder() - .spawn_handler(spawn.clone()) - .mux(mux_0) - .build() - .unwrap(), - Multithread::builder() - .spawn_handler(spawn) - .mux(mux_1) - .build() - .unwrap(), - ) -} - -/// Creates a pair of multi-threaded contexts with a custom spawn handler and -/// concurrency. -/// -/// Like [`test_mt_context_with_spawn`], but allows configuring the maximum -/// concurrency level (number of worker threads) per context. -pub fn test_mt_context_with_concurrency( - io_buffer: usize, - concurrency: usize, - spawn: F, -) -> (Multithread, Multithread) -where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, -{ - let (mux_0, mux_1) = test_framed_mux(io_buffer); - - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - ( - Multithread::builder() - .concurrency(concurrency) - .spawn_handler(spawn.clone()) - .mux(mux_0) - .build() - .unwrap(), - Multithread::builder() - .concurrency(concurrency) - .spawn_handler(spawn) - .mux(mux_1) - .build() - .unwrap(), + ExecutorBuilder::default().spawn(spawn.clone()).build(mux_0), + ExecutorBuilder::default().spawn(spawn).build(mux_1), ) } diff --git a/crates/common/src/context/test/mod.rs b/crates/common/src/context/test/mod.rs index a3a52523..179a32cf 100644 --- a/crates/common/src/context/test/mod.rs +++ b/crates/common/src/context/test/mod.rs @@ -6,15 +6,13 @@ mod replay; #[cfg(test)] mod tests; -pub use helpers::{ - test_mt_context, test_mt_context_with_concurrency, test_mt_context_with_spawn, test_st_context, -}; +pub use helpers::{test_mt_context, test_mt_context_with_spawn, test_st_context}; pub use recording::{ RecordedMtData, RecordingDuplex, recording_mt_context, recording_mt_context_with_limit, - recording_mt_context_with_spawn, recording_mt_context_with_spawn_and_limit, - recording_st_context, recording_st_context_with_limit, + recording_mt_context_with_spawn_and_limit, recording_st_context, + recording_st_context_with_limit, }; pub use replay::{ - ReplayDuplex, replay_mt_context, replay_mt_context_with_limit, replay_mt_context_with_spawn, + ReplayDuplex, replay_mt_context, replay_mt_context_with_limit, replay_mt_context_with_spawn_and_limit, replay_st_context, }; diff --git a/crates/common/src/context/test/recording.rs b/crates/common/src/context/test/recording.rs index fc94ab10..7d30c6ff 100644 --- a/crates/common/src/context/test/recording.rs +++ b/crates/common/src/context/test/recording.rs @@ -11,14 +11,13 @@ use futures::{AsyncRead, AsyncWrite}; use tokio_util::compat::{Compat, TokioAsyncReadCompatExt}; use crate::{ - ThreadId, - context::{Context, Multithread, SpawnError}, + ContextId, + context::Context, + executor::{Executor, ExecutorBuilder}, io::Io, mux::Mux, }; -use super::helpers::new_st_context_with_limit; - /// A duplex stream that records all bytes written. /// /// Used for recording protocol messages for replay in isolated benchmarks. @@ -69,6 +68,46 @@ impl AsyncWrite for RecordingDuplex { } } +/// A simple mux that wraps a single I/O stream for recording tests. +struct SingleChannelMux { + io: Mutex>, + max_frame_length: Option, +} + +impl SingleChannelMux { + fn new(io: I, max_frame_length: Option) -> Self { + Self { + io: Mutex::new(Some(io)), + max_frame_length, + } + } +} + +impl Mux for SingleChannelMux +where + I: AsyncRead + AsyncWrite + Send + Sync + Unpin + 'static, +{ + fn open(&self, id: &[u8]) -> Result { + // Only allow opening the root ID + if id != ContextId::default().as_bytes() { + return Err(std::io::Error::other( + "single channel mux only supports root ID", + )); + } + let io = self + .io + .lock() + .unwrap() + .take() + .ok_or_else(|| std::io::Error::other("channel already opened"))?; + if let Some(limit) = self.max_frame_length { + Ok(Io::from_io_with_limit(io, limit)) + } else { + Ok(Io::from_io(io)) + } + } +} + /// Creates a pair of single-threaded contexts where writes from ctx_1 to ctx_0 /// are recorded. /// @@ -85,9 +124,12 @@ pub fn recording_st_context(io_buffer: usize) -> (Context, Context, Arc>, + pub channels: HashMap, Vec>, } /// Shared state for recording test mux. @@ -138,11 +183,11 @@ pub struct RecordedMtData { #[derive(Default)] struct RecordingMuxState { /// Channels waiting to be opened by role A. - waiting_a: HashMap>, + waiting_a: HashMap, Compat>, /// Channels waiting to be opened by role B. - waiting_b: HashMap, + waiting_b: HashMap, RecordingDuplexMt>, /// Track which channels have been opened. - opened: std::collections::HashSet, + opened: std::collections::HashSet>, } /// Role in the recording mux. @@ -177,13 +222,14 @@ impl std::fmt::Debug for RecordingTestMux { } impl Mux for RecordingTestMux { - fn open(&self, id: ThreadId) -> Result { + fn open(&self, id: &[u8]) -> Result { let mut state = self.state.lock().unwrap(); + let id_vec = id.to_vec(); // Check if channel already exists from the other side match self.role { RecordingRole::A => { - if let Some(stream) = state.waiting_a.remove(&id) { + if let Some(stream) = state.waiting_a.remove(&id_vec) { return Ok(if let Some(limit) = self.max_frame_length { Io::from_io_with_limit(stream, limit) } else { @@ -192,7 +238,7 @@ impl Mux for RecordingTestMux { } } RecordingRole::B => { - if let Some(recording_stream) = state.waiting_b.remove(&id) { + if let Some(recording_stream) = state.waiting_b.remove(&id_vec) { return Ok(if let Some(limit) = self.max_frame_length { Io::from_io_with_limit(recording_stream, limit) } else { @@ -203,7 +249,7 @@ impl Mux for RecordingTestMux { } // Check for duplicate - if !state.opened.insert(id.clone()) { + if !state.opened.insert(id_vec.clone()) { return Err(std::io::Error::other("duplicate stream id")); } @@ -212,7 +258,7 @@ impl Mux for RecordingTestMux { // Role B's writes are recorded let recorded_for_channel = self.recorded.clone(); - let channel_id = id.clone(); + let channel_id = id_vec.clone(); match self.role { RecordingRole::A => { @@ -221,7 +267,7 @@ impl Mux for RecordingTestMux { RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); state .waiting_b - .insert(id, recording_stream.into_recording_duplex()); + .insert(id_vec, recording_stream.into_recording_duplex()); Ok(if let Some(limit) = self.max_frame_length { Io::from_io_with_limit(stream_a.compat(), limit) } else { @@ -230,7 +276,7 @@ impl Mux for RecordingTestMux { } RecordingRole::B => { // B gets recording stream, A gets plain stream - state.waiting_a.insert(id, stream_a.compat()); + state.waiting_a.insert(id_vec, stream_a.compat()); let recording_stream = RecordingDuplexWithId::new(stream_b, channel_id, recorded_for_channel); Ok(if let Some(limit) = self.max_frame_length { @@ -246,14 +292,14 @@ impl Mux for RecordingTestMux { /// Helper to create RecordingDuplex with per-channel recording. struct RecordingDuplexWithId { inner: tokio::io::DuplexStream, - channel_id: ThreadId, + channel_id: Vec, recorded: Arc>, } impl RecordingDuplexWithId { fn new( inner: tokio::io::DuplexStream, - channel_id: ThreadId, + channel_id: Vec, recorded: Arc>, ) -> Self { Self { @@ -277,7 +323,7 @@ impl RecordingDuplexWithId { /// Like `RecordingDuplex` but stores bytes per-channel for MT contexts. struct RecordingDuplexMt { inner: Compat, - channel_id: ThreadId, + channel_id: Vec, recorded: Arc>, } @@ -359,17 +405,12 @@ fn recording_test_mux( /// # Arguments /// /// * `io_buffer` - Size of the I/O buffer per channel. -pub fn recording_mt_context( - io_buffer: usize, -) -> (Multithread, Multithread, Arc>) { +pub fn recording_mt_context(io_buffer: usize) -> (Executor, Executor, Arc>) { let (mux_0, mux_1, recorded) = recording_test_mux(io_buffer, None); - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - ( - Multithread::builder().mux(mux_0).build().unwrap(), - Multithread::builder().mux(mux_1).build().unwrap(), + ExecutorBuilder::default().build(mux_0), + ExecutorBuilder::default().build(mux_1), recorded, ) } @@ -384,90 +425,41 @@ pub fn recording_mt_context( pub fn recording_mt_context_with_limit( io_buffer: usize, max_frame_length: usize, -) -> (Multithread, Multithread, Arc>) { +) -> (Executor, Executor, Arc>) { let (mux_0, mux_1, recorded) = recording_test_mux(io_buffer, Some(max_frame_length)); - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - ( - Multithread::builder().mux(mux_0).build().unwrap(), - Multithread::builder().mux(mux_1).build().unwrap(), + ExecutorBuilder::default().build(mux_0), + ExecutorBuilder::default().build(mux_1), recorded, ) } -/// Creates a pair of multi-threaded contexts with custom spawn handler where -/// writes from ctx_1 are recorded. -/// -/// # Arguments +/// Like [`recording_mt_context_with_limit`], but uses a custom worker spawn +/// callback (e.g. `web_spawn::spawn` on wasm) and a fixed concurrency level. /// -/// * `io_buffer` - Size of the I/O buffer per channel. -/// * `spawn` - Custom spawn handler for worker threads. -pub fn recording_mt_context_with_spawn( - io_buffer: usize, - spawn: F, -) -> (Multithread, Multithread, Arc>) -where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, -{ - let (mux_0, mux_1, recorded) = recording_test_mux(io_buffer, None); - - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - - ( - Multithread::builder() - .spawn_handler(spawn.clone()) - .mux(mux_0) - .build() - .unwrap(), - Multithread::builder() - .spawn_handler(spawn) - .mux(mux_1) - .build() - .unwrap(), - recorded, - ) -} - -/// Creates a pair of multi-threaded contexts with custom spawn handler and -/// frame limit where writes from ctx_1 are recorded. -/// -/// # Arguments -/// -/// * `io_buffer` - Size of the I/O buffer per channel. -/// * `max_frame_length` - Maximum frame size in bytes. -/// * `concurrency` - Maximum parallelism level (max children per parent -/// thread). -/// * `spawn` - Custom spawn handler for worker threads. +/// The same `spawn` callback is used for both executors. pub fn recording_mt_context_with_spawn_and_limit( io_buffer: usize, max_frame_length: usize, concurrency: usize, spawn: F, -) -> (Multithread, Multithread, Arc>) +) -> (Executor, Executor, Arc>) where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, + F: Fn(Box) -> Result<(), std::io::Error> + + Clone + + Send + + Sync + + 'static, { let (mux_0, mux_1, recorded) = recording_test_mux(io_buffer, Some(max_frame_length)); - - let mux_0: Box = Box::new(mux_0); - let mux_1: Box = Box::new(mux_1); - - ( - Multithread::builder() - .spawn_handler(spawn.clone()) - .concurrency(concurrency) - .mux(mux_0) - .build() - .unwrap(), - Multithread::builder() - .spawn_handler(spawn) - .concurrency(concurrency) - .mux(mux_1) - .build() - .unwrap(), - recorded, - ) + let exec_0 = ExecutorBuilder::default() + .num_threads(concurrency) + .spawn(spawn.clone()) + .build(mux_0); + let exec_1 = ExecutorBuilder::default() + .num_threads(concurrency) + .spawn(spawn) + .build(mux_1); + (exec_0, exec_1, recorded) } diff --git a/crates/common/src/context/test/replay.rs b/crates/common/src/context/test/replay.rs index f3695343..89880384 100644 --- a/crates/common/src/context/test/replay.rs +++ b/crates/common/src/context/test/replay.rs @@ -9,13 +9,13 @@ use std::{ use futures::{AsyncRead, AsyncWrite}; use crate::{ - ThreadId, - context::{Context, Multithread, SpawnError}, + context::Context, + executor::{Executor, ExecutorBuilder}, io::Io, mux::Mux, }; -use super::{helpers::new_st_context_with_limit, recording::RecordedMtData}; +use super::recording::RecordedMtData; /// A duplex stream that replays recorded bytes on read and discards writes. /// @@ -65,6 +65,43 @@ impl AsyncWrite for ReplayDuplex { } } +/// A simple mux that wraps a single replay stream. +struct SingleReplayMux { + replay: Mutex>, + max_frame_length: Option, +} + +impl SingleReplayMux { + fn new(recorded: Vec, max_frame_length: Option) -> Self { + Self { + replay: Mutex::new(Some(ReplayDuplex::new(recorded))), + max_frame_length, + } + } +} + +impl Mux for SingleReplayMux { + fn open(&self, id: &[u8]) -> Result { + // Only allow opening the root ID + if id != [0] { + return Err(std::io::Error::other( + "single replay mux only supports root ID", + )); + } + let replay = self + .replay + .lock() + .unwrap() + .take() + .ok_or_else(|| std::io::Error::other("channel already opened"))?; + if let Some(limit) = self.max_frame_length { + Ok(Io::from_io_with_limit(replay, limit)) + } else { + Ok(Io::from_io(replay)) + } + } +} + /// Creates a single-threaded context that replays recorded bytes. /// /// The context will read from the recorded bytes and discard all writes. @@ -75,8 +112,8 @@ impl AsyncWrite for ReplayDuplex { /// * `recorded` - The recorded bytes to replay. /// * `max_frame_length` - Maximum frame size in bytes. pub fn replay_st_context(recorded: Vec, max_frame_length: usize) -> Context { - let replay = ReplayDuplex::new(recorded); - new_st_context_with_limit(replay, max_frame_length) + let mux = SingleReplayMux::new(recorded, Some(max_frame_length)); + Context::new(mux).unwrap() } // ============================================================================ @@ -103,12 +140,12 @@ impl ReplayTestMux { } impl Mux for ReplayTestMux { - fn open(&self, id: ThreadId) -> Result { + fn open(&self, id: &[u8]) -> Result { let recorded = self.recorded.clone(); let max_frame_length = self.max_frame_length; let data = { let mut rec = recorded.lock().unwrap(); - rec.channels.remove(&id).unwrap_or_default() + rec.channels.remove(id).unwrap_or_default() }; let replay = ReplayDuplex::new(data); if let Some(limit) = max_frame_length { @@ -128,11 +165,9 @@ impl Mux for ReplayTestMux { /// # Arguments /// /// * `recorded` - The recorded data to replay (per-channel). -pub fn replay_mt_context(recorded: RecordedMtData) -> Multithread { +pub fn replay_mt_context(recorded: RecordedMtData) -> Executor { let mux = ReplayTestMux::new(recorded, None); - let mux: Box = Box::new(mux); - - Multithread::builder().mux(mux).build().unwrap() + ExecutorBuilder::default().build(mux) } /// Creates a multi-threaded context that replays recorded data with a custom @@ -142,63 +177,25 @@ pub fn replay_mt_context(recorded: RecordedMtData) -> Multithread { /// /// * `recorded` - The recorded data to replay (per-channel). /// * `max_frame_length` - Maximum frame size in bytes. -pub fn replay_mt_context_with_limit( - recorded: RecordedMtData, - max_frame_length: usize, -) -> Multithread { +pub fn replay_mt_context_with_limit(recorded: RecordedMtData, max_frame_length: usize) -> Executor { let mux = ReplayTestMux::new(recorded, Some(max_frame_length)); - let mux: Box = Box::new(mux); - - Multithread::builder().mux(mux).build().unwrap() -} - -/// Creates a multi-threaded context that replays recorded data with custom -/// spawn handler. -/// -/// # Arguments -/// -/// * `recorded` - The recorded data to replay (per-channel). -/// * `spawn` - Custom spawn handler for worker threads. -pub fn replay_mt_context_with_spawn(recorded: RecordedMtData, spawn: F) -> Multithread -where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, -{ - let mux = ReplayTestMux::new(recorded, None); - let mux: Box = Box::new(mux); - - Multithread::builder() - .spawn_handler(spawn) - .mux(mux) - .build() - .unwrap() + ExecutorBuilder::default().build(mux) } -/// Creates a multi-threaded context that replays recorded data with custom -/// spawn handler and frame length limit. -/// -/// # Arguments -/// -/// * `recorded` - The recorded data to replay (per-channel). -/// * `max_frame_length` - Maximum frame size in bytes. -/// * `concurrency` - Maximum parallelism level (max children per parent -/// thread). -/// * `spawn` - Custom spawn handler for worker threads. +/// Like [`replay_mt_context_with_limit`], but uses a custom worker spawn +/// callback (e.g. `web_spawn::spawn` on wasm) and a fixed concurrency level. pub fn replay_mt_context_with_spawn_and_limit( recorded: RecordedMtData, max_frame_length: usize, concurrency: usize, spawn: F, -) -> Multithread +) -> Executor where - F: FnMut(Box) -> Result<(), SpawnError> + Clone + Send + 'static, + F: Fn(Box) -> Result<(), std::io::Error> + Send + Sync + 'static, { let mux = ReplayTestMux::new(recorded, Some(max_frame_length)); - let mux: Box = Box::new(mux); - - Multithread::builder() - .spawn_handler(spawn) - .concurrency(concurrency) - .mux(mux) - .build() - .unwrap() + ExecutorBuilder::default() + .num_threads(concurrency) + .spawn(spawn) + .build(mux) } diff --git a/crates/common/src/context/test/tests.rs b/crates/common/src/context/test/tests.rs index 48c1c35c..33c05fd9 100644 --- a/crates/common/src/context/test/tests.rs +++ b/crates/common/src/context/test/tests.rs @@ -58,7 +58,7 @@ async fn test_recording_determinism() { #[tokio::test] async fn test_recording_mt_context() { - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); @@ -89,7 +89,7 @@ async fn test_recording_mt_context() { #[tokio::test] async fn test_replay_mt_context() { // First: record some messages - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); @@ -106,7 +106,7 @@ async fn test_replay_mt_context() { let recorded_data = recorded.lock().unwrap().clone(); // Now replay to a new context - let mut replay_exec = replay_mt_context(recorded_data); + let replay_exec = replay_mt_context(recorded_data); let mut replay_ctx = replay_exec.new_context().unwrap(); // Should be able to receive the same messages from replay @@ -120,7 +120,7 @@ async fn test_replay_mt_context() { #[tokio::test] async fn test_recording_mt_multiple_channels() { // Test that recording works correctly with multiple channels via ctx.try_join() - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); @@ -129,25 +129,25 @@ async fn test_recording_mt_multiple_channels() { let (result, send_result) = futures::join!( // ctx_0 uses try_join to receive on multiple channels ctx_0.try_join( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: u64 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, + }), ), // ctx_1 uses try_join to send on multiple channels ctx_1.try_join( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(42u32).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(123u64).await.unwrap(); Ok::<_, std::io::Error>(()) - }, + }), ) ); @@ -181,39 +181,39 @@ async fn test_recording_mt_multiple_channels() { #[tokio::test] async fn test_recording_mt_try_join3() { - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); let (result, send_result) = futures::join!( ctx_0.try_join3( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: u64 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: String = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, + }), ), ctx_1.try_join3( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(42u32).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(123u64).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send("hello".to_string()).await.unwrap(); Ok::<_, std::io::Error>(()) - }, + }), ) ); @@ -241,47 +241,47 @@ async fn test_recording_mt_try_join3() { #[tokio::test] async fn test_recording_mt_try_join4() { - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); let (result, send_result) = futures::join!( ctx_0.try_join4( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: u64 = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: String = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { let msg: Vec = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, + }), ), ctx_1.try_join4( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(42u32).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(123u64).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send("hello".to_string()).await.unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send(vec![1u8, 2, 3]).await.unwrap(); Ok::<_, std::io::Error>(()) - }, + }), ) ); @@ -309,7 +309,7 @@ async fn test_recording_mt_try_join4() { #[tokio::test] async fn test_recording_mt_map() { - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); @@ -318,21 +318,15 @@ async fn test_recording_mt_map() { let items: Vec = (0..8).collect(); let (recv_results, send_results) = futures::join!( - ctx_0.map( - items.clone(), - async |ctx: &mut Context, _item: u32| { + ctx_0.map(items.clone(), |ctx: &mut Context, _item: u32| Box::pin( + async move { let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); msg - }, - |_| 1, // weight - ), - ctx_1.map( - items, - async |ctx: &mut Context, item: u32| { - ctx.io_mut().send(item * 10).await.unwrap(); - }, - |_| 1, - ) + } + ),), + ctx_1.map(items, |ctx: &mut Context, item: u32| Box::pin(async move { + ctx.io_mut().send(item * 10).await.unwrap(); + }),) ); let recv_results = recv_results.unwrap(); @@ -360,7 +354,7 @@ async fn test_recording_mt_map() { #[tokio::test] async fn test_recording_mt_nested_try_join() { - let (mut exec_0, mut exec_1, recorded) = recording_mt_context(1024 * 1024); + let (exec_0, exec_1, recorded) = recording_mt_context(1024 * 1024); let mut ctx_0 = exec_0.new_context().unwrap(); let mut ctx_1 = exec_1.new_context().unwrap(); @@ -369,56 +363,64 @@ async fn test_recording_mt_nested_try_join() { // Outer try_join ctx_0.try_join( // Inner try_join in first branch - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { // Receive the outer child's message first let outer_msg: u32 = ctx.io_mut().expect_next().await.unwrap(); assert_eq!(outer_msg, 999); let inner_result = ctx .try_join( - async |ctx: &mut Context| { - let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); - Ok::<_, std::io::Error>(msg) + |ctx: &mut Context| { + Box::pin(async move { + let msg: u32 = ctx.io_mut().expect_next().await.unwrap(); + Ok::<_, std::io::Error>(msg) + }) }, - async |ctx: &mut Context| { - let msg: u64 = ctx.io_mut().expect_next().await.unwrap(); - Ok::<_, std::io::Error>(msg) + |ctx: &mut Context| { + Box::pin(async move { + let msg: u64 = ctx.io_mut().expect_next().await.unwrap(); + Ok::<_, std::io::Error>(msg) + }) }, ) .await .unwrap() .unwrap(); Ok::<_, std::io::Error>(inner_result) - }, + }), // Simple receive in second branch - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { let msg: String = ctx.io_mut().expect_next().await.unwrap(); Ok::<_, std::io::Error>(msg) - }, + }), ), // Matching structure on sender side ctx_1.try_join( - async |ctx: &mut Context| { + |ctx: &mut Context| Box::pin(async move { // Write something on outer child before inner try_join ctx.io_mut().send(999u32).await.unwrap(); ctx.try_join( - async |ctx: &mut Context| { - ctx.io_mut().send(42u32).await.unwrap(); - Ok::<_, std::io::Error>(()) + |ctx: &mut Context| { + Box::pin(async move { + ctx.io_mut().send(42u32).await.unwrap(); + Ok::<_, std::io::Error>(()) + }) }, - async |ctx: &mut Context| { - ctx.io_mut().send(123u64).await.unwrap(); - Ok::<_, std::io::Error>(()) + |ctx: &mut Context| { + Box::pin(async move { + ctx.io_mut().send(123u64).await.unwrap(); + Ok::<_, std::io::Error>(()) + }) }, ) .await .unwrap() .unwrap(); Ok::<_, std::io::Error>(()) - }, - async |ctx: &mut Context| { + }), + |ctx: &mut Context| Box::pin(async move { ctx.io_mut().send("nested".to_string()).await.unwrap(); Ok::<_, std::io::Error>(()) - }, + }), ) ); diff --git a/crates/common/src/executor.rs b/crates/common/src/executor.rs new file mode 100644 index 00000000..890a52f0 --- /dev/null +++ b/crates/common/src/executor.rs @@ -0,0 +1,482 @@ +//! Work-stealing async executor. +//! +//! This module provides a work-stealing threadpool executor that integrates +//! with the MPC task model. Each task is assigned a deterministic [`ContextId`] +//! and owns its own I/O channel, allowing tasks to be freely migrated between +//! worker threads while maintaining deterministic execution order for I/O. + +use std::sync::{ + Arc, + atomic::{AtomicBool, AtomicU32, Ordering}, +}; + +use async_task::{Runnable, Task}; +use crossbeam_deque::{Injector, Steal, Stealer, Worker}; +use crossbeam_utils::sync::{Parker, Unparker}; + +use crate::{Context, ContextId, mux::Mux}; + +/// A work-stealing async executor. +#[derive(Debug)] +pub struct Executor { + inner: Arc, +} + +/// Per-worker parking state. +struct WorkerState { + unparker: Unparker, + parked: AtomicBool, +} + +pub(crate) struct Inner { + /// Global task queue for new tasks and cross-thread wakeups. + injector: Injector, + + /// Stealers for each worker's local queue. + stealers: Vec>, + + /// Per-worker parking state, indexed by worker index. + workers: Box<[WorkerState]>, + + /// Shutdown flag. + shutdown: AtomicBool, + + /// Multiplexer for creating I/O channels. + mux: Arc, + + /// Namespace prefix applied to all contexts created by this executor. + prefix: ContextId, + + /// Counter handed out to each new context, ensuring uniqueness. + next_context: AtomicU32, +} + +impl std::fmt::Debug for Inner { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Inner") + .field("workers", &self.workers.len()) + .field("shutdown", &self.shutdown) + .finish_non_exhaustive() + } +} + +/// A worker spawn callback. +/// +/// Receives a worker entry-point and dispatches it on a thread (or +/// platform-equivalent, e.g. `web_spawn::spawn` on wasm). +pub type SpawnFn = + Box) -> Result<(), std::io::Error> + Send + Sync>; + +/// Builder for [`Executor`]. +pub struct ExecutorBuilder { + num_threads: usize, + prefix: ContextId, + spawn: SpawnFn, +} + +impl std::fmt::Debug for ExecutorBuilder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExecutorBuilder") + .field("num_threads", &self.num_threads) + .field("prefix", &self.prefix) + .finish_non_exhaustive() + } +} + +impl Default for ExecutorBuilder { + fn default() -> Self { + Self { + num_threads: std::thread::available_parallelism() + .map(|n| n.get()) + .unwrap_or(4), + prefix: ContextId::from_prefix([]), + spawn: Box::new(default_spawn), + } + } +} + +fn default_spawn(f: Box) -> Result<(), std::io::Error> { + std::thread::Builder::new() + .name("mpz-executor-worker".to_string()) + .spawn(f) + .map(drop) +} + +impl ExecutorBuilder { + /// Sets the number of worker threads. + pub fn num_threads(mut self, n: usize) -> Self { + self.num_threads = n; + self + } + + /// Sets a namespace prefix applied to all contexts created by the + /// executor. + /// + /// Useful when several sub-protocols share a mux and need to be kept in + /// disjoint ID spaces. + pub fn prefix(mut self, prefix: impl AsRef<[u8]>) -> Self { + self.prefix = ContextId::from_prefix(prefix); + self + } + + /// Sets a custom worker spawn callback. + /// + /// Defaults to `std::thread::spawn`. Useful on platforms without OS + /// threads (e.g. wasm32 where workers must be created via + /// `web_spawn::spawn`). + pub fn spawn(mut self, spawn: F) -> Self + where + F: Fn(Box) -> Result<(), std::io::Error> + + Send + + Sync + + 'static, + { + self.spawn = Box::new(spawn); + self + } + + /// Builds the executor with the given multiplexer. + pub fn build(self, mux: M) -> Executor { + let injector = Injector::new(); + + // Create local worker queues and their stealers. + let worker_queues: Vec> = + (0..self.num_threads).map(|_| Worker::new_fifo()).collect(); + + let stealers: Vec> = worker_queues.iter().map(|w| w.stealer()).collect(); + + let parkers: Vec = (0..self.num_threads).map(|_| Parker::new()).collect(); + let workers: Box<[WorkerState]> = parkers + .iter() + .map(|p| WorkerState { + unparker: p.unparker().clone(), + parked: AtomicBool::new(false), + }) + .collect(); + + let inner = Arc::new(Inner { + injector, + stealers, + workers, + shutdown: AtomicBool::new(false), + mux: Arc::new(mux), + prefix: self.prefix, + next_context: AtomicU32::new(0), + }); + + // Spawn worker threads via the configured spawn callback. + for (index, (local, parker)) in worker_queues.into_iter().zip(parkers).enumerate() { + let inner = inner.clone(); + (self.spawn)(Box::new(move || worker_loop(inner, local, index, parker))) + .expect("failed to spawn worker thread"); + } + + Executor { inner } + } +} + +/// Worker thread loop. +fn worker_loop(inner: Arc, local: Worker, index: usize, parker: Parker) { + let state = &inner.workers[index]; + + let drain_local = |local: &Worker| { + // Drop any runnables still sitting in this worker's local queue. + // Dropping cancels the corresponding task so awaiters of `Task` + // see cancellation instead of hanging on a worker that has exited. + while local.pop().is_some() {} + }; + + while !inner.shutdown.load(Ordering::Relaxed) { + if let Some(runnable) = find_task(&inner, &local, index) { + // Poll the task once. If it returns Pending, the waker will + // reschedule it; if it completes, we're done with the task. + runnable.run(); + continue; + } + + // Slow path: announce we're about to park, then recheck. + // + // The recheck after setting `parked = true` closes the race against a + // producer that pushed before we announced (and therefore didn't see + // us as a candidate to unpark). + state.parked.store(true, Ordering::SeqCst); + + if let Some(runnable) = find_task(&inner, &local, index) { + state.parked.store(false, Ordering::SeqCst); + runnable.run(); + continue; + } + + if inner.shutdown.load(Ordering::Relaxed) { + state.parked.store(false, Ordering::SeqCst); + break; + } + + // If a producer fires between the recheck above and `park()`, the + // `unpark` token is remembered by the parker and `park()` returns + // immediately — no lost wakeup. + parker.park(); + state.parked.store(false, Ordering::SeqCst); + } + + drain_local(&local); +} + +/// Finds a task to execute using work-stealing. +fn find_task(inner: &Inner, local: &Worker, index: usize) -> Option { + // 1. Local queue (fast path, cache-friendly). + if let Some(runnable) = local.pop() { + return Some(runnable); + } + + // 2. Global injector queue. + loop { + match inner.injector.steal_batch_and_pop(local) { + Steal::Success(runnable) => return Some(runnable), + Steal::Empty => break, + Steal::Retry => continue, + } + } + + // 3. Steal from other workers. + let num_stealers = inner.stealers.len(); + for i in 1..num_stealers { + let victim = (index + i) % num_stealers; + loop { + match inner.stealers[victim].steal_batch_and_pop(local) { + Steal::Success(runnable) => return Some(runnable), + Steal::Empty => break, + Steal::Retry => continue, + } + } + } + + None +} + +impl Executor { + /// Creates a new builder. + pub fn builder() -> ExecutorBuilder { + ExecutorBuilder::default() + } + + /// Shuts down the executor. + /// + /// Sets the shutdown flag, drains the global queue, and unparks every + /// worker. After this returns, no further runnables will be accepted by + /// the scheduler (newly woken tasks are dropped on arrival), and any + /// runnables still queued at the moment of shutdown are dropped. Dropping + /// a [`Runnable`] cancels its task, so awaiters of `Task` propagate + /// cancellation rather than hanging on a worker that has exited. + pub fn shutdown(&self) { + self.inner.shutdown.store(true, Ordering::SeqCst); + + // Drain the injector before unparking workers. Any push that races + // with this drain is handled by the shutdown check in the schedule + // callback (see `spawn_on`), which drops the runnable. + loop { + match self.inner.injector.steal() { + Steal::Success(_) => continue, + Steal::Empty => break, + Steal::Retry => continue, + } + } + + for w in self.inner.workers.iter() { + w.unparker.unpark(); + } + } + + /// Returns `true` if the executor has been shut down. + pub fn is_shutdown(&self) -> bool { + self.inner.shutdown.load(Ordering::SeqCst) + } + + /// Creates a new context. + /// + /// Each context produced by an executor is given a distinct ID under the + /// executor's configured prefix. + pub fn new_context(&self) -> Result { + let index = self.inner.next_context.fetch_add(1, Ordering::Relaxed); + let id = self.inner.prefix.child(index); + let io = self.inner.mux.open(id.as_ref())?; + Ok(Context::with_executor( + id, + io, + self.inner.mux.clone(), + self.inner.clone(), + )) + } +} + +impl Drop for Executor { + fn drop(&mut self) { + self.shutdown(); + } +} + +/// Spawns a future on the given executor inner. +pub(crate) fn spawn_on(inner: &Arc, future: F) -> Task +where + F: std::future::Future + Send + 'static, + F::Output: Send + 'static, +{ + let inner = Arc::clone(inner); + let schedule = move |runnable: Runnable| { + // After shutdown, no worker will run this. Dropping the runnable + // cancels the task so the awaiter doesn't hang. SeqCst pairs with + // the SeqCst store in `Executor::shutdown` to ensure that any push + // that "loses" the race is then drained by shutdown's pass over + // the injector. + if inner.shutdown.load(Ordering::SeqCst) { + drop(runnable); + return; + } + inner.injector.push(runnable); + // Scan for an idle worker and claim it for this notification. The + // `load` is a cheap filter; the `compare_exchange` is what makes the + // claim race-free against other concurrent producers. Stops at the + // first claimed worker — one push, one wake. + for w in inner.workers.iter() { + if w.parked.load(Ordering::SeqCst) + && w.parked + .compare_exchange(true, false, Ordering::SeqCst, Ordering::SeqCst) + .is_ok() + { + w.unparker.unpark(); + break; + } + } + }; + let (runnable, task) = async_task::spawn(future, schedule); + runnable.schedule(); + task +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mux::test_framed_mux; + use serio::{SinkExt, StreamExt}; + + #[test] + fn test_executor_spawn() { + let (mux_a, _mux_b) = test_framed_mux(1024); + let executor = Executor::builder().num_threads(2).build(mux_a); + + let mut ctx = executor.new_context().unwrap(); + let (a, b) = futures::executor::block_on(ctx.join( + |_ctx| Box::pin(async move { 21 }), + |_ctx| Box::pin(async move { 21 }), + )) + .unwrap(); + + assert_eq!(a + b, 42); + + executor.shutdown(); + } + + #[test] + fn test_executor_map() { + let (mux_a, _mux_b) = test_framed_mux(1024); + let executor = Executor::builder().num_threads(2).build(mux_a); + + let mut ctx = executor.new_context().unwrap(); + + let items = vec![1, 2, 3, 4, 5]; + let results = + futures::executor::block_on(ctx.map(items, |_ctx, x| Box::pin(async move { x * 2 }))); + + assert_eq!(results.unwrap(), vec![2, 4, 6, 8, 10]); + + executor.shutdown(); + } + + #[test] + fn test_executor_join() { + let (mux_a, _mux_b) = test_framed_mux(1024); + let executor = Executor::builder().num_threads(2).build(mux_a); + + let mut ctx = executor.new_context().unwrap(); + + let result = futures::executor::block_on(ctx.join( + |_ctx| Box::pin(async move { 1 + 1 }), + |_ctx| Box::pin(async move { 2 + 2 }), + )); + + assert_eq!(result.unwrap(), (2, 4)); + + executor.shutdown(); + } + + #[test] + fn test_executor_io() { + // Test that I/O works between two executors (simulating two parties). + let (mux_a, mux_b) = test_framed_mux(1024); + + let executor_a = Executor::builder().num_threads(2).build(mux_a); + let executor_b = Executor::builder().num_threads(2).build(mux_b); + + let mut ctx_a = executor_a.new_context().unwrap(); + let mut ctx_b = executor_b.new_context().unwrap(); + + let (_, (val1, val2)) = futures::executor::block_on(futures::future::join( + async { + ctx_a.io_mut().send(42u32).await.unwrap(); + ctx_a.io_mut().send(123u32).await.unwrap(); + }, + async { + let val1: u32 = ctx_b.io_mut().next().await.unwrap().unwrap(); + let val2: u32 = ctx_b.io_mut().next().await.unwrap().unwrap(); + (val1, val2) + }, + )); + + assert_eq!(val1, 42); + assert_eq!(val2, 123); + + executor_a.shutdown(); + executor_b.shutdown(); + } + + #[test] + fn test_executor_map_with_io() { + // Test that map works with I/O between two parties. + let (mux_a, mux_b) = test_framed_mux(1024); + + let executor_a = Executor::builder().num_threads(4).build(mux_a); + let executor_b = Executor::builder().num_threads(4).build(mux_b); + + let mut ctx_a = executor_a.new_context().unwrap(); + let mut ctx_b = executor_b.new_context().unwrap(); + + let items_a = vec![1u32, 2, 3, 4]; + let items_b = vec![10u32, 20, 30, 40]; + + // Party A sends each item, Party B receives and returns sum. + let task_a = ctx_a.map(items_a, |ctx, x| { + Box::pin(async move { + ctx.io_mut().send(x).await.unwrap(); + }) + }); + + let task_b = ctx_b.map(items_b, |ctx, x| { + Box::pin(async move { + let received: u32 = ctx.io_mut().next().await.unwrap().unwrap(); + received + x + }) + }); + + let (results_a, results_b) = + futures::executor::block_on(futures::future::join(task_a, task_b)); + + assert!(results_a.is_ok()); + let results_b = results_b.unwrap(); + + // Each B task should receive the corresponding A value and add it to B's value. + assert_eq!(results_b, vec![11, 22, 33, 44]); + + executor_a.shutdown(); + executor_b.shutdown(); + } +} diff --git a/crates/common/src/id.rs b/crates/common/src/id.rs index 1ee6db7f..2ea44f8b 100644 --- a/crates/common/src/id.rs +++ b/crates/common/src/id.rs @@ -1,121 +1,91 @@ use core::fmt; -/// A logical thread identifier. +/// A logical context identifier. /// -/// Every thread is assigned a unique identifier, which can be forked to create -/// a child thread. +/// Hierarchical: each level is a `u32` index, serialized big-endian. Both +/// parties derive identical IDs by following the same call sequence. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct ThreadId(Box<[u8]>); +pub struct ContextId(Box<[u8]>); -impl Default for ThreadId { - fn default() -> Self { - Self(vec![0].into()) - } -} +impl ContextId { + const LEVEL_BYTES: usize = 4; -impl ThreadId { - /// Creates a new thread ID with the provided ID. + /// Creates a context ID at the top level with the given index. #[inline] - pub fn new(id: u8) -> Self { - Self(vec![id].into()) + pub fn new(index: u32) -> Self { + Self(index.to_be_bytes().to_vec().into()) } - /// Returns the thread ID as a byte slice. + /// Creates a context ID from an arbitrary byte prefix. + /// + /// Useful for namespacing contexts under a caller-chosen identifier (e.g. + /// a sub-protocol name). Forked children are appended to this prefix + /// using the standard hierarchical layout. #[inline] - pub fn as_bytes(&self) -> &[u8] { - &self.0 + pub fn from_prefix(prefix: impl AsRef<[u8]>) -> Self { + Self(prefix.as_ref().to_vec().into()) } - /// Increments the thread ID, returning `None` if the ID overflows. + /// Returns the ID as a byte slice. #[inline] - pub fn increment(&self) -> Option { - let mut next = self.clone(); - - let id = next.0.last_mut()?; - *id = id.checked_add(1)?; - - Some(next) + pub fn as_bytes(&self) -> &[u8] { + &self.0 } - /// Increments the thread ID in place, returning the original ID if it - /// doesn't overflow. + /// Descends into a child namespace at the given index. #[inline] - pub fn increment_in_place(&mut self) -> Option { - let prev = self.clone(); - - let id = self.0.last_mut()?; - *id = id.checked_add(1)?; - - Some(prev) + pub fn child(&self, index: u32) -> Self { + let mut bytes = Vec::with_capacity(self.0.len() + Self::LEVEL_BYTES); + bytes.extend_from_slice(&self.0); + bytes.extend_from_slice(&index.to_be_bytes()); + Self(bytes.into()) } +} - /// Forks the thread ID. - #[inline] - pub fn fork(&self) -> Self { - let mut id = vec![0; self.0.len() + 1]; - id[0..self.0.len()].copy_from_slice(&self.0); - - Self(id.into()) +impl Default for ContextId { + fn default() -> Self { + Self::new(0) } } -impl AsRef<[u8]> for ThreadId { +impl AsRef<[u8]> for ContextId { #[inline] fn as_ref(&self) -> &[u8] { self.as_bytes() } } -impl fmt::Display for ThreadId { +impl fmt::Display for ContextId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for (i, byte) in self.0.iter().enumerate() { + for (i, chunk) in self.0.chunks(Self::LEVEL_BYTES).enumerate() { if i > 0 { write!(f, "/")?; } - write!(f, "{byte}")?; + let mut buf = [0u8; 4]; + buf[..chunk.len()].copy_from_slice(chunk); + write!(f, "{}", u32::from_be_bytes(buf))?; } Ok(()) } } -/// A simple counter. -#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct Counter(u32); - -impl Counter { - /// Increments the counter in place, returning the previous value. - #[allow(clippy::should_implement_trait)] - pub fn next(&mut self) -> Self { - let prev = self.0; - self.0 += 1; - Self(prev) - } - - /// Returns the next value without incrementing the counter. - pub fn peek(&self) -> Self { - Self(self.0 + 1) - } -} - -impl fmt::Display for Counter { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.0) - } -} - #[cfg(test)] mod tests { use super::*; #[test] - fn test_thread_id() { - let mut id = ThreadId::new(0); - - assert_eq!(id.as_bytes(), &[0]); - assert_eq!(id.increment_in_place().unwrap().as_bytes(), &[0]); - assert_eq!(id.as_bytes(), &[1]); - assert_eq!(id.increment().unwrap().as_bytes(), &[2]); - assert_eq!(id.fork().as_bytes(), &[1, 0]); + fn test_context_id() { + let id = ContextId::default(); + assert_eq!(id.as_bytes(), &[0, 0, 0, 0]); + + let child0 = id.child(0); + let child1 = id.child(1); + assert_ne!(child0, child1); + assert_eq!(child0.as_bytes(), &[0, 0, 0, 0, 0, 0, 0, 0]); + assert_eq!(child1.as_bytes(), &[0, 0, 0, 0, 0, 0, 0, 1]); + + let grand = child0.child(7); + assert_eq!(grand.as_bytes(), &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7]); } } diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 721824fd..f00a8a89 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -16,20 +16,21 @@ )] pub mod context; +pub mod executor; #[cfg(any(test, feature = "future"))] pub mod future; mod id; #[cfg(any(test, feature = "ideal"))] pub mod ideal; pub mod io; -pub(crate) mod load_balance; pub mod mux; #[cfg(feature = "sync")] pub mod sync; mod task; pub use context::{Context, ContextError}; -pub use id::{Counter, ThreadId}; +pub use executor::{Executor, ExecutorBuilder}; +pub use id::ContextId; pub use task::Task; use async_trait::async_trait; diff --git a/crates/common/src/load_balance.rs b/crates/common/src/load_balance.rs deleted file mode 100644 index fe0cf435..00000000 --- a/crates/common/src/load_balance.rs +++ /dev/null @@ -1,44 +0,0 @@ -//! Load balancing algorithms. - -/// Evenly distributes items across lanes based on their weight. -pub(crate) fn distribute_by_weight( - items: impl IntoIterator, - f_weight: F, - num_lanes: usize, -) -> Vec> -where - F: Fn(&T) -> usize, -{ - if num_lanes == 0 { - return Vec::new(); - } - - // Compute weights and pair with items - let mut items_with_weights: Vec<(T, usize)> = items - .into_iter() - .map(|item| { - let weight = f_weight(&item); - (item, weight) - }) - .collect(); - - // Sort in decreasing order of weight - items_with_weights.sort_by_key(|item| std::cmp::Reverse(item.1)); - - let mut lanes: Vec> = (0..num_lanes).map(|_| Vec::new()).collect(); - let mut lane_weights = vec![0; num_lanes]; - for (item, weight) in items_with_weights { - // Find the lane with minimum total weight - let idx = lane_weights - .iter() - .enumerate() - .min_by_key(|&(_, w)| w) - .unwrap() - .0; - - lanes[idx].push(item); - lane_weights[idx] += weight; - } - - lanes -} diff --git a/crates/common/src/mux.rs b/crates/common/src/mux.rs index 937fc59e..bc0ad6d2 100644 --- a/crates/common/src/mux.rs +++ b/crates/common/src/mux.rs @@ -1,11 +1,11 @@ //! Multiplexing types. -use crate::{ThreadId, io::Io}; +use crate::io::Io; /// A multiplexer. pub trait Mux { - /// Opens a new I/O channel for the given thread. - fn open(&self, id: ThreadId) -> Result; + /// Opens a new I/O channel for the given context ID. + fn open(&self, id: &[u8]) -> Result; } #[cfg(any(test, feature = "test-utils"))] @@ -17,7 +17,7 @@ mod test_utils { use serio::channel::{MemoryDuplex, duplex}; - use crate::{ThreadId, io::Io, mux::Mux}; + use crate::{io::Io, mux::Mux}; #[derive(Debug, Default)] struct State { @@ -41,16 +41,16 @@ mod test_utils { } impl Mux for TestFramedMux { - fn open(&self, id: ThreadId) -> Result { + fn open(&self, id: &[u8]) -> Result { let mut state = self.state.lock().unwrap(); if let Some(channel) = match self.role { - Role::A => state.waiting_a.remove(id.as_ref()), - Role::B => state.waiting_b.remove(id.as_ref()), + Role::A => state.waiting_a.remove(id), + Role::B => state.waiting_b.remove(id), } { Ok(Io::from_channel(channel)) } else { - if !state.exists.insert(id.as_ref().to_vec()) { + if !state.exists.insert(id.to_vec()) { return Err(std::io::Error::other("duplicate stream id")); } @@ -58,11 +58,11 @@ mod test_utils { match self.role { Role::A => { - state.waiting_b.insert(id.as_ref().to_vec(), b); + state.waiting_b.insert(id.to_vec(), b); Ok(Io::from_channel(a)) } Role::B => { - state.waiting_a.insert(id.as_ref().to_vec(), a); + state.waiting_a.insert(id.to_vec(), a); Ok(Io::from_channel(b)) } } @@ -90,7 +90,7 @@ mod test_utils { #[cfg(test)] mod tests { - use crate::{ThreadId, mux::Mux}; + use crate::mux::Mux; use serio::{SinkExt, StreamExt}; #[test] @@ -98,11 +98,11 @@ mod test_utils { let (a, b) = super::test_framed_mux(1); futures::executor::block_on(async { - let mut a_0 = a.open(ThreadId::new(0)).unwrap(); - let mut b_0 = b.open(ThreadId::new(0)).unwrap(); + let mut a_0 = a.open(&[0]).unwrap(); + let mut b_0 = b.open(&[0]).unwrap(); - let mut a_1 = a.open(ThreadId::new(1)).unwrap(); - let mut b_1 = b.open(ThreadId::new(1)).unwrap(); + let mut a_1 = a.open(&[1]).unwrap(); + let mut b_1 = b.open(&[1]).unwrap(); a_0.send(42u8).await.unwrap(); assert_eq!(b_0.next::().await.unwrap().unwrap(), 42); @@ -117,11 +117,11 @@ mod test_utils { let (a, b) = super::test_framed_mux(1); futures::executor::block_on(async { - let _ = a.open(ThreadId::new(0)).unwrap(); - let _ = b.open(ThreadId::new(0)).unwrap(); + let _ = a.open(&[0]).unwrap(); + let _ = b.open(&[0]).unwrap(); - assert!(a.open(ThreadId::new(0)).is_err()); - assert!(b.open(ThreadId::new(0)).is_err()); + assert!(a.open(&[0]).is_err()); + assert!(b.open(&[0]).is_err()); }) } } diff --git a/crates/garble/benches/evaluator.rs b/crates/garble/benches/evaluator.rs index c1522b11..9e8f6ba5 100644 --- a/crates/garble/benches/evaluator.rs +++ b/crates/garble/benches/evaluator.rs @@ -10,9 +10,12 @@ use std::sync::Arc; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_circuits::{AES128, Circuit}; -use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, - replay_mt_context_with_limit, replay_st_context, +use mpz_common::{ + Executor, + context::{ + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, + }, }; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_memory_core::{Array, binary::U8, correlated::Delta}; @@ -165,8 +168,8 @@ async fn run_evaluator_with_replay( /// Runs the full garble protocol with MT contexts. /// Records garbler->evaluator messages. async fn run_protocol_record_garbler_mt( - exec_gb: &mut Multithread, - exec_ev: &mut Multithread, + exec_gb: &mut Executor, + exec_ev: &mut Executor, circuit: Arc, circuit_count: usize, seed: u64, @@ -261,7 +264,7 @@ fn record_for_evaluator_mt( /// Runs MT evaluator only with replay context. async fn run_evaluator_with_replay_mt( - exec: &mut Multithread, + exec: &mut Executor, circuit: Arc, circuit_count: usize, ) { diff --git a/crates/garble/benches/garbler.rs b/crates/garble/benches/garbler.rs index 8691b76d..0a2cfb69 100644 --- a/crates/garble/benches/garbler.rs +++ b/crates/garble/benches/garbler.rs @@ -9,9 +9,12 @@ use std::sync::Arc; use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_circuits::{AES128, Circuit}; -use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, - replay_mt_context_with_limit, replay_st_context, +use mpz_common::{ + Executor, + context::{ + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, + }, }; use mpz_garble::protocol::semihonest::{Evaluator, Garbler}; use mpz_memory_core::{Array, binary::U8, correlated::Delta}; @@ -167,8 +170,8 @@ async fn run_garbler_with_replay( /// Runs the full garble protocol with MT contexts. /// Records evaluator->garbler messages. async fn run_protocol_record_evaluator_mt( - exec_gb: &mut Multithread, - exec_ev: &mut Multithread, + exec_gb: &mut Executor, + exec_ev: &mut Executor, circuit: Arc, circuit_count: usize, seed: u64, @@ -265,7 +268,7 @@ fn record_for_garbler_mt( /// Runs MT garbler only with replay context. async fn run_garbler_with_replay_mt( - exec: &mut Multithread, + exec: &mut Executor, circuit: Arc, circuit_count: usize, delta: Delta, diff --git a/crates/garble/src/protocol/semihonest/evaluator.rs b/crates/garble/src/protocol/semihonest/evaluator.rs index dfd7e3cd..bbdff9bc 100644 --- a/crates/garble/src/protocol/semihonest/evaluator.rs +++ b/crates/garble/src/protocol/semihonest/evaluator.rs @@ -232,39 +232,42 @@ where let (_, preprocessed) = ctx .try_join( - async move |ctx| { - // This flush is primarily intended to perform OT setup - // concurrently with preprocessing. - cot.flush(ctx).await.map_err(VmError::execute) + move |ctx| { + Box::pin(async move { + // This flush is primarily intended to perform OT setup + // concurrently with preprocessing. + cot.flush(ctx).await.map_err(VmError::execute) + }) }, - async move |ctx| { - let mut preprocessed = Vec::new(); - - while !call_stack.is_empty() { - let calls = take_preprocess_calls(&mut call_stack); - - // There must be at least one call ready for preprocessing - // in a non-empty call stack. - debug_assert!(!calls.is_empty()); - - let mut outputs = ctx - .map( - calls, - async move |ctx, (call, output): (Call, Slice)| { - let garbled_circuit = receive_garbled_circuit(ctx, call.circ()) - .await - .map_err(VmError::execute)?; - Ok::<_, VmError>((call, output, garbled_circuit)) - }, - |(call, _)| call.circ().and_count(), - ) - .await - .map_err(VmError::execute)?; - - preprocessed.append(&mut outputs); - } - - Ok::<_, VmError>(preprocessed) + move |ctx| { + Box::pin(async move { + let mut preprocessed = Vec::new(); + + while !call_stack.is_empty() { + let calls = take_preprocess_calls(&mut call_stack); + + // There must be at least one call ready for preprocessing + // in a non-empty call stack. + debug_assert!(!calls.is_empty()); + + let mut outputs = ctx + .map(calls, move |ctx, (call, output): (Call, Slice)| { + Box::pin(async move { + let garbled_circuit = + receive_garbled_circuit(ctx, call.circ()) + .await + .map_err(VmError::execute)?; + Ok::<_, VmError>((call, output, garbled_circuit)) + }) + }) + .await + .map_err(VmError::execute)?; + + preprocessed.append(&mut outputs); + } + + Ok::<_, VmError>(preprocessed) + }) }, ) .await @@ -310,13 +313,10 @@ where let store = self.store.clone(); let outputs = ctx - .map( - calls, - async move |ctx, (call, output): (Call, Slice)| { - evaluate(ctx, store.clone(), call, output).await - }, - |(call, _)| call.circ().and_count(), - ) + .map(calls, move |ctx, (call, output): (Call, Slice)| { + let store = store.clone(); + Box::pin(async move { evaluate(ctx, store, call, output).await }) + }) .await .map_err(VmError::execute)?; diff --git a/crates/garble/src/protocol/semihonest/garbler.rs b/crates/garble/src/protocol/semihonest/garbler.rs index 3355a48b..d03510e5 100644 --- a/crates/garble/src/protocol/semihonest/garbler.rs +++ b/crates/garble/src/protocol/semihonest/garbler.rs @@ -213,36 +213,42 @@ where .collect::>(); ctx.try_join( - async move |ctx| { - // This flush is primarily intended to perform OT setup - // concurrently with preprocessing. - cot.flush(ctx).await.map_err(VmError::execute) + move |ctx| { + Box::pin(async move { + // This flush is primarily intended to perform OT setup + // concurrently with preprocessing. + cot.flush(ctx).await.map_err(VmError::execute) + }) }, - async move |ctx| { - while !call_stack.is_empty() { - let calls = take_preprocess_calls(&mut call_stack); - - // There must be at least one call ready for preprocessing - // in a non-empty call stack. - debug_assert!(!calls.is_empty()); - - let store = store.clone(); - let outputs = ctx - .map( - calls, - async move |ctx: &mut Context, (call, output): (Call, Slice)| { - generate(ctx, store.clone(), delta, call, output, Mode::Preprocess) - .await - }, - |(call, _)| call.circ().and_count(), - ) - .await - .map_err(VmError::execute)?; - - outputs.into_iter().collect::>()?; - } - - Ok::<_, VmError>(()) + move |ctx| { + Box::pin(async move { + while !call_stack.is_empty() { + let calls = take_preprocess_calls(&mut call_stack); + + // There must be at least one call ready for preprocessing + // in a non-empty call stack. + debug_assert!(!calls.is_empty()); + + let store = store.clone(); + let outputs = ctx + .map( + calls, + move |ctx: &mut Context, (call, output): (Call, Slice)| { + let store = store.clone(); + Box::pin(async move { + generate(ctx, store, delta, call, output, Mode::Preprocess) + .await + }) + }, + ) + .await + .map_err(VmError::execute)?; + + outputs.into_iter().collect::>()?; + } + + Ok::<_, VmError>(()) + }) }, ) .await @@ -281,10 +287,12 @@ where let outputs = ctx .map( calls, - async move |ctx: &mut Context, (call, output): (Call, Slice)| { - generate(ctx, store.clone(), delta, call, output, Mode::Execute).await + move |ctx: &mut Context, (call, output): (Call, Slice)| { + let store = store.clone(); + Box::pin(async move { + generate(ctx, store, delta, call, output, Mode::Execute).await + }) }, - |(call, _)| call.circ().and_count(), ) .await .map_err(VmError::execute)?; diff --git a/crates/garble/src/store/evaluator.rs b/crates/garble/src/store/evaluator.rs index 43738a43..fe8024ae 100644 --- a/crates/garble/src/store/evaluator.rs +++ b/crates/garble/src/store/evaluator.rs @@ -130,18 +130,20 @@ where let expected_size = self.core.flush_view().garbler_flush_size(); let (flush, ()) = ctx .try_join( - async move |ctx| { - ctx.io_mut().with_limit(flush_size).send(flush).await?; - - // Adjust the limit to expected size. - let limit = ctx.io().limit().max(expected_size); - ctx.io_mut() - .with_limit(limit) - .expect_next() - .await - .map_err(Error::from) + move |ctx| { + Box::pin(async move { + ctx.io_mut().with_limit(flush_size).send(flush).await?; + + // Adjust the limit to expected size. + let limit = ctx.io().limit().max(expected_size); + ctx.io_mut() + .with_limit(limit) + .expect_next() + .await + .map_err(Error::from) + }) }, - async move |ctx| cot.flush(ctx).await.map_err(Error::cot), + move |ctx| Box::pin(async move { cot.flush(ctx).await.map_err(Error::cot) }), ) .await??; diff --git a/crates/garble/src/store/garbler.rs b/crates/garble/src/store/garbler.rs index 35dba70b..b8feb6b2 100644 --- a/crates/garble/src/store/garbler.rs +++ b/crates/garble/src/store/garbler.rs @@ -130,18 +130,20 @@ where let expected_size = self.core.flush_view().evaluator_flush_size(); let (flush, ()) = ctx .try_join( - async move |ctx| { - ctx.io_mut().with_limit(flush_size).send(flush).await?; - - // Adjust the limit to expected size. - let limit = ctx.io().limit().max(expected_size); - ctx.io_mut() - .with_limit(limit) - .expect_next() - .await - .map_err(Error::from) + move |ctx| { + Box::pin(async move { + ctx.io_mut().with_limit(flush_size).send(flush).await?; + + // Adjust the limit to expected size. + let limit = ctx.io().limit().max(expected_size); + ctx.io_mut() + .with_limit(limit) + .expect_next() + .await + .map_err(Error::from) + }) }, - async move |ctx| cot.flush(ctx).await.map_err(Error::cot), + move |ctx| Box::pin(async move { cot.flush(ctx).await.map_err(Error::cot) }), ) .await??; diff --git a/crates/ot/Cargo.toml b/crates/ot/Cargo.toml index 7454c525..e8ab4e66 100644 --- a/crates/ot/Cargo.toml +++ b/crates/ot/Cargo.toml @@ -13,7 +13,7 @@ name = "mpz_ot" default = [] rayon = ["mpz-ot-core/rayon"] ideal = ["mpz-common/ideal", "serde"] -test-utils = ["mpz-ot-core/test-utils"] +test-utils = ["mpz-ot-core/test-utils", "mpz-common/test-utils"] [dependencies] mpz-core.workspace = true diff --git a/crates/ot/benches/ferret_receiver.rs b/crates/ot/benches/ferret_receiver.rs index 50ece121..48c7b68f 100644 --- a/crates/ot/benches/ferret_receiver.rs +++ b/crates/ot/benches/ferret_receiver.rs @@ -8,10 +8,10 @@ use criterion::{Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_common::{ - Flush, + Executor, Flush, context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, - recording_st_context_with_limit, replay_mt_context_with_limit, replay_st_context, + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, }, }; use mpz_core::Block; @@ -141,8 +141,8 @@ struct RecordedDataMt { /// Runs the full Ferret protocol with MT contexts. /// Records sender->receiver messages. async fn run_protocol_record_sender_mt( - exec_sender: &mut Multithread, - exec_receiver: &mut Multithread, + exec_sender: &mut Executor, + exec_receiver: &mut Executor, config: FerretConfig, delta: Block, cot_seed: Block, @@ -213,7 +213,7 @@ fn record_for_receiver_mt(seed: u64) -> RecordedDataMt { } /// Runs MT receiver only with replay context. -async fn run_receiver_with_replay_mt(exec: &mut Multithread, data: &RecordedDataMt) { +async fn run_receiver_with_replay_mt(exec: &mut Executor, data: &RecordedDataMt) { let cot_recv = IdealRCOTReceiver::from_seed(data.cot_recv_seed); let config = bench_config(); let mut receiver = Receiver::new(config, data.receiver_seed, cot_recv); diff --git a/crates/ot/benches/ferret_sender.rs b/crates/ot/benches/ferret_sender.rs index ce58bfff..a3a45d87 100644 --- a/crates/ot/benches/ferret_sender.rs +++ b/crates/ot/benches/ferret_sender.rs @@ -8,10 +8,10 @@ use criterion::{Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_common::{ - Flush, + Executor, Flush, context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, - recording_st_context_with_limit, replay_mt_context_with_limit, replay_st_context, + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, }, }; use mpz_core::Block; @@ -146,8 +146,8 @@ struct RecordedDataMt { /// Runs the full Ferret protocol with MT contexts. /// Records receiver->sender messages. async fn run_protocol_record_receiver_mt( - exec_sender: &mut Multithread, - exec_receiver: &mut Multithread, + exec_sender: &mut Executor, + exec_receiver: &mut Executor, config: FerretConfig, delta: Block, cot_seed: Block, @@ -219,7 +219,7 @@ fn record_for_sender_mt(seed: u64) -> RecordedDataMt { } /// Runs MT sender only with replay context. -async fn run_sender_with_replay_mt(exec: &mut Multithread, data: &RecordedDataMt) { +async fn run_sender_with_replay_mt(exec: &mut Executor, data: &RecordedDataMt) { let cot_send = IdealRCOTSender::new(data.cot_seed, data.delta); let config = bench_config(); let mut sender = Sender::new(config, data.sender_seed, cot_send); diff --git a/crates/wasm-bench/src/garble/evaluator.rs b/crates/wasm-bench/src/garble/evaluator.rs index b0390509..f14ddaa2 100644 --- a/crates/wasm-bench/src/garble/evaluator.rs +++ b/crates/wasm-bench/src/garble/evaluator.rs @@ -9,8 +9,9 @@ use wasm_bindgen::prelude::*; #[cfg(target_arch = "wasm32")] use mpz_circuits::AES128; #[cfg(target_arch = "wasm32")] +use mpz_common::Executor; use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_spawn_and_limit, + RecordedMtData, recording_mt_context_with_spawn_and_limit, replay_mt_context_with_spawn_and_limit, }; #[cfg(target_arch = "wasm32")] @@ -44,8 +45,8 @@ fn max_frame_length(circuit: &mpz_circuits::Circuit, circuit_count: usize) -> us #[cfg(target_arch = "wasm32")] async fn run_protocol_record_garbler( - exec_gb: &mut Multithread, - exec_ev: &mut Multithread, + exec_gb: &mut Executor, + exec_ev: &mut Executor, circuit_count: usize, seed: u64, ) { @@ -141,7 +142,7 @@ async fn record_for_evaluator( } #[cfg(target_arch = "wasm32")] -async fn run_evaluator_with_replay(exec: &mut Multithread, circuit_count: usize) { +async fn run_evaluator_with_replay(exec: &mut Executor, circuit_count: usize) { let (_, cot_recv) = ideal_cot([0u8; 16].into()); let mut ev = Evaluator::new(cot_recv); diff --git a/crates/wasm-bench/src/garble/garbler.rs b/crates/wasm-bench/src/garble/garbler.rs index 6d8206bb..ba7f3032 100644 --- a/crates/wasm-bench/src/garble/garbler.rs +++ b/crates/wasm-bench/src/garble/garbler.rs @@ -9,8 +9,9 @@ use wasm_bindgen::prelude::*; #[cfg(target_arch = "wasm32")] use mpz_circuits::AES128; #[cfg(target_arch = "wasm32")] +use mpz_common::Executor; use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_spawn_and_limit, + RecordedMtData, recording_mt_context_with_spawn_and_limit, replay_mt_context_with_spawn_and_limit, }; #[cfg(target_arch = "wasm32")] @@ -44,8 +45,8 @@ fn max_frame_length(circuit: &mpz_circuits::Circuit, circuit_count: usize) -> us #[cfg(target_arch = "wasm32")] async fn run_protocol_record_evaluator( - exec_gb: &mut Multithread, - exec_ev: &mut Multithread, + exec_gb: &mut Executor, + exec_ev: &mut Executor, circuit_count: usize, seed: u64, ) { @@ -136,7 +137,7 @@ async fn record_for_garbler(circuit_count: usize, seed: u64, concurrency: usize) } #[cfg(target_arch = "wasm32")] -async fn run_garbler_with_replay(exec: &mut Multithread, circuit_count: usize, delta: Delta) { +async fn run_garbler_with_replay(exec: &mut Executor, circuit_count: usize, delta: Delta) { let (cot_send, _) = ideal_cot(delta.into_inner()); let mut gb = Garbler::new(cot_send, [0u8; 16], delta); diff --git a/crates/wasm-bench/src/lib.rs b/crates/wasm-bench/src/lib.rs index 86cb1112..240cf891 100644 --- a/crates/wasm-bench/src/lib.rs +++ b/crates/wasm-bench/src/lib.rs @@ -128,7 +128,7 @@ pub async fn test_mt_context_only() -> Result { use mpz_common::context::test_mt_context_with_spawn; use serio::{SinkExt, stream::IoStreamExt}; - let (mut mt1, mut mt2) = test_mt_context_with_spawn(8, |f| { + let (mt1, mt2) = test_mt_context_with_spawn(8, |f| { let _ = web_spawn::spawn(f); Ok(()) }); diff --git a/crates/wasm-bench/src/ot/ferret.rs b/crates/wasm-bench/src/ot/ferret.rs index b4d75acf..c9445b1d 100644 --- a/crates/wasm-bench/src/ot/ferret.rs +++ b/crates/wasm-bench/src/ot/ferret.rs @@ -49,8 +49,9 @@ fn bench_config() -> FerretConfig { // ============================================================================ #[cfg(target_arch = "wasm32")] +use mpz_common::Executor; use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_spawn_and_limit, + RecordedMtData, recording_mt_context_with_spawn_and_limit, replay_mt_context_with_spawn_and_limit, }; @@ -72,8 +73,8 @@ struct RecordedDataMt { #[cfg(target_arch = "wasm32")] #[allow(clippy::too_many_arguments)] async fn run_protocol_record_receiver_mt( - exec_sender: &mut Multithread, - exec_receiver: &mut Multithread, + exec_sender: &mut Executor, + exec_receiver: &mut Executor, config: FerretConfig, delta: Block, cot_seed: Block, @@ -149,7 +150,7 @@ async fn record_for_sender_mt(seed: u64, concurrency: usize, ot_count: usize) -> /// Runs MT sender only with replay context. #[cfg(target_arch = "wasm32")] -async fn run_sender_with_replay_mt(exec: &mut Multithread, data: &RecordedDataMt, ot_count: usize) { +async fn run_sender_with_replay_mt(exec: &mut Executor, data: &RecordedDataMt, ot_count: usize) { let (cot_send, _) = ideal_rcot(data.cot_seed, data.delta); let config = bench_config(); let mut sender = Sender::new(config, data.sender_seed, cot_send); diff --git a/crates/wasm-bench/src/zk/prover.rs b/crates/wasm-bench/src/zk/prover.rs index 79c1801b..510aeac4 100644 --- a/crates/wasm-bench/src/zk/prover.rs +++ b/crates/wasm-bench/src/zk/prover.rs @@ -7,8 +7,9 @@ use wasm_bindgen::prelude::*; use mpz_circuits::AES128; #[cfg(target_arch = "wasm32")] +use mpz_common::Executor; use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_spawn_and_limit, + RecordedMtData, recording_mt_context_with_spawn_and_limit, replay_mt_context_with_spawn_and_limit, }; use mpz_memory_core::{Array, binary::U8, correlated::Delta}; @@ -31,8 +32,8 @@ fn max_frame_length(circuit: &mpz_circuits::Circuit, circuit_count: usize) -> us /// Records verifier->prover messages. #[cfg(target_arch = "wasm32")] async fn run_protocol_record_verifier( - exec_p: &mut Multithread, - exec_v: &mut Multithread, + exec_p: &mut Executor, + exec_v: &mut Executor, seed: u64, circuit_count: usize, ) { @@ -133,7 +134,7 @@ async fn record_for_prover(seed: u64, circuit_count: usize, concurrency: usize) /// Runs prover only with replay context. #[cfg(target_arch = "wasm32")] -async fn run_prover_with_replay(exec: &mut Multithread, circuit_count: usize) { +async fn run_prover_with_replay(exec: &mut Executor, circuit_count: usize) { let (_, ot_recv) = ideal_rcot([0u8; 16].into(), [0u8; 16].into()); let prover_config = ProverConfig::builder().build().unwrap(); let mut prover = Prover::new(prover_config, ot_recv); diff --git a/crates/wasm-bench/src/zk/verifier.rs b/crates/wasm-bench/src/zk/verifier.rs index f90a2d1d..e4e98dcf 100644 --- a/crates/wasm-bench/src/zk/verifier.rs +++ b/crates/wasm-bench/src/zk/verifier.rs @@ -9,8 +9,9 @@ use std::sync::{Arc, Mutex}; use mpz_circuits::AES128; #[cfg(target_arch = "wasm32")] +use mpz_common::Executor; use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_spawn_and_limit, + RecordedMtData, recording_mt_context_with_spawn_and_limit, replay_mt_context_with_spawn_and_limit, }; use mpz_core::Block; @@ -34,8 +35,8 @@ fn max_frame_length(circuit: &mpz_circuits::Circuit, circuit_count: usize) -> us /// Records prover->verifier messages. #[cfg(target_arch = "wasm32")] async fn run_protocol_record_prover( - exec_p: &mut Multithread, - exec_v: &mut Multithread, + exec_p: &mut Executor, + exec_v: &mut Executor, seed: u64, circuit_count: usize, ) { @@ -149,7 +150,7 @@ async fn record_for_verifier( /// Runs verifier only with replay context. #[cfg(target_arch = "wasm32")] async fn run_verifier_with_replay( - exec: &mut Multithread, + exec: &mut Executor, circuit_count: usize, delta: Delta, ot_seed: Block, diff --git a/crates/zk/benches/prover.rs b/crates/zk/benches/prover.rs index 58ba765f..cd122287 100644 --- a/crates/zk/benches/prover.rs +++ b/crates/zk/benches/prover.rs @@ -7,9 +7,12 @@ use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_circuits::AES128; -use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, - replay_mt_context_with_limit, replay_st_context, +use mpz_common::{ + Executor, + context::{ + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, + }, }; use mpz_ot::ideal::rcot::ideal_rcot; use mpz_vm_core::{ @@ -162,8 +165,8 @@ async fn run_prover_with_replay(ctx: &mut mpz_common::Context, circuit_count: us /// Runs the full ZK protocol with MT contexts. async fn run_protocol_record_verifier_mt( - exec_p: &mut Multithread, - exec_v: &mut Multithread, + exec_p: &mut Executor, + exec_v: &mut Executor, circuit_count: usize, seed: u64, ) { @@ -253,7 +256,7 @@ fn record_for_prover_mt(circuit_count: usize, seed: u64) -> RecordedMtData { } /// Runs MT prover only with replay context. -async fn run_prover_with_replay_mt(exec: &mut Multithread, circuit_count: usize) { +async fn run_prover_with_replay_mt(exec: &mut Executor, circuit_count: usize) { let (_, ot_recv) = ideal_rcot([0u8; 16].into(), [0u8; 16].into()); let prover_config = ProverConfig::builder().build().unwrap(); let mut prover = Prover::new(prover_config, ot_recv); diff --git a/crates/zk/benches/verifier.rs b/crates/zk/benches/verifier.rs index d6c3905e..e4d33f79 100644 --- a/crates/zk/benches/verifier.rs +++ b/crates/zk/benches/verifier.rs @@ -8,9 +8,12 @@ use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; use futures::executor::block_on; use mpz_circuits::AES128; -use mpz_common::context::{ - Multithread, RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, - replay_mt_context_with_limit, replay_st_context, +use mpz_common::{ + Executor, + context::{ + RecordedMtData, recording_mt_context_with_limit, recording_st_context_with_limit, + replay_mt_context_with_limit, replay_st_context, + }, }; use mpz_core::Block; use mpz_ot::ideal::rcot::ideal_rcot; @@ -173,8 +176,8 @@ async fn run_verifier_with_replay( /// Runs the full ZK protocol with MT contexts. async fn run_protocol_record_prover_mt( - exec_p: &mut Multithread, - exec_v: &mut Multithread, + exec_p: &mut Executor, + exec_v: &mut Executor, circuit_count: usize, seed: u64, ) { @@ -270,7 +273,7 @@ fn record_for_verifier_mt(circuit_count: usize, seed: u64) -> (RecordedMtData, B /// Runs MT verifier only with replay context. async fn run_verifier_with_replay_mt( - exec: &mut Multithread, + exec: &mut Executor, circuit_count: usize, delta: Delta, ot_seed: Block, diff --git a/crates/zk/src/prover.rs b/crates/zk/src/prover.rs index 0e1d227c..f7f8f7d7 100644 --- a/crates/zk/src/prover.rs +++ b/crates/zk/src/prover.rs @@ -160,9 +160,8 @@ where } let outputs = ctx - .map( - tasks, - async move |ctx, (mut execute, output)| { + .map(tasks, move |ctx, (mut execute, output)| { + Box::pin(async move { let mut iter = execute.iter(); loop { // Stream the `adjust` bits to avoid buffering them in memory. @@ -178,9 +177,8 @@ where let output_macs = execute.finish().map_err(VmError::execute)?; Ok((output, output_macs)) - }, - |(execute, _)| execute.and_count(), - ) + }) + }) .await .map_err(VmError::execute)? .into_iter() diff --git a/crates/zk/src/verifier.rs b/crates/zk/src/verifier.rs index abf8cecc..390e3e90 100644 --- a/crates/zk/src/verifier.rs +++ b/crates/zk/src/verifier.rs @@ -158,9 +158,8 @@ where } let outputs = ctx - .map( - tasks, - async move |ctx, (mut execute, output)| { + .map(tasks, move |ctx, (mut execute, output)| { + Box::pin(async move { let mut consumer = execute.consumer(); while consumer.wants_adjust() { let adjust: BitVec = ctx.io_mut().expect_next().await?; @@ -172,9 +171,8 @@ where let output_keys = execute.finish().map_err(VmError::execute)?; Ok((output, output_keys)) - }, - |(execute, _)| execute.and_count(), - ) + }) + }) .await .map_err(VmError::execute)? .into_iter()