diff --git a/rocketmq-client/benches/produce_accumulator_benchmark.rs b/rocketmq-client/benches/produce_accumulator_benchmark.rs index 481d48ec0..cf876ac2b 100644 --- a/rocketmq-client/benches/produce_accumulator_benchmark.rs +++ b/rocketmq-client/benches/produce_accumulator_benchmark.rs @@ -17,6 +17,7 @@ //! //! Run with: cargo bench --bench produce_accumulator_benchmark +use std::collections::BinaryHeap; use std::collections::HashMap; use std::hint::black_box; use std::sync::atomic::AtomicU64; @@ -211,6 +212,66 @@ async fn bench_capacity_reservation( start.elapsed() } +#[derive(Clone, Eq, PartialEq)] +struct BenchDeadline { + deadline_tick: usize, + sequence: usize, +} + +impl Ord for BenchDeadline { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other + .deadline_tick + .cmp(&self.deadline_tick) + .then_with(|| other.sequence.cmp(&self.sequence)) + } +} + +impl PartialOrd for BenchDeadline { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +fn bench_guard_full_scan_low_activity(num_keys: usize, ticks: usize) -> usize { + let mut deadlines = vec![ticks + 1; num_keys]; + deadlines[0] = ticks / 2; + let mut active = vec![true; num_keys]; + let mut processed = 0usize; + + for tick in 0..ticks { + for index in 0..num_keys { + if active[index] && deadlines[index] <= tick { + active[index] = false; + processed += 1; + } + } + } + + black_box(processed) +} + +fn bench_guard_deadline_heap_low_activity(num_keys: usize, ticks: usize) -> usize { + let mut heap = BinaryHeap::new(); + for sequence in 0..num_keys { + let deadline_tick = if sequence == 0 { ticks / 2 } else { ticks + 1 }; + heap.push(BenchDeadline { + deadline_tick, + sequence, + }); + } + + let mut processed = 0usize; + for tick in 0..ticks { + while heap.peek().is_some_and(|deadline| deadline.deadline_tick <= tick) { + heap.pop(); + processed += 1; + } + } + + black_box(processed) +} + /// Benchmark: Concurrent reads and writes with HashMap async fn bench_hashmap_mutex_mixed_ops(num_keys: usize, read_ops: usize, write_ops: usize) -> Duration { let map: Arc>>>> = Arc::new(Mutex::new(HashMap::new())); @@ -379,6 +440,22 @@ fn bench_capacity_reservation_control(c: &mut Criterion) { group.finish(); } +fn bench_guard_deadline_scheduler(c: &mut Criterion) { + let mut group = c.benchmark_group("guard_deadline_scheduler"); + let num_keys = 64; + let ticks = 1024; + + group.throughput(Throughput::Elements((num_keys * ticks) as u64)); + group.bench_function("FullScan_64keys_low_activity", |b| { + b.iter(|| bench_guard_full_scan_low_activity(num_keys, ticks)); + }); + group.bench_function("DeadlineHeap_64keys_low_activity", |b| { + b.iter(|| bench_guard_deadline_heap_low_activity(num_keys, ticks)); + }); + + group.finish(); +} + /// Benchmark Group: Mixed Read/Write Operations fn bench_mixed_operations(c: &mut Criterion) { let rt = Runtime::new().unwrap(); @@ -464,6 +541,7 @@ criterion_group!( benches, bench_concurrent_insert, bench_capacity_reservation_control, + bench_guard_deadline_scheduler, bench_mixed_operations, bench_multi_topic_scenario, bench_high_contention diff --git a/rocketmq-client/src/producer/produce_accumulator.rs b/rocketmq-client/src/producer/produce_accumulator.rs index c5a2e487c..a02a2459b 100644 --- a/rocketmq-client/src/producer/produce_accumulator.rs +++ b/rocketmq-client/src/producer/produce_accumulator.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::BinaryHeap; use std::collections::HashSet; use std::future::Future; use std::hash::Hash; @@ -35,6 +36,7 @@ use rocketmq_common::common::message::MessageTrait; use rocketmq_common::TimeUtils::current_millis; use rocketmq_rust::ArcMut; use serde::Serialize; +use tokio::sync::mpsc; use tokio::sync::watch; use tokio::sync::Mutex; @@ -46,14 +48,59 @@ use crate::runtime::ClientTrackedTaskHandle; const GUARD_TASK_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5); +type BatchMap = Arc>>>; + #[derive(Debug, Clone, Serialize)] pub struct ProduceAccumulatorGuardLifecycleProbe { pub task_count_before_shutdown: usize, pub task_count_after_shutdown: usize, pub shutdown_elapsed_us: u128, + pub guard_metrics: ProduceAccumulatorGuardMetricsSnapshot, pub healthy: bool, } +#[derive(Debug, Clone, Copy, Default, Serialize)] +pub struct ProduceAccumulatorGuardMetricsSnapshot { + pub sync: BatchGuardMetricsSnapshot, + pub async_send: BatchGuardMetricsSnapshot, +} + +#[derive(Debug, Clone, Copy, Default, Serialize)] +pub struct BatchGuardMetricsSnapshot { + pub wakeup_count: u64, + pub flush_count: u64, + pub idle_wakeup_count: u64, +} + +#[derive(Default)] +struct BatchGuardMetrics { + wakeup_count: AtomicU64, + flush_count: AtomicU64, + idle_wakeup_count: AtomicU64, +} + +impl BatchGuardMetrics { + fn record_wakeup(&self) { + self.wakeup_count.fetch_add(1, Ordering::Relaxed); + } + + fn record_flush(&self) { + self.flush_count.fetch_add(1, Ordering::Relaxed); + } + + fn record_idle(&self) { + self.idle_wakeup_count.fetch_add(1, Ordering::Relaxed); + } + + fn snapshot(&self) -> BatchGuardMetricsSnapshot { + BatchGuardMetricsSnapshot { + wakeup_count: self.wakeup_count.load(Ordering::Relaxed), + flush_count: self.flush_count.load(Ordering::Relaxed), + idle_wakeup_count: self.idle_wakeup_count.load(Ordering::Relaxed), + } + } +} + #[derive(Default)] pub struct ProduceAccumulator { total_hold_size: usize, @@ -74,16 +121,8 @@ impl ProduceAccumulator { hold_size: 1024 * 32, hold_ms: 10, instance_name: instance_name.to_string(), - guard_thread_for_async_send: GuardForAsyncSendService { - service_name: instance_name.to_string(), - stopped: Arc::new(AtomicBool::new(false)), - task_handle: None, - }, - guard_thread_for_sync_send: GuardForSyncSendService { - service_name: instance_name.to_string(), - stopped: Arc::new(AtomicBool::new(false)), - task_handle: None, - }, + guard_thread_for_async_send: GuardForAsyncSendService::new(instance_name), + guard_thread_for_sync_send: GuardForSyncSendService::new(instance_name), ..Default::default() } } @@ -158,6 +197,13 @@ impl ProduceAccumulator { self.guard_thread_for_sync_send.task_count() + self.guard_thread_for_async_send.task_count() } + pub fn guard_metrics_snapshot(&self) -> ProduceAccumulatorGuardMetricsSnapshot { + ProduceAccumulatorGuardMetricsSnapshot { + sync: self.guard_thread_for_sync_send.metrics_snapshot(), + async_send: self.guard_thread_for_async_send.metrics_snapshot(), + } + } + pub(crate) fn try_add_message(&self, message: &T) -> bool { let body_size = message.get_body().map_or(0, |body| body.len()) as u64; if body_size == 0 { @@ -369,15 +415,20 @@ impl ProduceAccumulator { aggregate_key: AggregateKey, default_mq_producer: &DefaultMQProducer, ) -> Arc> { - self.sync_send_batchs - .entry(aggregate_key.clone()) - .or_insert_with(|| { - Arc::new(Mutex::new(MessageAccumulation::new( - aggregate_key, - ArcMut::new(default_mq_producer.clone()), - ))) - }) - .clone() + match self.sync_send_batchs.entry(aggregate_key.clone()) { + dashmap::mapref::entry::Entry::Occupied(entry) => entry.get().clone(), + dashmap::mapref::entry::Entry::Vacant(entry) => { + let accumulation = + MessageAccumulation::new(aggregate_key.clone(), ArcMut::new(default_mq_producer.clone())); + let create_time = accumulation.create_time; + let deadline_ms = accumulation.deadline_ms(self.hold_ms as u64); + let batch = Arc::new(Mutex::new(accumulation)); + entry.insert(batch.clone()); + self.guard_thread_for_sync_send + .schedule_batch(aggregate_key, create_time, deadline_ms); + batch + } + } } async fn get_or_create_async_send_batch( @@ -385,15 +436,20 @@ impl ProduceAccumulator { aggregate_key: AggregateKey, default_mq_producer: &DefaultMQProducer, ) -> Arc> { - self.async_send_batchs - .entry(aggregate_key.clone()) - .or_insert_with(|| { - Arc::new(Mutex::new(MessageAccumulation::new( - aggregate_key, - ArcMut::new(default_mq_producer.clone()), - ))) - }) - .clone() + match self.async_send_batchs.entry(aggregate_key.clone()) { + dashmap::mapref::entry::Entry::Occupied(entry) => entry.get().clone(), + dashmap::mapref::entry::Entry::Vacant(entry) => { + let accumulation = + MessageAccumulation::new(aggregate_key.clone(), ArcMut::new(default_mq_producer.clone())); + let create_time = accumulation.create_time; + let deadline_ms = accumulation.deadline_ms(self.hold_ms as u64); + let batch = Arc::new(Mutex::new(accumulation)); + entry.insert(batch.clone()); + self.guard_thread_for_async_send + .schedule_batch(aggregate_key, create_time, deadline_ms); + batch + } + } } /// Send a batch synchronously (extracted to avoid holding lock across await) @@ -703,12 +759,14 @@ pub async fn run_produce_accumulator_guard_lifecycle_probe() -> ProduceAccumulat accumulator.shutdown_async().await; let shutdown_elapsed_us = shutdown_started_at.elapsed().as_micros(); let task_count_after_shutdown = accumulator.guard_task_count(); + let guard_metrics = accumulator.guard_metrics_snapshot(); let healthy = task_count_before_shutdown == 2 && task_count_after_shutdown == 0; ProduceAccumulatorGuardLifecycleProbe { task_count_before_shutdown, task_count_after_shutdown, shutdown_elapsed_us, + guard_metrics, healthy, } } @@ -999,6 +1057,67 @@ mod tests { assert!(accumulator.async_send_batchs.is_empty()); } + #[tokio::test] + async fn sync_guard_notifies_batch_at_deadline() { + let mut accumulator = ProduceAccumulator::new("accumulator-sync-deadline-test"); + accumulator.set_batch_max_delay_ms(100).unwrap(); + accumulator.start(); + + let aggregate_key = AggregateKey::new(CheetahString::from("test-topic"), None, true, None); + let batch = accumulator + .get_or_create_sync_send_batch(aggregate_key, &DefaultMQProducer::default()) + .await; + let notify = { + let mut batch_guard = batch.lock().await; + let message = Message::builder() + .topic("test-topic") + .body_slice(b"hello") + .build_unchecked(); + assert!(batch_guard.add(message, None, usize::MAX, 100).unwrap().is_some()); + batch_guard.completion_notify.clone() + }; + + tokio::time::timeout(Duration::from_secs(1), notify.notified()) + .await + .expect("sync guard should notify when the batch deadline expires"); + + let metrics = accumulator.guard_metrics_snapshot(); + assert!(metrics.sync.wakeup_count >= 1); + assert_eq!(metrics.sync.flush_count, 1); + + accumulator.shutdown_async().await; + } + + #[tokio::test] + async fn guard_deadline_queue_lazily_ignores_removed_batch() { + let mut accumulator = ProduceAccumulator::new("accumulator-stale-deadline-test"); + accumulator.set_batch_max_delay_ms(30).unwrap(); + accumulator.start(); + + let aggregate_key = AggregateKey::new(CheetahString::from("test-topic"), None, true, None); + let batch = accumulator + .get_or_create_sync_send_batch(aggregate_key.clone(), &DefaultMQProducer::default()) + .await; + { + let mut batch_guard = batch.lock().await; + let message = Message::builder() + .topic("test-topic") + .body_slice(b"hello") + .build_unchecked(); + assert!(batch_guard.add(message, None, usize::MAX, 30).unwrap().is_some()); + } + accumulator.sync_send_batchs.remove(&aggregate_key); + + tokio::time::sleep(Duration::from_millis(150)).await; + + let metrics = accumulator.guard_metrics_snapshot(); + assert!(metrics.sync.wakeup_count >= 1); + assert_eq!(metrics.sync.flush_count, 0); + assert!(metrics.sync.idle_wakeup_count >= 1); + + accumulator.shutdown_async().await; + } + #[tokio::test] async fn produce_accumulator_guard_lifecycle_probe_reports_clean_shutdown() { let probe = run_produce_accumulator_guard_lifecycle_probe().await; @@ -1227,6 +1346,45 @@ impl BatchState { } } +#[derive(Clone)] +struct GuardScheduleCommand { + aggregate_key: AggregateKey, + create_time: u64, + deadline_ms: u64, +} + +#[derive(Clone, Eq)] +struct GuardDeadline { + aggregate_key: AggregateKey, + create_time: u64, + deadline_ms: u64, + sequence: u64, +} + +impl PartialEq for GuardDeadline { + fn eq(&self, other: &Self) -> bool { + self.deadline_ms == other.deadline_ms + && self.sequence == other.sequence + && self.create_time == other.create_time + && self.aggregate_key == other.aggregate_key + } +} + +impl Ord for GuardDeadline { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + other + .deadline_ms + .cmp(&self.deadline_ms) + .then_with(|| other.sequence.cmp(&self.sequence)) + } +} + +impl PartialOrd for GuardDeadline { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + impl MessageAccumulation { pub fn new(aggregate_key: AggregateKey, default_mq_producer: ArcMut) -> Self { Self { @@ -1245,6 +1403,10 @@ impl MessageAccumulation { } } + fn deadline_ms(&self, hold_ms: u64) -> u64 { + self.create_time.saturating_add(hold_ms) + } + fn state(&self) -> BatchState { BatchState::from_u8(self.state.load(Ordering::Acquire)) } @@ -1334,6 +1496,8 @@ struct GuardForSyncSendService { service_name: String, stopped: Arc, task_handle: Option, + schedule_tx: Option>, + metrics: Arc, } enum GuardTaskHandle { @@ -1400,17 +1564,165 @@ where } } -async fn wait_guard_tick_or_shutdown( +enum GuardWakeEvent { + Deadline, + Schedule(GuardScheduleCommand), + Shutdown, +} + +fn push_guard_deadline(deadlines: &mut BinaryHeap, command: GuardScheduleCommand, sequence: &mut u64) { + deadlines.push(GuardDeadline { + aggregate_key: command.aggregate_key, + create_time: command.create_time, + deadline_ms: command.deadline_ms, + sequence: *sequence, + }); + *sequence = sequence.wrapping_add(1); +} + +async fn wait_guard_deadline_or_schedule( shutdown_rx: &mut watch::Receiver, - interval: &mut tokio::time::Interval, -) -> bool { + schedule_rx: &mut mpsc::UnboundedReceiver, + next_deadline_ms: Option, +) -> GuardWakeEvent { if *shutdown_rx.borrow() { - return false; + return GuardWakeEvent::Shutdown; + } + + if let Some(deadline_ms) = next_deadline_ms { + let delay = Duration::from_millis(deadline_ms.saturating_sub(current_millis())); + if delay.is_zero() { + return GuardWakeEvent::Deadline; + } + + tokio::select! { + _ = tokio::time::sleep(delay) => GuardWakeEvent::Deadline, + command = schedule_rx.recv() => { + command.map_or(GuardWakeEvent::Shutdown, GuardWakeEvent::Schedule) + } + _ = shutdown_rx.changed() => GuardWakeEvent::Shutdown, + } + } else { + tokio::select! { + command = schedule_rx.recv() => { + command.map_or(GuardWakeEvent::Shutdown, GuardWakeEvent::Schedule) + } + _ = shutdown_rx.changed() => GuardWakeEvent::Shutdown, + } + } +} + +fn remove_batch_if_same( + batches: &DashMap>>, + key: &AggregateKey, + expected: &Arc>, +) -> Option>> { + let should_remove = batches + .get(key) + .is_some_and(|current| Arc::ptr_eq(current.value(), expected)); + if should_remove { + batches.remove(key).map(|(_, batch)| batch) + } else { + None } +} + +async fn drain_due_sync_batches( + deadlines: &mut BinaryHeap, + batches: &DashMap>>, + hold_ms: u64, + metrics: &BatchGuardMetrics, +) { + let now = current_millis(); + while deadlines.peek().is_some_and(|deadline| deadline.deadline_ms <= now) { + let deadline = deadlines.pop().expect("deadline should exist after peek"); + let Some(batch) = batches.get(&deadline.aggregate_key).map(|entry| entry.value().clone()) else { + metrics.record_idle(); + continue; + }; + + let mut remove_empty = false; + let mut notify = None; + { + let batch_guard = batch.lock().await; + if batch_guard.create_time != deadline.create_time || batch_guard.state() != BatchState::Open { + metrics.record_idle(); + continue; + } + + if batch_guard.messages_size.load(Ordering::Acquire) == 0 { + batch_guard.mark_closed(); + remove_empty = true; + } else if batch_guard.deadline_ms(hold_ms) <= now { + notify = Some(batch_guard.completion_notify.clone()); + } else { + metrics.record_idle(); + } + } - tokio::select! { - _ = interval.tick() => !*shutdown_rx.borrow(), - changed = shutdown_rx.changed() => changed.is_ok() && !*shutdown_rx.borrow(), + if remove_empty { + remove_batch_if_same(batches, &deadline.aggregate_key, &batch); + metrics.record_idle(); + } else if let Some(notify) = notify { + metrics.record_flush(); + notify.notify_waiters(); + } + } +} + +async fn drain_due_async_batches( + deadlines: &mut BinaryHeap, + batches: &DashMap>>, + currently_hold_size: &Arc, + hold_size: usize, + hold_ms: u64, + metrics: &BatchGuardMetrics, +) { + let now = current_millis(); + while deadlines.peek().is_some_and(|deadline| deadline.deadline_ms <= now) { + let deadline = deadlines.pop().expect("deadline should exist after peek"); + let Some(batch) = batches.get(&deadline.aggregate_key).map(|entry| entry.value().clone()) else { + metrics.record_idle(); + continue; + }; + + let mut remove_empty = false; + let mut should_send = false; + { + let batch_guard = batch.lock().await; + if batch_guard.create_time != deadline.create_time || batch_guard.state() != BatchState::Open { + metrics.record_idle(); + continue; + } + + if batch_guard.messages_size.load(Ordering::Acquire) == 0 { + batch_guard.mark_closed(); + remove_empty = true; + } else if batch_guard.ready_to_send(hold_size, hold_ms) { + should_send = true; + } else { + metrics.record_idle(); + } + } + + if remove_empty { + remove_batch_if_same(batches, &deadline.aggregate_key, &batch); + metrics.record_idle(); + continue; + } + + if should_send { + if let Some(batch) = remove_batch_if_same(batches, &deadline.aggregate_key, &batch) { + metrics.record_flush(); + if let Err(error) = + GuardForAsyncSendService::send_batch_async_internal(batch, currently_hold_size.clone()).await + { + tracing::error!("Failed to send batch via guard thread: {:?}", error); + } + } else { + metrics.record_idle(); + } + } } } @@ -1420,10 +1732,12 @@ impl GuardForSyncSendService { service_name: service_name.to_string(), stopped: Arc::new(AtomicBool::new(false)), task_handle: None, + schedule_tx: None, + metrics: Arc::new(BatchGuardMetrics::default()), } } - pub fn start(&mut self, batches: Arc>>>, hold_ms: u32) { + pub fn start(&mut self, batches: BatchMap, hold_ms: u32) { if self.task_handle.as_ref().is_some_and(|handle| !handle.is_finished()) { tracing::warn!("{} sync batch guard already started", self.service_name); return; @@ -1431,39 +1745,33 @@ impl GuardForSyncSendService { let service_name = self.service_name.clone(); let stopped = self.stopped.clone(); - let sleep_time = std::cmp::max(1, hold_ms / 2) as u64; + let metrics = self.metrics.clone(); + let (schedule_tx, mut schedule_rx) = mpsc::unbounded_channel(); let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + self.schedule_tx = Some(schedule_tx); self.stopped.store(false, Ordering::Release); self.task_handle = spawn_guard_task("rocketmq-client-sync-batch-guard", shutdown_tx, async move { tracing::info!("{} service started", service_name); - let mut interval = tokio::time::interval(Duration::from_millis(sleep_time)); + let mut deadlines = BinaryHeap::new(); + let mut sequence = 0u64; loop { - if !wait_guard_tick_or_shutdown(&mut shutdown_rx, &mut interval).await - || stopped.load(Ordering::Acquire) - { + if stopped.load(Ordering::Acquire) { break; } - // Process batches - DashMap provides concurrent iteration - // Collect empty batches to remove - let mut to_remove = Vec::new(); - for item in batches.iter() { - let key = item.key(); - let batch = item.value(); - let batch_guard = batch.lock().await; - let messages_size = batch_guard.messages_size.load(Ordering::Acquire); - if messages_size == 0 { - batch_guard.mark_closed(); - to_remove.push(key.clone()); + let next_deadline_ms = deadlines.peek().map(|deadline: &GuardDeadline| deadline.deadline_ms); + match wait_guard_deadline_or_schedule(&mut shutdown_rx, &mut schedule_rx, next_deadline_ms).await { + GuardWakeEvent::Schedule(command) => { + push_guard_deadline(&mut deadlines, command, &mut sequence); } - } - - // Remove empty batches - for key in to_remove { - batches.remove(&key); + GuardWakeEvent::Deadline => { + metrics.record_wakeup(); + drain_due_sync_batches(&mut deadlines, &batches, hold_ms as u64, &metrics).await; + } + GuardWakeEvent::Shutdown => break, } } @@ -1471,8 +1779,23 @@ impl GuardForSyncSendService { }); } + fn schedule_batch(&self, aggregate_key: AggregateKey, create_time: u64, deadline_ms: u64) { + if let Some(schedule_tx) = &self.schedule_tx { + let _ = schedule_tx.send(GuardScheduleCommand { + aggregate_key, + create_time, + deadline_ms, + }); + } + } + + fn metrics_snapshot(&self) -> BatchGuardMetricsSnapshot { + self.metrics.snapshot() + } + pub fn shutdown(&mut self) { self.stopped.store(true, Ordering::Release); + self.schedule_tx = None; if let Some(handle) = self.task_handle.take() { if !handle.shutdown_blocking(GUARD_TASK_SHUTDOWN_TIMEOUT) { tracing::warn!( @@ -1485,6 +1808,7 @@ impl GuardForSyncSendService { pub async fn shutdown_async(&mut self) { self.stopped.store(true, Ordering::Release); + self.schedule_tx = None; if let Some(handle) = self.task_handle.take() { if !handle.shutdown_async(GUARD_TASK_SHUTDOWN_TIMEOUT).await { tracing::warn!( @@ -1508,6 +1832,8 @@ struct GuardForAsyncSendService { service_name: String, stopped: Arc, task_handle: Option, + schedule_tx: Option>, + metrics: Arc, } impl GuardForAsyncSendService { @@ -1516,16 +1842,12 @@ impl GuardForAsyncSendService { service_name: service_name.to_string(), stopped: Arc::new(AtomicBool::new(false)), task_handle: None, + schedule_tx: None, + metrics: Arc::new(BatchGuardMetrics::default()), } } - pub fn start( - &mut self, - batches: Arc>>>, - currently_hold_size: Arc, - hold_size: usize, - hold_ms: u32, - ) { + pub fn start(&mut self, batches: BatchMap, currently_hold_size: Arc, hold_size: usize, hold_ms: u32) { if self.task_handle.as_ref().is_some_and(|handle| !handle.is_finished()) { tracing::warn!("{} async batch guard already started", self.service_name); return; @@ -1533,71 +1855,41 @@ impl GuardForAsyncSendService { let service_name = self.service_name.clone(); let stopped = self.stopped.clone(); - let sleep_time = std::cmp::max(1, hold_ms / 2) as u64; + let metrics = self.metrics.clone(); + let (schedule_tx, mut schedule_rx) = mpsc::unbounded_channel(); let (shutdown_tx, mut shutdown_rx) = watch::channel(false); + self.schedule_tx = Some(schedule_tx); self.stopped.store(false, Ordering::Release); self.task_handle = spawn_guard_task("rocketmq-client-async-batch-guard", shutdown_tx, async move { tracing::info!("{} service started", service_name); - let mut interval = tokio::time::interval(Duration::from_millis(sleep_time)); + let mut deadlines = BinaryHeap::new(); + let mut sequence = 0u64; loop { - if !wait_guard_tick_or_shutdown(&mut shutdown_rx, &mut interval).await - || stopped.load(Ordering::Acquire) - { + if stopped.load(Ordering::Acquire) { break; } - // Collect keys of ready batches (without holding locks during iteration) - let mut ready_keys = Vec::new(); - for item in batches.iter() { - let key = item.key(); - let batch = item.value(); - - // Quick check without locking first - let should_check = { - let batch_guard = batch.lock().await; - batch_guard.state() == BatchState::Open && batch_guard.ready_to_send(hold_size, hold_ms as u64) - }; - - if should_check { - ready_keys.push(key.clone()); - } - } - - // Send ready batches - for key in ready_keys { - if let Some((_, batch)) = batches.remove(&key) { - // Send the batch asynchronously - if let Err(e) = Self::send_batch_async_internal(batch, currently_hold_size.clone()).await { - tracing::error!("Failed to send batch via guard thread: {:?}", e); - } + let next_deadline_ms = deadlines.peek().map(|deadline: &GuardDeadline| deadline.deadline_ms); + match wait_guard_deadline_or_schedule(&mut shutdown_rx, &mut schedule_rx, next_deadline_ms).await { + GuardWakeEvent::Schedule(command) => { + push_guard_deadline(&mut deadlines, command, &mut sequence); } - } - - // Collect empty batches to remove - let mut empty_keys = Vec::new(); - for item in batches.iter() { - let key = item.key(); - let batch = item.value(); - - let is_empty = { - let batch_guard = batch.lock().await; - batch_guard.messages_size.load(Ordering::Acquire) == 0 - }; - - if is_empty { - empty_keys.push(key.clone()); - } - } - - // Remove empty batches - for key in empty_keys { - if let Some((_, batch)) = batches.remove(&key) { - let batch_guard = batch.lock().await; - batch_guard.mark_closed(); + GuardWakeEvent::Deadline => { + metrics.record_wakeup(); + drain_due_async_batches( + &mut deadlines, + &batches, + ¤tly_hold_size, + hold_size, + hold_ms as u64, + &metrics, + ) + .await; } + GuardWakeEvent::Shutdown => break, } } @@ -1605,6 +1897,20 @@ impl GuardForAsyncSendService { }); } + fn schedule_batch(&self, aggregate_key: AggregateKey, create_time: u64, deadline_ms: u64) { + if let Some(schedule_tx) = &self.schedule_tx { + let _ = schedule_tx.send(GuardScheduleCommand { + aggregate_key, + create_time, + deadline_ms, + }); + } + } + + fn metrics_snapshot(&self) -> BatchGuardMetricsSnapshot { + self.metrics.snapshot() + } + /// Internal method to send batch (used by guard thread) async fn send_batch_async_internal( batch: Arc>, @@ -1713,6 +2019,7 @@ impl GuardForAsyncSendService { pub fn shutdown(&mut self) { self.stopped.store(true, Ordering::Release); + self.schedule_tx = None; if let Some(handle) = self.task_handle.take() { if !handle.shutdown_blocking(GUARD_TASK_SHUTDOWN_TIMEOUT) { tracing::warn!( @@ -1725,6 +2032,7 @@ impl GuardForAsyncSendService { pub async fn shutdown_async(&mut self) { self.stopped.store(true, Ordering::Release); + self.schedule_tx = None; if let Some(handle) = self.task_handle.take() { if !handle.shutdown_async(GUARD_TASK_SHUTDOWN_TIMEOUT).await { tracing::warn!(