diff --git a/Cargo.toml b/Cargo.toml index b92e4d7e..e1b4eced 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -55,6 +55,7 @@ md-5 = { version = "0.11.0", default-features = false, optional = true } quick-xml = { version = "0.39.0", features = ["serialize", "overlapped-lists"], optional = true } rand = { version = "0.10", default-features = false, features = ["std", "std_rng", "thread_rng"], optional = true } reqwest = { version = "0.12", default-features = false, features = ["rustls-tls-native-roots", "http2"], optional = true } +reqwest-middleware = { version = "0.4", optional = true } ring = { version = "0.17", default-features = false, features = ["std"], optional = true } rustls-pki-types = { version = "1.9", default-features = false, features = ["std"], optional = true } serde = { version = "1.0", default-features = false, features = ["derive"], optional = true } @@ -90,6 +91,7 @@ http = ["cloud"] tls-webpki-roots = ["reqwest?/rustls-tls-webpki-roots"] integration = ["rand", "tokio"] tokio = ["dep:tokio", "dep:tracing"] +reqwest-middleware = ["dep:reqwest-middleware", "cloud"] [dev-dependencies] # In alphabetical order futures-executor = "0.3" diff --git a/src/client/http/connection.rs b/src/client/http/connection.rs index 69c8436d..ec5961e4 100644 --- a/src/client/http/connection.rs +++ b/src/client/http/connection.rs @@ -136,6 +136,22 @@ impl HttpError { pub fn kind(&self) -> HttpErrorKind { self.kind } + + /// Build an [`HttpError`] from a [`reqwest_middleware::Error`]. + /// + /// `Reqwest(_)` variants delegate to [`HttpError::reqwest`] for the + /// existing classification logic; opaque `Middleware(_)` variants + /// fall back to [`HttpErrorKind::Unknown`]. + #[cfg(all(feature = "reqwest-middleware", not(target_arch = "wasm32")))] + pub(crate) fn reqwest_middleware(e: reqwest_middleware::Error) -> Self { + match e { + reqwest_middleware::Error::Reqwest(re) => Self::reqwest(re), + reqwest_middleware::Error::Middleware(me) => Self { + kind: HttpErrorKind::Unknown, + source: me.into(), + }, + } + } } /// An asynchronous function from a [`HttpRequest`] to a [`HttpResponse`]. diff --git a/src/client/http/mod.rs b/src/client/http/mod.rs index 86e1e11d..4a78a5e9 100644 --- a/src/client/http/mod.rs +++ b/src/client/http/mod.rs @@ -25,3 +25,8 @@ pub use connection::*; mod spawn; pub use spawn::*; + +// `reqwest-middleware` is not supported on wasm32: its `Middleware` trait +// requires `Send + Sync` whereas the wasm `HttpService` ecosystem uses `?Send`. +#[cfg(all(feature = "reqwest-middleware", not(target_arch = "wasm32")))] +mod reqwest_middleware; diff --git a/src/client/http/reqwest_middleware.rs b/src/client/http/reqwest_middleware.rs new file mode 100644 index 00000000..2b3fa360 --- /dev/null +++ b/src/client/http/reqwest_middleware.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `HttpService` impl for [`reqwest_middleware::ClientWithMiddleware`]. +//! +//! Mirrors the [`reqwest::Client`] impl in `connection.rs` so consumers +//! that have already composed a `reqwest_middleware` middleware stack +//! (e.g. `reqwest_tracing::TracingMiddleware`) can hand it directly to +//! [`crate::client::HttpClient::new`]. + +use crate::client::{HttpError, HttpRequest, HttpResponse, HttpResponseBody, HttpService}; +use async_trait::async_trait; +use http_body_util::BodyExt; + +#[async_trait] +impl HttpService for reqwest_middleware::ClientWithMiddleware { + async fn call(&self, req: HttpRequest) -> Result { + let (parts, body) = req.into_parts(); + + let url = parts.uri.to_string().parse().unwrap(); + let mut req = reqwest::Request::new(parts.method, url); + *req.headers_mut() = parts.headers; + *req.body_mut() = Some(body.into_reqwest()); + + let r = self + .execute(req) + .await + .map_err(HttpError::reqwest_middleware)?; + let res: http::Response = r.into(); + let (parts, body) = res.into_parts(); + + let body = HttpResponseBody::new(body.map_err(HttpError::reqwest)); + Ok(HttpResponse::from_parts(parts, body)) + } +} + +#[cfg(test)] +mod tests { + use crate::RetryConfig; + use crate::client::HttpClient; + use crate::client::mock_server::MockServer; + use crate::client::retry::RetryExt; + use http::HeaderValue; + use hyper::Response; + use reqwest::Request; + use reqwest_middleware::{ClientBuilder, Middleware, Next}; + use std::sync::Arc; + use std::sync::atomic::{AtomicBool, Ordering}; + + /// Middleware that injects a known header so the test can assert + /// the middleware stack actually ran for each request. + struct InjectHeader; + + #[async_trait::async_trait] + impl Middleware for InjectHeader { + async fn handle( + &self, + mut req: Request, + extensions: &mut http::Extensions, + next: Next<'_>, + ) -> reqwest_middleware::Result { + req.headers_mut() + .insert("x-test-middleware", HeaderValue::from_static("hit")); + next.run(req, extensions).await + } + } + + #[tokio::test] + async fn client_with_middleware_runs_middleware() { + let mock = MockServer::new().await; + + let saw_header = Arc::new(AtomicBool::new(false)); + let saw_header_in_handler = Arc::clone(&saw_header); + mock.push_fn(move |req| { + if req.headers().get("x-test-middleware") == Some(&HeaderValue::from_static("hit")) { + saw_header_in_handler.store(true, Ordering::SeqCst); + } + Response::new("BANANAS".to_string()) + }); + + let inner = reqwest::Client::new(); + let client_with_mw = ClientBuilder::new(inner).with(InjectHeader).build(); + let http_client = HttpClient::new(client_with_mw); + + let url = mock.url().to_string(); + let retry = RetryConfig::default(); + let resp = http_client.get(url).send_retry(&retry).await.unwrap(); + let payload = resp.into_body().bytes().await.unwrap(); + assert_eq!(payload.as_ref(), b"BANANAS"); + + assert!( + saw_header.load(Ordering::SeqCst), + "middleware did not run before the server saw the request", + ); + } +}