diff --git a/lib/saluki-components/src/common/datadog/io.rs b/lib/saluki-components/src/common/datadog/io.rs index da95de0cf18..8427ca6cee7 100644 --- a/lib/saluki-components/src/common/datadog/io.rs +++ b/lib/saluki-components/src/common/datadog/io.rs @@ -83,7 +83,8 @@ where /// Transaction forwarder for Datadog endpoints. pub struct TransactionForwarder { context: ComponentContext, - config: ForwarderConfiguration, + config: ForwarderConfiguration, // static snapshot of forwarder settings + live_config: Option, // runtime-mutable configuration telemetry: ComponentTelemetry, metrics_builder: MetricsBuilder, client: HttpClient, @@ -141,13 +142,13 @@ where { /// Creates a new `TransactionForwarder` instance from the given configuration. pub fn from_config( - context: ComponentContext, config: ForwarderConfiguration, configuration: Option, + context: ComponentContext, config: ForwarderConfiguration, live_config: Option, endpoint_name: F, telemetry: ComponentTelemetry, metrics_builder: MetricsBuilder, ) -> Result where F: Fn(&Uri) -> Option + Send + Sync + 'static, { - let endpoints = config.endpoint().build_resolved_endpoints(configuration)?; + let endpoints = config.endpoint().build_resolved_endpoints(live_config.clone())?; let mut client_builder = HttpClient::builder() .with_request_timeout(config.request_timeout()) .with_bytes_sent_counter(telemetry.bytes_sent().clone()) @@ -167,6 +168,7 @@ where Ok(Self { context, config, + live_config, telemetry, metrics_builder, client, @@ -186,6 +188,7 @@ where let Self { context, config, + live_config, telemetry, metrics_builder, client, @@ -200,6 +203,7 @@ where io_shutdown_tx, context, config, + live_config, client, telemetry, metrics_builder, @@ -214,10 +218,12 @@ where } } +#[allow(clippy::too_many_arguments)] async fn run_io_loop( mut transactions_rx: mpsc::Receiver>, io_shutdown_tx: oneshot::Sender<()>, - context: ComponentContext, config: ForwarderConfiguration, service: HttpClient, telemetry: ComponentTelemetry, - metrics_builder: MetricsBuilder, resolved_endpoints: Vec, + context: ComponentContext, config: ForwarderConfiguration, live_config: Option, + service: HttpClient, telemetry: ComponentTelemetry, metrics_builder: MetricsBuilder, + resolved_endpoints: Vec, ) where B: Body + Buf + Clone + Send + Sync + 'static, B::Data: Send, @@ -243,6 +249,7 @@ async fn run_io_loop( task_barrier, context.clone(), config.clone(), + live_config.clone(), service.clone(), telemetry.clone(), txnq_telemetry, @@ -279,10 +286,11 @@ async fn run_io_loop( let _ = io_shutdown_tx.send(()); } +#[allow(clippy::too_many_arguments)] async fn run_endpoint_io_loop( mut txns_rx: mpsc::Receiver>, task_barrier: Arc, context: ComponentContext, - config: ForwarderConfiguration, service: HttpClient, telemetry: ComponentTelemetry, - txnq_telemetry: TransactionQueueTelemetry, endpoint: ResolvedEndpoint, + config: ForwarderConfiguration, live_config: Option, service: HttpClient, + telemetry: ComponentTelemetry, txnq_telemetry: TransactionQueueTelemetry, endpoint: ResolvedEndpoint, ) where B: Body + Buf + Clone + Send + Sync + 'static, B::Data: Send, @@ -311,7 +319,7 @@ async fn run_endpoint_io_loop( .map_request(with_version_info()) .concurrency_limit(config.endpoint_concurrency()) .layer(RetryCircuitBreakerLayer::new( - config.retry().to_default_http_retry_policy(), + config.retry().to_default_http_retry_policy(live_config.clone()), )) .map_request(|req: Request>| req.map(into_client_body)) .service(service); @@ -646,23 +654,33 @@ impl PendingTransactions { #[cfg(test)] mod tests { - use std::sync::{Arc, OnceLock}; + use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, OnceLock, + }; use bytes::Bytes; + use http::StatusCode; use http_body_util::Empty; use rcgen::{generate_simple_self_signed, CertifiedKey}; use rustls::{ pki_types::{PrivateKeyDer, PrivatePkcs8KeyDer}, RootCertStore, ServerConfig, }; + use saluki_common::buf::FrozenChunkedBytesBuffer; + use saluki_config::ConfigurationLoader; + use saluki_core::{observability::ComponentMetricsExt as _, topology::ComponentId}; + use serde_json::json; use tokio::{ io::{AsyncReadExt, AsyncWriteExt}, + net::TcpListener, sync::mpsc, time::{timeout, Duration}, }; use tokio_rustls::TlsAcceptor; use super::*; + use crate::common::datadog::transaction::{Metadata as TxnMetadata, Transaction}; fn forwarder_config_from_value(value: serde_json::Value) -> ForwarderConfiguration { serde_json::from_value(value).expect("ForwarderConfiguration should deserialize") @@ -809,4 +827,227 @@ mod tests { .expect("HTTPS request channel closed"); assert!(received_request.starts_with("GET / HTTP/1.1")); } + + /// Mode controlling what status codes the recording HTTP server returns to incoming requests. + enum ServerMode { + /// Always respond with the given status code. + AlwaysStatus(StatusCode), + /// Respond with each status code from the sequence in turn; once exhausted, respond with the final code forever. + StatusSequence(Vec), + } + + /// Starts a minimal HTTP server on `127.0.0.1:0` that records each request and replies based on `mode`. + /// + /// Returns the server's `http://127.0.0.1:PORT/` URL and a counter that increments once per accepted/processed + /// connection (one connection per request, since the server replies with `Connection: close`). + async fn start_recording_http_server(mode: ServerMode) -> (String, Arc) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let counter = Arc::new(AtomicUsize::new(0)); + + let mode = Arc::new(mode); + let counter_for_task = Arc::clone(&counter); + tokio::spawn(async move { + loop { + let (mut stream, _) = match listener.accept().await { + Ok(pair) => pair, + Err(_) => return, + }; + let mode = Arc::clone(&mode); + let counter = Arc::clone(&counter_for_task); + + tokio::spawn(async move { + let mut request = Vec::new(); + let mut buf = [0u8; 1024]; + loop { + match stream.read(&mut buf).await { + Ok(0) => return, + Ok(n) => { + request.extend_from_slice(&buf[..n]); + if request.windows(4).any(|window| window == b"\r\n\r\n") { + break; + } + } + Err(_) => return, + } + } + + // Drain any body bytes that arrived alongside the headers, plus whatever remains based on a + // simple Content-Length parse. We don't actually need to buffer it; we just need to consume it + // so the client doesn't get a connection reset before reading our response. + let request_str = String::from_utf8_lossy(&request).into_owned(); + let content_length = parse_content_length(&request_str).unwrap_or(0); + let header_end = request + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(request.len(), |idx| idx + 4); + let mut already_read_body = request.len().saturating_sub(header_end); + while already_read_body < content_length { + match stream.read(&mut buf).await { + Ok(0) => break, + Ok(n) => already_read_body += n, + Err(_) => return, + } + } + + let nth = counter.fetch_add(1, Ordering::SeqCst); + let status = match mode.as_ref() { + ServerMode::AlwaysStatus(s) => *s, + ServerMode::StatusSequence(seq) => { + let idx = nth.min(seq.len() - 1); + seq[idx] + } + }; + + let response = format!( + "HTTP/1.1 {} {}\r\nContent-Length: 0\r\nConnection: close\r\n\r\n", + status.as_u16(), + status.canonical_reason().unwrap_or(""), + ); + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + }); + } + }); + + (format!("http://127.0.0.1:{port}/"), counter) + } + + fn parse_content_length(request: &str) -> Option { + for line in request.lines() { + if let Some(value) = line + .strip_prefix("Content-Length:") + .or_else(|| line.strip_prefix("content-length:")) + { + return value.trim().parse().ok(); + } + } + None + } + + fn build_test_forwarder( + forwarder_url: &str, live_config: Option, + ) -> TransactionForwarder { + // The HTTP client builder requires the process-wide TLS crypto provider to be initialized, even when the + // forwarder is pointed at a plain HTTP endpoint. + init_tls_crypto_provider(); + + // Tight timeouts and small backoffs keep the test under a couple seconds even with retries. + let value = serde_json::json!({ + "api_key": "test-api-key", + "dd_url": forwarder_url, + "forwarder_timeout": 1u64, + "forwarder_num_workers": 1usize, + "forwarder_high_prio_buffer_size": 4usize, + "forwarder_backoff_base": 0.001, + "forwarder_backoff_max": 0.01, + "forwarder_backoff_factor": 2.0, + "forwarder_recovery_interval": 1u32, + "forwarder_recovery_reset": false, + // The HTTP client builder otherwise requires the process-wide default root certificate store to be + // populated. We are talking to a plain HTTP endpoint anyway, so disable validation to skip that path. + "skip_ssl_validation": true, + }); + let forwarder_config = forwarder_config_from_value(value); + let context = + ComponentContext::forwarder(ComponentId::try_from("test_forwarder").expect("component ID should be valid")); + let metrics_builder = MetricsBuilder::from_component_context(&context); + let telemetry = ComponentTelemetry::from_builder(&metrics_builder); + + TransactionForwarder::::from_config( + context, + forwarder_config, + live_config, + |_uri: &Uri| None, + telemetry, + metrics_builder, + ) + .expect("forwarder should build") + } + + fn build_test_transaction() -> Transaction { + let body = FrozenChunkedBytesBuffer::from(Bytes::from_static(b"test-payload")); + let request = http::Request::builder() + .method("POST") + // The endpoint middleware rewrites the authority to point at our `dd_url`, but preserves the path. Use a + // path that is not the special-cased `/api/v2/logs` or `/api/v0.2/{traces,stats}` routes, so the request + // is dispatched to the configured `dd_url` host directly. + .uri("http://placeholder/api/v2/series") + .body(body) + .expect("request should build"); + Transaction::from_original(TxnMetadata::from_event_count(1), request) + } + + async fn config_with(values: serde_json::Value) -> GenericConfiguration { + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + config + } + + async fn wait_for_count_at_least(counter: &Arc, target: usize, deadline: Duration) -> usize { + let start = std::time::Instant::now(); + loop { + let current = counter.load(Ordering::SeqCst); + if current >= target { + return current; + } + if start.elapsed() > deadline { + return current; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + } + + #[tokio::test] + async fn forwarder_does_not_retry_403_without_secrets() { + let (server_url, counter) = start_recording_http_server(ServerMode::AlwaysStatus(StatusCode::FORBIDDEN)).await; + let live_config = config_with(json!({})).await; + let forwarder = build_test_forwarder(&server_url, Some(live_config)); + + let handle = forwarder.spawn().await; + handle + .send_transaction(build_test_transaction()) + .await + .expect("send should succeed"); + + // Wait for the first 403 to be observed, then give the forwarder a generous window to perform any + // (incorrect) additional retries. + let first = wait_for_count_at_least(&counter, 1, Duration::from_secs(2)).await; + assert!(first >= 1, "server should have received at least one request"); + tokio::time::sleep(Duration::from_millis(200)).await; + + // Tear down before asserting so any pending in-flight call is drained. + handle.shutdown().await; + + let final_count = counter.load(Ordering::SeqCst); + assert_eq!( + final_count, 1, + "without secrets configured, 403 must not be retried (saw {} requests)", + final_count + ); + } + + #[tokio::test] + async fn forwarder_retries_403_when_secrets_in_use() { + // The server returns 403 to the first request and 200 to every subsequent request; the forwarder must drive + // at least one retry to observe the second request. + let (server_url, counter) = + start_recording_http_server(ServerMode::StatusSequence(vec![StatusCode::FORBIDDEN, StatusCode::OK])).await; + let live_config = config_with(json!({ "secret_backend_command": "/bin/true" })).await; + let forwarder = build_test_forwarder(&server_url, Some(live_config)); + + let handle = forwarder.spawn().await; + handle + .send_transaction(build_test_transaction()) + .await + .expect("send should succeed"); + + let observed = wait_for_count_at_least(&counter, 2, Duration::from_secs(3)).await; + handle.shutdown().await; + + assert!( + observed >= 2, + "with secrets configured, 403 must be retried at least once (saw {} requests)", + observed + ); + } } diff --git a/lib/saluki-components/src/common/datadog/retry.rs b/lib/saluki-components/src/common/datadog/retry.rs index de3215b97cd..f25325aff23 100644 --- a/lib/saluki-components/src/common/datadog/retry.rs +++ b/lib/saluki-components/src/common/datadog/retry.rs @@ -1,11 +1,15 @@ use std::{ path::{Path, PathBuf}, + sync::Arc, time::Duration, }; use facet::Facet; +use http::StatusCode; use saluki_config::GenericConfiguration; -use saluki_io::net::util::retry::{DefaultHttpRetryPolicy, ExponentialBackoff}; +use saluki_io::net::util::retry::{ + DefaultHttpRetryPolicy, ExponentialBackoff, StandardHttpClassifier, StatusCodeRetryPredicate, +}; use serde::Deserialize; use tracing::debug; @@ -177,26 +181,90 @@ impl RetryConfiguration { } /// Creates a new [`DefaultHttpRetryPolicy`] based on the forwarder configuration. - pub fn to_default_http_retry_policy(&self) -> DefaultHttpRetryPolicy { + /// + /// If a [`GenericConfiguration`] is supplied, it is captured by the policy and consulted on every 403 Forbidden + /// response to decide whether the response should be treated as retriable. When secrets management is in use + /// (see [`is_secrets_in_use`]), 403 responses are retried because they may indicate that the Core Agent is in the + /// middle of refreshing the API key. If no configuration is supplied, 403 responses retain their default + /// non-retriable behavior. + pub fn to_default_http_retry_policy(&self, config: Option) -> DefaultHttpRetryPolicy { let retry_backoff = ExponentialBackoff::with_jitter( Duration::from_secs_f64(self.backoff_base), Duration::from_secs_f64(self.backoff_max), self.backoff_factor, ); + let mut classifier = StandardHttpClassifier::new(); + if let Some(config) = config { + let secrets_gate: StatusCodeRetryPredicate = Arc::new(move || is_secrets_in_use(&config)); + classifier.set_status_code_predicate(StatusCode::FORBIDDEN, secrets_gate); + } + let recovery_error_decrease_factor = (!self.recovery_reset).then_some(self.recovery_error_decrease_factor); - DefaultHttpRetryPolicy::with_backoff(retry_backoff) + DefaultHttpRetryPolicy::with_backoff_and_classifier(retry_backoff, classifier) .with_recovery_error_decrease_factor(recovery_error_decrease_factor) } } +/// Returns `true` if the supplied configuration indicates that the Core Agent has secrets management in use. +/// +/// This is computed dynamically from `config` so that runtime configuration updates (delivered over the dynamic +/// configuration stream from the Core Agent) are reflected on every call. The gate is enabled when any of the +/// following are true: +/// +/// - `secret_refresh_on_api_key_failure_interval` is greater than zero +/// - `secret_refresh_interval` is greater than zero +/// - `secret_backend_type` is set to a non-empty string +/// - `secret_backend_command` is set to a non-empty string +pub fn is_secrets_in_use(config: &GenericConfiguration) -> bool { + matches!(config.try_get_typed::("secret_refresh_on_api_key_failure_interval"), Ok(Some(v)) if v > 0) + || matches!(config.try_get_typed::("secret_refresh_interval"), Ok(Some(v)) if v > 0) + || matches!(config.try_get_typed::("secret_backend_type"), Ok(Some(s)) if !s.trim().is_empty()) + || matches!(config.try_get_typed::("secret_backend_command"), Ok(Some(s)) if !s.trim().is_empty()) +} + #[cfg(test)] mod tests { - use saluki_config::ConfigurationLoader; + use std::time::Duration as StdDuration; + + use http::{Request, Response}; + use saluki_config::{dynamic::ConfigUpdate, ConfigurationLoader}; use serde_json::json; + use tower::retry::Policy; use super::*; + type BoxError = Box; + type TestRequest = Request<()>; + type TestResponse = Result, BoxError>; + + fn ok_response(status: StatusCode) -> TestResponse { + Ok(Response::builder().status(status).body(()).unwrap()) + } + + fn test_request() -> TestRequest { + Request::builder() + .method("POST") + .uri("http://localhost/intake") + .body(()) + .unwrap() + } + + fn test_retry_config() -> RetryConfiguration { + // Use small backoffs so that any returned `Sleep` futures are cheap; we never await them, but build them. + serde_json::from_value(json!({ + "forwarder_backoff_base": 0.001, + "forwarder_backoff_max": 0.01, + "forwarder_backoff_factor": 2.0, + })) + .expect("RetryConfiguration should deserialize") + } + + fn would_retry(policy: &mut DefaultHttpRetryPolicy, mut response: TestResponse) -> bool { + let mut request = test_request(); + Policy::, BoxError>::retry(policy, &mut request, &mut response).is_some() + } + #[tokio::test] async fn fix_empty_storage_path_sets_path_from_run_path() { const RUN_PATH: &str = "/my/little/run_path"; @@ -281,4 +349,130 @@ mod tests { let retry_config: RetryConfiguration = config.as_typed().expect("should deserialize"); assert_eq!(retry_config.queue_max_size_bytes(), OVERRIDE_PRIMARY_SIZE_BYTES); } + + #[tokio::test] + async fn is_secrets_in_use_returns_false_when_no_keys_set() { + let (config, _) = ConfigurationLoader::for_tests(None, None, false).await; + assert!(!is_secrets_in_use(&config)); + } + + #[tokio::test] + async fn is_secrets_in_use_returns_true_for_each_individual_key() { + let cases = [ + json!({ "secret_refresh_on_api_key_failure_interval": 30u64 }), + json!({ "secret_refresh_interval": 60u64 }), + json!({ "secret_backend_type": "file.json" }), + json!({ "secret_backend_command": "/usr/local/bin/secret-helper" }), + ]; + + for values in cases { + let label = values.to_string(); + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + assert!( + is_secrets_in_use(&config), + "expected secrets gate to be true for config {}", + label + ); + } + } + + #[tokio::test] + async fn is_secrets_in_use_returns_false_for_zero_intervals() { + let values = json!({ + "secret_refresh_on_api_key_failure_interval": 0u64, + "secret_refresh_interval": 0u64, + }); + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + assert!(!is_secrets_in_use(&config)); + } + + #[tokio::test] + async fn is_secrets_in_use_returns_false_for_empty_or_whitespace_strings() { + let values = json!({ + "secret_backend_type": "", + "secret_backend_command": " ", + }); + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + assert!(!is_secrets_in_use(&config)); + } + + #[tokio::test] + async fn policy_with_no_config_does_not_retry_403() { + let retry_config = test_retry_config(); + let mut policy = retry_config.to_default_http_retry_policy(None); + + assert!(!would_retry(&mut policy, ok_response(StatusCode::FORBIDDEN))); + } + + #[tokio::test] + async fn policy_with_config_but_no_secrets_does_not_retry_403() { + let (config, _) = ConfigurationLoader::for_tests(None, None, false).await; + let retry_config = test_retry_config(); + let mut policy = retry_config.to_default_http_retry_policy(Some(config)); + + assert!(!would_retry(&mut policy, ok_response(StatusCode::FORBIDDEN))); + } + + #[tokio::test] + async fn policy_with_config_and_secrets_retries_403() { + let values = json!({ "secret_backend_command": "/bin/true" }); + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + let retry_config = test_retry_config(); + let mut policy = retry_config.to_default_http_retry_policy(Some(config)); + + assert!(would_retry(&mut policy, ok_response(StatusCode::FORBIDDEN))); + } + + #[tokio::test] + async fn policy_403_decision_reflects_dynamic_config_changes() { + let (config, sender) = ConfigurationLoader::for_tests(Some(json!({})), None, true).await; + let sender = sender.expect("dynamic configuration sender should be present"); + + // Apply an empty initial snapshot and wait for readiness so the figment is in a known state. + sender + .send(ConfigUpdate::Snapshot(json!({}))) + .await + .expect("should send initial snapshot"); + config.ready().await; + + let retry_config = test_retry_config(); + let mut policy = retry_config.to_default_http_retry_policy(Some(config.clone())); + + // Before the dynamic update, the secrets gate is off, so 403 is not retried. + assert!(!would_retry(&mut policy, ok_response(StatusCode::FORBIDDEN))); + + // Subscribe to changes for the secrets key, then push an update that flips the gate. + let mut watcher = config.watch_for_updates("secret_backend_command"); + sender + .send(ConfigUpdate::Partial { + key: "secret_backend_command".to_string(), + value: json!("/bin/true"), + }) + .await + .expect("should send partial update"); + + let (_, new) = tokio::time::timeout(StdDuration::from_secs(2), watcher.changed::()) + .await + .expect("timed out waiting for secret_backend_command update"); + assert_eq!(new.as_deref(), Some("/bin/true")); + + // The same policy instance now retries 403, proving the predicate consults live config rather than a snapshot. + assert!(would_retry(&mut policy, ok_response(StatusCode::FORBIDDEN))); + } + + #[tokio::test] + async fn policy_does_not_change_other_status_codes() { + // Use a config that flips the secrets gate ON, to make sure the 403-only override does not bleed into others. + let values = json!({ "secret_backend_command": "/bin/true" }); + let (config, _) = ConfigurationLoader::for_tests(Some(values), None, false).await; + let retry_config = test_retry_config(); + let mut policy = retry_config.to_default_http_retry_policy(Some(config)); + + assert!(!would_retry(&mut policy, ok_response(StatusCode::OK))); + assert!(!would_retry(&mut policy, ok_response(StatusCode::BAD_REQUEST))); + assert!(!would_retry(&mut policy, ok_response(StatusCode::UNAUTHORIZED))); + assert!(!would_retry(&mut policy, ok_response(StatusCode::PAYLOAD_TOO_LARGE))); + assert!(would_retry(&mut policy, ok_response(StatusCode::INTERNAL_SERVER_ERROR))); + assert!(would_retry(&mut policy, ok_response(StatusCode::TOO_MANY_REQUESTS))); + } } diff --git a/lib/saluki-io/src/net/util/retry/classifier/http.rs b/lib/saluki-io/src/net/util/retry/classifier/http.rs index 83f09c6712c..2cfc7979512 100644 --- a/lib/saluki-io/src/net/util/retry/classifier/http.rs +++ b/lib/saluki-io/src/net/util/retry/classifier/http.rs @@ -1,7 +1,15 @@ +use std::{collections::HashMap, sync::Arc}; + use http::StatusCode; use super::RetryClassifier; +/// A predicate that decides whether a response with a particular [`StatusCode`] should be treated as retriable. +/// +/// The predicate is invoked at classification time, on every response that matches its status code, allowing the +/// decision to be re-evaluated dynamically (for example, against runtime-updated configuration). +pub type StatusCodeRetryPredicate = Arc bool + Send + Sync>; + /// A standard HTTP response classifier. /// /// Generally treats all client (4xx) and server (5xx) errors as retriable, with the exception of a few specific client @@ -11,24 +19,207 @@ use super::RetryClassifier; /// - 401 Unauthorized (likely a client-side misconfiguration) /// - 403 Forbidden (likely a client-side misconfiguration) /// - 413 Payload Too Large (likely a client-side bug) -#[derive(Clone)] -pub struct StandardHttpClassifier; +/// +/// The default classification for any [`StatusCode`] can be overridden by registering a [`StatusCodeRetryPredicate`] +/// for that status via [`StandardHttpClassifier::with_status_code_predicate`] (or +/// [`StandardHttpClassifier::set_status_code_predicate`]). When a response is received whose status code has a +/// predicate registered, the predicate is consulted and its return value is used as the retriability decision, +/// overriding the default behavior described above. +#[derive(Clone, Default)] +pub struct StandardHttpClassifier { + status_code_predicates: HashMap, +} + +impl StandardHttpClassifier { + /// Creates a new [`StandardHttpClassifier`] with no per-status-code predicates installed. + pub fn new() -> Self { + Self { + status_code_predicates: HashMap::new(), + } + } + + /// Builder-style: registers `predicate` as the retriability override for `status`. + /// + /// Replaces any previously registered predicate for the same status code. See + /// [`StandardHttpClassifier::set_status_code_predicate`] for the in-place equivalent. + pub fn with_status_code_predicate(mut self, status: StatusCode, predicate: StatusCodeRetryPredicate) -> Self { + self.set_status_code_predicate(status, predicate); + self + } + + /// Registers `predicate` as the retriability override for `status`. + /// + /// Replaces any previously registered predicate for the same status code. + pub fn set_status_code_predicate(&mut self, status: StatusCode, predicate: StatusCodeRetryPredicate) { + self.status_code_predicates.insert(status, predicate); + } + + /// Removes any predicate previously registered for `status`. + /// + /// After removal, responses with `status` revert to the classifier's default behavior. If no predicate was + /// registered for `status`, this is a no-op. + pub fn remove_status_code_predicate(&mut self, status: StatusCode) { + self.status_code_predicates.remove(&status); + } +} impl RetryClassifier, Error> for StandardHttpClassifier { fn should_retry(&self, response: &Result, Error>) -> bool { match response { - Ok(resp) => match resp.status() { - // There's some status codes that likely indicate a fundamental misconfiguration or bug on the client - // side which won't be resolved by retrying the request. - StatusCode::BAD_REQUEST - | StatusCode::UNAUTHORIZED - | StatusCode::FORBIDDEN - | StatusCode::PAYLOAD_TOO_LARGE => false, - - // For all other status codes, we'll only retry if they're in the client/server error range. - status => status.is_client_error() || status.is_server_error(), - }, + Ok(resp) => { + let status = resp.status(); + + // If a per-status-code predicate is installed, it takes precedence over the default classification. + if let Some(predicate) = self.status_code_predicates.get(&status) { + return predicate(); + } + + match status { + // There's some status codes that likely indicate a fundamental misconfiguration or bug on the + // client side which won't be resolved by retrying the request. + StatusCode::BAD_REQUEST + | StatusCode::UNAUTHORIZED + | StatusCode::FORBIDDEN + | StatusCode::PAYLOAD_TOO_LARGE => false, + + // For all other status codes, we'll only retry if they're in the client/server error range. + status => status.is_client_error() || status.is_server_error(), + } + } Err(_) => true, } } } + +#[cfg(test)] +mod tests { + use std::sync::atomic::{AtomicBool, Ordering}; + + use super::*; + + type TestResponse = Result, ()>; + + fn ok(status: StatusCode) -> TestResponse { + Ok(http::Response::builder().status(status).body(()).unwrap()) + } + + fn err() -> TestResponse { + Err(()) + } + + fn classify(classifier: &StandardHttpClassifier, response: &TestResponse) -> bool { + , ()>>::should_retry(classifier, response) + } + + fn always(value: bool) -> StatusCodeRetryPredicate { + Arc::new(move || value) + } + + #[test] + fn default_classifier_retries_5xx_and_most_4xx() { + let classifier = StandardHttpClassifier::new(); + + assert!(!classify(&classifier, &ok(StatusCode::OK))); + assert!(!classify(&classifier, &ok(StatusCode::NO_CONTENT))); + + for status in [ + StatusCode::INTERNAL_SERVER_ERROR, + StatusCode::BAD_GATEWAY, + StatusCode::SERVICE_UNAVAILABLE, + StatusCode::GATEWAY_TIMEOUT, + ] { + assert!(classify(&classifier, &ok(status)), "{} should be retried", status); + } + + for status in [ + StatusCode::REQUEST_TIMEOUT, + StatusCode::TOO_MANY_REQUESTS, + StatusCode::NOT_FOUND, + ] { + assert!(classify(&classifier, &ok(status)), "{} should be retried", status); + } + } + + #[test] + fn default_classifier_does_not_retry_known_client_misconfig() { + let classifier = StandardHttpClassifier::new(); + + for status in [ + StatusCode::BAD_REQUEST, + StatusCode::UNAUTHORIZED, + StatusCode::FORBIDDEN, + StatusCode::PAYLOAD_TOO_LARGE, + ] { + assert!(!classify(&classifier, &ok(status)), "{} should not be retried", status); + } + } + + #[test] + fn default_classifier_retries_transport_error() { + let classifier = StandardHttpClassifier::new(); + assert!(classify(&classifier, &err())); + } + + #[test] + fn predicate_overrides_default_for_403_when_true() { + let classifier = StandardHttpClassifier::new().with_status_code_predicate(StatusCode::FORBIDDEN, always(true)); + + assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN))); + // Sibling client-misconfig statuses without a predicate keep their default (non-retriable) behavior. + assert!(!classify(&classifier, &ok(StatusCode::UNAUTHORIZED))); + assert!(!classify(&classifier, &ok(StatusCode::BAD_REQUEST))); + } + + #[test] + fn predicate_overrides_default_for_403_when_false() { + // A `false` predicate on 403 matches the default, so to *prove* the predicate was consulted we also install a + // `false` predicate on 500 (which would otherwise be retried) and assert it flips. + let classifier = StandardHttpClassifier::new() + .with_status_code_predicate(StatusCode::FORBIDDEN, always(false)) + .with_status_code_predicate(StatusCode::INTERNAL_SERVER_ERROR, always(false)); + + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + assert!(!classify(&classifier, &ok(StatusCode::INTERNAL_SERVER_ERROR))); + } + + #[test] + fn predicate_is_re_evaluated_each_call() { + let flag = Arc::new(AtomicBool::new(false)); + let flag_clone = Arc::clone(&flag); + let predicate: StatusCodeRetryPredicate = Arc::new(move || flag_clone.load(Ordering::SeqCst)); + + let classifier = StandardHttpClassifier::new().with_status_code_predicate(StatusCode::FORBIDDEN, predicate); + + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + + flag.store(true, Ordering::SeqCst); + assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN))); + + flag.store(false, Ordering::SeqCst); + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + } + + #[test] + fn remove_status_code_predicate_restores_default() { + let mut classifier = + StandardHttpClassifier::new().with_status_code_predicate(StatusCode::FORBIDDEN, always(true)); + assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN))); + + classifier.remove_status_code_predicate(StatusCode::FORBIDDEN); + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + + // Removing a predicate that was never installed is a no-op. + classifier.remove_status_code_predicate(StatusCode::IM_A_TEAPOT); + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + } + + #[test] + fn set_status_code_predicate_replaces_existing() { + let mut classifier = + StandardHttpClassifier::new().with_status_code_predicate(StatusCode::FORBIDDEN, always(true)); + assert!(classify(&classifier, &ok(StatusCode::FORBIDDEN))); + + classifier.set_status_code_predicate(StatusCode::FORBIDDEN, always(false)); + assert!(!classify(&classifier, &ok(StatusCode::FORBIDDEN))); + } +} diff --git a/lib/saluki-io/src/net/util/retry/classifier/mod.rs b/lib/saluki-io/src/net/util/retry/classifier/mod.rs index bb1e5faf585..26ea6cc8765 100644 --- a/lib/saluki-io/src/net/util/retry/classifier/mod.rs +++ b/lib/saluki-io/src/net/util/retry/classifier/mod.rs @@ -1,5 +1,5 @@ mod http; -pub use self::http::StandardHttpClassifier; +pub use self::http::{StandardHttpClassifier, StatusCodeRetryPredicate}; /// Determines whether or not a request should be retried. /// diff --git a/lib/saluki-io/src/net/util/retry/mod.rs b/lib/saluki-io/src/net/util/retry/mod.rs index 7f029273d25..750c1f82871 100644 --- a/lib/saluki-io/src/net/util/retry/mod.rs +++ b/lib/saluki-io/src/net/util/retry/mod.rs @@ -2,7 +2,7 @@ mod backoff; pub use self::backoff::ExponentialBackoff; mod classifier; -pub use self::classifier::{RetryClassifier, StandardHttpClassifier}; +pub use self::classifier::{RetryClassifier, StandardHttpClassifier, StatusCodeRetryPredicate}; mod lifecycle; pub use self::lifecycle::StandardHttpRetryLifecycle; @@ -22,7 +22,75 @@ impl DefaultHttpRetryPolicy { /// /// This policy uses the standard HTTP classifier ([`StandardHttpClassifier`]) and retry lifecycle ([`StandardHttpRetryLifecycle`]). pub fn with_backoff(backoff: ExponentialBackoff) -> Self { - RollingExponentialBackoffRetryPolicy::new(StandardHttpClassifier, backoff) - .with_retry_lifecycle(StandardHttpRetryLifecycle) + Self::with_backoff_and_classifier(backoff, StandardHttpClassifier::new()) + } + + /// Creates a new retry policy adapted to HTTP-based clients with the given exponential backoff strategy and a + /// pre-built [`StandardHttpClassifier`]. + /// + /// This is the same as [`DefaultHttpRetryPolicy::with_backoff`], but allows the caller to supply a classifier that + /// has been customized (for example, with a [`StatusCodeRetryPredicate`]). + pub fn with_backoff_and_classifier(backoff: ExponentialBackoff, classifier: StandardHttpClassifier) -> Self { + RollingExponentialBackoffRetryPolicy::new(classifier, backoff).with_retry_lifecycle(StandardHttpRetryLifecycle) + } +} + +#[cfg(test)] +mod tests { + use std::{sync::Arc, time::Duration}; + + use http::{Request, Response, StatusCode}; + use tower::retry::Policy; + + use super::*; + + type BoxError = Box; + type TestRequest = Request<()>; + + fn test_backoff() -> ExponentialBackoff { + ExponentialBackoff::with_jitter(Duration::from_millis(1), Duration::from_millis(10), 2.0) + } + + fn test_request() -> TestRequest { + Request::builder() + .method("POST") + .uri("http://localhost/intake") + .body(()) + .unwrap() + } + + fn ok_response(status: StatusCode) -> Result, BoxError> { + Ok(Response::builder().status(status).body(()).unwrap()) + } + + fn would_retry(policy: &mut DefaultHttpRetryPolicy, status: StatusCode) -> bool { + let mut request = test_request(); + let mut response = ok_response(status); + Policy::, BoxError>::retry(policy, &mut request, &mut response).is_some() + } + + #[tokio::test] + async fn default_http_retry_policy_with_backoff_uses_default_classifier() { + let mut policy = DefaultHttpRetryPolicy::with_backoff(test_backoff()); + + assert!(!would_retry(&mut policy, StatusCode::OK)); + assert!(!would_retry(&mut policy, StatusCode::FORBIDDEN)); + assert!(!would_retry(&mut policy, StatusCode::BAD_REQUEST)); + assert!(would_retry(&mut policy, StatusCode::INTERNAL_SERVER_ERROR)); + assert!(would_retry(&mut policy, StatusCode::TOO_MANY_REQUESTS)); + } + + #[tokio::test] + async fn default_http_retry_policy_with_backoff_and_classifier_threads_predicate() { + // Build a classifier that flips 403 to retriable, then ensure the constructed policy honors it. + let predicate: StatusCodeRetryPredicate = Arc::new(|| true); + let classifier = StandardHttpClassifier::new().with_status_code_predicate(StatusCode::FORBIDDEN, predicate); + let mut policy = DefaultHttpRetryPolicy::with_backoff_and_classifier(test_backoff(), classifier); + + assert!(would_retry(&mut policy, StatusCode::FORBIDDEN)); + // Other status codes still follow default classifier behavior. + assert!(!would_retry(&mut policy, StatusCode::OK)); + assert!(!would_retry(&mut policy, StatusCode::UNAUTHORIZED)); + assert!(would_retry(&mut policy, StatusCode::INTERNAL_SERVER_ERROR)); } }