diff --git a/core/layers/throttle/src/lib.rs b/core/layers/throttle/src/lib.rs index 3e8d7b644d64..a4bbdce18d06 100644 --- a/core/layers/throttle/src/lib.rs +++ b/core/layers/throttle/src/lib.rs @@ -20,6 +20,7 @@ #![cfg_attr(docsrs, feature(doc_cfg))] #![deny(missing_docs)] +use std::future::Future; use std::num::NonZeroU32; use std::sync::Arc; @@ -32,6 +33,33 @@ use governor::state::NotKeyed; use opendal_core::raw::*; use opendal_core::*; +/// ThrottleRateLimiter abstracts a rate-limit primitive used by +/// [`ThrottleLayer`]. +pub trait ThrottleRateLimiter: Send + Sync + Clone + Unpin + 'static { + /// Block until `n` units of capacity are available. + /// + /// Returns an error when the request can never be satisfied, for + /// example when `n` exceeds the limiter's burst/capacity. + fn until_n_ready(&self, n: NonZeroU32) -> impl Future> + MaybeSend; +} + +/// Share an atomic RateLimiter instance across all threads in one operator. +/// If want to add more observability in the future, replace the default NoOpMiddleware with other middleware types. +/// Read more about [Middleware](https://docs.rs/governor/latest/governor/middleware/index.html) +pub type SharedRateLimiter = + Arc>; + +impl ThrottleRateLimiter for SharedRateLimiter { + async fn until_n_ready(&self, n: NonZeroU32) -> Result<()> { + self.as_ref().until_n_ready(n).await.map_err(|_| { + Error::new( + ErrorKind::RateLimited, + "burst size is smaller than the request size", + ) + }) + } +} + /// Add a bandwidth rate limiter to the underlying services. /// /// # Throttle @@ -67,12 +95,11 @@ use opendal_core::*; /// # } /// ``` #[derive(Clone)] -pub struct ThrottleLayer { - bandwidth: NonZeroU32, - burst: NonZeroU32, +pub struct ThrottleLayer { + rate_limiter: L, } -impl ThrottleLayer { +impl ThrottleLayer { /// Create a new `ThrottleLayer` with given bandwidth and burst. /// /// - bandwidth: the maximum number of bytes allowed to pass through per second. @@ -80,43 +107,65 @@ impl ThrottleLayer { pub fn new(bandwidth: u32, burst: u32) -> Self { assert!(bandwidth > 0); assert!(burst > 0); - Self { - bandwidth: NonZeroU32::new(bandwidth).unwrap(), - burst: NonZeroU32::new(burst).unwrap(), - } + let bandwidth = NonZeroU32::new(bandwidth).unwrap(); + let burst = NonZeroU32::new(burst).unwrap(); + let rate_limiter = Arc::new(RateLimiter::direct( + Quota::per_second(bandwidth).allow_burst(burst), + )); + Self { rate_limiter } + } +} + +impl ThrottleLayer { + /// Create a layer with any [`ThrottleRateLimiter`] implementation. + /// + /// ``` + /// # use std::num::NonZeroU32; + /// # use std::sync::Arc; + /// # use governor::Quota; + /// # use governor::RateLimiter; + /// # use opendal_layer_throttle::SharedRateLimiter; + /// # use opendal_layer_throttle::ThrottleLayer; + /// let limiter: SharedRateLimiter = Arc::new(RateLimiter::direct( + /// Quota::per_second(NonZeroU32::new(1024).unwrap()) + /// .allow_burst(NonZeroU32::new(1024 * 1024).unwrap()), + /// )); + /// let _layer = ThrottleLayer::with_limiter(limiter); + /// ``` + pub fn with_limiter(rate_limiter: L) -> Self { + Self { rate_limiter } } } -impl Layer for ThrottleLayer { - type LayeredAccess = ThrottleAccessor; +impl Layer for ThrottleLayer { + type LayeredAccess = ThrottleAccessor; fn layer(&self, inner: A) -> Self::LayeredAccess { - let rate_limiter = Arc::new(RateLimiter::direct( - Quota::per_second(self.bandwidth).allow_burst(self.burst), - )); ThrottleAccessor { inner, - rate_limiter, + rate_limiter: self.rate_limiter.clone(), } } } -/// Share an atomic RateLimiter instance across all threads in one operator. -/// If want to add more observability in the future, replace the default NoOpMiddleware with other middleware types. -/// Read more about [Middleware](https://docs.rs/governor/latest/governor/middleware/index.html) -type SharedRateLimiter = Arc>; - #[doc(hidden)] -#[derive(Debug)] -pub struct ThrottleAccessor { +pub struct ThrottleAccessor { inner: A, - rate_limiter: SharedRateLimiter, + rate_limiter: L, } -impl LayeredAccess for ThrottleAccessor { +impl std::fmt::Debug for ThrottleAccessor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ThrottleAccessor") + .field("inner", &self.inner) + .finish_non_exhaustive() + } +} + +impl LayeredAccess for ThrottleAccessor { type Inner = A; - type Reader = ThrottleWrapper; - type Writer = ThrottleWrapper; + type Reader = ThrottleWrapper; + type Writer = ThrottleWrapper; type Lister = A::Lister; type Deleter = A::Deleter; @@ -152,27 +201,24 @@ impl LayeredAccess for ThrottleAccessor { } #[doc(hidden)] -pub struct ThrottleWrapper { +pub struct ThrottleWrapper { inner: R, - limiter: SharedRateLimiter, + limiter: L, } -impl ThrottleWrapper { - fn new(inner: R, rate_limiter: SharedRateLimiter) -> Self { - Self { - inner, - limiter: rate_limiter, - } +impl ThrottleWrapper { + fn new(inner: R, limiter: L) -> Self { + Self { inner, limiter } } } -impl oio::Read for ThrottleWrapper { +impl oio::Read for ThrottleWrapper { async fn read(&mut self) -> Result { self.inner.read().await } } -impl oio::Write for ThrottleWrapper { +impl oio::Write for ThrottleWrapper { async fn write(&mut self, bs: Buffer) -> Result<()> { let len = bs.len(); if len == 0 { @@ -189,12 +235,7 @@ impl oio::Write for ThrottleWrapper { let buf_length = NonZeroU32::new(len as u32).expect("len is non-zero so NonZeroU32 must exist"); - self.limiter.until_n_ready(buf_length).await.map_err(|_| { - Error::new( - ErrorKind::RateLimited, - "burst size is smaller than the request size", - ) - })?; + self.limiter.until_n_ready(buf_length).await?; self.inner.write(bs).await }