Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 21 additions & 105 deletions futures-util/src/future/future/shared.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::task::{waker_ref, ArcWake};
use crate::task::waker_ref;
use crate::wakerset::{WakerKey, WakerSet};
use alloc::sync::{Arc, Weak};
use core::cell::UnsafeCell;
use core::fmt;
Expand All @@ -8,29 +9,19 @@ use core::ptr;
use core::sync::atomic::AtomicUsize;
use core::sync::atomic::Ordering::{Acquire, SeqCst};
use futures_core::future::{FusedFuture, Future};
use futures_core::task::{Context, Poll, Waker};
use slab::Slab;

#[cfg(feature = "std")]
type Mutex<T> = std::sync::Mutex<T>;
#[cfg(not(feature = "std"))]
type Mutex<T> = spin::Mutex<T>;
use futures_core::task::{Context, Poll};

/// Future for the [`shared`](super::FutureExt::shared) method.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Shared<Fut: Future> {
inner: Option<Arc<Inner<Fut>>>,
waker_key: usize,
waker_key: WakerKey,
}

struct Inner<Fut: Future> {
future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
notifier: Arc<Notifier>,
}

struct Notifier {
state: AtomicUsize,
wakers: Mutex<Option<Slab<Option<Waker>>>>,
notifier: Arc<WakerSet>,
}

/// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
Expand Down Expand Up @@ -87,19 +78,15 @@ const POLLING: usize = 1;
const COMPLETE: usize = 2;
const POISONED: usize = 3;

const NULL_WAKER_KEY: usize = usize::MAX;

impl<Fut: Future> Shared<Fut> {
pub(super) fn new(future: Fut) -> Self {
let inner = Inner {
future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
notifier: Arc::new(Notifier {
state: AtomicUsize::new(IDLE),
wakers: Mutex::new(Some(Slab::new())),
}),
state: AtomicUsize::new(IDLE),
notifier: Arc::new(WakerSet::new()),
};

Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
Self { inner: Some(Arc::new(inner)), waker_key: WakerKey::NULL }
}
}

Expand All @@ -113,7 +100,7 @@ where
/// [`poll`](Future::poll).
pub fn peek(&self) -> Option<&Fut::Output> {
if let Some(inner) = self.inner.as_ref() {
match inner.notifier.state.load(SeqCst) {
match inner.state.load(SeqCst) {
COMPLETE => unsafe { return Some(inner.output()) },
POISONED => panic!("inner future panicked during poll"),
_ => {}
Expand Down Expand Up @@ -207,34 +194,6 @@ where
Fut: Future,
Fut::Output: Clone,
{
/// Registers the current task to receive a wakeup when we are awoken.
fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
#[cfg(feature = "std")]
let mut wakers_guard = self.notifier.wakers.lock().unwrap();
#[cfg(not(feature = "std"))]
let mut wakers_guard = self.notifier.wakers.lock();

let wakers_mut = wakers_guard.as_mut();

let wakers = match wakers_mut {
Some(wakers) => wakers,
None => return,
};

let new_waker = cx.waker();

if *waker_key == NULL_WAKER_KEY {
*waker_key = wakers.insert(Some(new_waker.clone()));
} else {
match wakers[*waker_key] {
Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
// Could use clone_from here, but Waker doesn't specialize it.
ref mut slot => *slot = Some(new_waker.clone()),
}
}
debug_assert!(*waker_key != NULL_WAKER_KEY);
}

/// Safety: callers must first ensure that `inner.state`
/// is `COMPLETE`
unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
Expand Down Expand Up @@ -271,19 +230,14 @@ where
let inner = this.inner.take().expect("Shared future polled again after completion");

// Fast path for when the wrapped future has already completed
if inner.notifier.state.load(Acquire) == COMPLETE {
if inner.state.load(Acquire) == COMPLETE {
// Safety: We're in the COMPLETE state
return unsafe { Poll::Ready(inner.take_or_clone_output()) };
}

inner.record_waker(&mut this.waker_key, cx);
inner.notifier.record_waker(&mut this.waker_key, cx);

match inner
.notifier
.state
.compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
.unwrap_or_else(|x| x)
{
match inner.state.compare_exchange(IDLE, POLLING, SeqCst, SeqCst).unwrap_or_else(|x| x) {
IDLE => {
// Lock acquired, fall through
}
Expand Down Expand Up @@ -317,7 +271,7 @@ where
}
}

let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
let mut reset = Reset { state: &inner.state, did_not_panic: false };

let output = {
let future = unsafe {
Expand All @@ -332,8 +286,7 @@ where

match poll_result {
Poll::Pending => {
if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
{
if inner.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok() {
// Success
drop(reset);
this.inner = Some(inner);
Expand All @@ -350,21 +303,12 @@ where
*inner.future_or_output.get() = FutureOrOutput::Output(output);
}

inner.notifier.state.store(COMPLETE, SeqCst);
inner.state.store(COMPLETE, SeqCst);

// Wake all tasks and drop the slab
#[cfg(feature = "std")]
let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
#[cfg(not(feature = "std"))]
let mut wakers_guard = inner.notifier.wakers.lock();

let mut wakers = wakers_guard.take().unwrap();
for waker in wakers.drain().flatten() {
waker.wake();
}
// Wake all tasks
inner.notifier.wake_and_finish();

drop(reset); // Make borrow checker happy
drop(wakers_guard);

// Safety: We're in the COMPLETE state
unsafe { Poll::Ready(inner.take_or_clone_output()) }
Expand All @@ -376,7 +320,7 @@ where
Fut: Future,
{
fn clone(&self) -> Self {
Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
Self { inner: self.inner.clone(), waker_key: WakerKey::NULL }
}
}

Expand All @@ -385,36 +329,8 @@ where
Fut: Future,
{
fn drop(&mut self) {
if self.waker_key != NULL_WAKER_KEY {
if let Some(ref inner) = self.inner {
#[cfg(feature = "std")]
if let Ok(mut wakers) = inner.notifier.wakers.lock() {
if let Some(wakers) = wakers.as_mut() {
wakers.remove(self.waker_key);
}
}
#[cfg(not(feature = "std"))]
if let Some(wakers) = inner.notifier.wakers.lock().as_mut() {
wakers.remove(self.waker_key);
}
}
}
}
}

impl ArcWake for Notifier {
fn wake_by_ref(arc_self: &Arc<Self>) {
#[cfg(feature = "std")]
let wakers = &mut *arc_self.wakers.lock().unwrap();
#[cfg(not(feature = "std"))]
let wakers = &mut *arc_self.wakers.lock();

if let Some(wakers) = wakers.as_mut() {
for (_key, opt_waker) in wakers {
if let Some(waker) = opt_waker.take() {
waker.wake();
}
}
if let Some(ref inner) = self.inner {
inner.notifier.unregister(self.waker_key)
}
}
}
Expand All @@ -425,6 +341,6 @@ impl<Fut: Future> WeakShared<Fut> {
/// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
/// to completion.
pub fn upgrade(&self) -> Option<Shared<Fut>> {
Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
Some(Shared { inner: Some(self.0.upgrade()?), waker_key: WakerKey::NULL })
}
}
2 changes: 2 additions & 0 deletions futures-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,3 +328,5 @@ mod abortable;

mod fns;
mod unfold_state;
#[cfg(any(feature = "std", all(feature = "alloc", feature = "spin")))]
mod wakerset;
3 changes: 3 additions & 0 deletions futures-util/src/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub use self::stream::Chunks;
#[cfg(feature = "alloc")]
pub use self::stream::ReadyChunks;

#[cfg(any(feature = "std", all(feature = "alloc", feature = "spin")))]
pub use self::stream::Shared;

#[cfg(feature = "sink")]
#[cfg_attr(docsrs, doc(cfg(feature = "sink")))]
pub use self::stream::Forward;
Expand Down
73 changes: 73 additions & 0 deletions futures-util/src/stream/stream/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ mod catch_unwind;
#[cfg(feature = "std")]
pub use self::catch_unwind::CatchUnwind;

#[cfg(feature = "alloc")]
mod shared;
#[cfg(feature = "alloc")]
pub use self::shared::Shared;

impl<T: ?Sized> StreamExt for T where T: Stream {}

/// An extension trait for `Stream`s that provides a variety of convenient
Expand Down Expand Up @@ -1469,6 +1474,74 @@ pub trait StreamExt: Stream {
assert_stream::<Self::Item, _>(Box::pin(self))
}

/// Create a cloneable handle to this stream where all handles will resolve
/// to the same result.
///
/// The shared() method provides a method to convert any stream into a
/// cloneable stream. It enables a stream to be polled by multiple threads.
///
/// This method is only available when the `std` feature of this library is
/// activiated, and it is activated by default.
///
/// # Panics
/// If the capacity is zero. It must have space for at least one item.
///
/// If the capacity is too large. The maximum size may be les than `usize::MAX`.
///
/// # Examples
///
/// ```
/// use futures::executor::block_on;
/// use futures::stream::{self, StreamExt};
///
/// let stream = stream::iter(1..=3);
/// let shared1 = stream.shared(4);
/// let shared2 = shared1.clone();
///
/// assert_eq!(vec![1,2,3], block_on(shared1.collect::<Vec<_>>()));
/// assert_eq!(vec![1,2,3], block_on(shared2.collect::<Vec<_>>()));
/// ```
///
/// ```
/// use futures::executor::block_on;
/// use futures::stream::{self, StreamExt};
/// use std::thread;
///
/// let stream = stream::iter(1..=3);
/// let shared1 = stream.shared(4);
/// let shared2 = shared1.clone();
/// let join_handle = thread::spawn(move || {
/// assert_eq!(vec![1,2,3], block_on(shared2.collect::<Vec<_>>()));
/// });
/// assert_eq!(vec![1,2,3], block_on(shared1.collect::<Vec<_>>()));
/// join_handle.join().unwrap();
/// ```
///
/// ```
/// # futures::executor::block_on(async {
/// use futures::stream::{self, StreamExt};
///
/// let stream = stream::iter(vec![1,2,3]);
/// let mut shared1 = stream.shared(4);
///
/// assert_eq!(Some(1), shared1.next().await);
///
/// let mut shared2 = shared1.clone();
/// assert_eq!(Some(2), shared2.next().await);
/// assert_eq!(Some(3), shared2.next().await);
/// assert_eq!(vec![2,3], shared1.collect::<Vec<_>>().await);
/// assert_eq!(None, shared2.next().await);
/// # });
/// ```
#[cfg(feature = "std")]
fn shared(self, capacity: usize) -> Shared<Self>
where
Self: Sized,
Self::Item: Clone,
{
Shared::new(self, capacity)
}

/// An adaptor for creating a buffered list of pending futures.
///
/// If this stream's item can be converted into a future, then this adaptor
Expand Down
Loading
Loading