diff --git a/Cargo.toml b/Cargo.toml index fbea1f38..5ac03448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -61,7 +61,7 @@ log = "0.4.21" parking_lot = "0.12.1" thiserror = "1.0.58" pin-project = "1.0.12" -tokio = { version = "1.27.0", features = ["net", "time", "sync"] } +tokio = { version = "1.27.0", features = ["net", "time", "sync", "rt"] } tokio-stream = "0.1.12" tokio-util = { version = "0.7.7", features = ["codec"] } diff --git a/src/client/mod.rs b/src/client/mod.rs index 8c3f08c6..e741ddad 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -435,8 +435,6 @@ macro_rules! pub_sender_base { pub struct ClientStream { state: Arc, stream: SplitStream, - // In case the client stream also handles outgoing messages. - outgoing: Option, } impl ClientStream { @@ -465,21 +463,6 @@ impl Stream for ClientStream { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - if let Some(outgoing) = self.as_mut().outgoing.as_mut() { - match Pin::new(outgoing).poll(cx) { - Poll::Ready(Ok(())) => { - // assure that we wake up again to check the incoming stream. - cx.waker().wake_by_ref(); - return Poll::Ready(None); - } - Poll::Ready(Err(e)) => { - cx.waker().wake_by_ref(); - return Poll::Ready(Some(Err(e))); - } - Poll::Pending => (), - } - } - match ready!(Pin::new(&mut self.as_mut().stream).poll_next(cx)) { Some(Ok(msg)) => { self.state.handle_message(&msg)?; @@ -988,10 +971,20 @@ impl Client { .take() .ok_or(error::Error::StreamAlreadyConfigured)?; + // Spawn the outgoing message handler as an independent task so that + // sent messages are flushed immediately, rather than waiting for the + // incoming stream to be polled. + if let Some(outgoing) = self.outgoing.take() { + tokio::spawn(async move { + if let Err(e) = outgoing.await { + log::error!("error in outgoing message handler: {}", e); + } + }); + } + Ok(ClientStream { state: Arc::clone(&self.state), stream, - outgoing: self.outgoing.take(), }) } @@ -1091,7 +1084,7 @@ impl Client { #[cfg(test)] mod test { - use std::{collections::HashMap, default::Default, thread, time::Duration}; + use std::{collections::HashMap, default::Default, time::Duration}; use super::Client; #[cfg(feature = "channel-lists")] @@ -1120,10 +1113,9 @@ mod test { } } - pub fn get_client_value(client: Client) -> String { - // We sleep here because of synchronization issues. - // We can't guarantee that everything will have been sent by the time of this call. - thread::sleep(Duration::from_millis(100)); + pub async fn get_client_value(client: Client) -> String { + // We yield here to allow the spawned outgoing task to flush messages. + tokio::time::sleep(Duration::from_millis(100)).await; client .log_view() .sent() @@ -1164,7 +1156,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "JOIN #test\r\nJOIN #test2\r\n" ); Ok(()) @@ -1182,7 +1174,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NICKSERV IDENTIFY password\r\nJOIN #test\r\n\ JOIN #test2\r\n" ); @@ -1206,7 +1198,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "JOIN #test\r\nJOIN #test2 password\r\n" ); Ok(()) @@ -1228,7 +1220,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NICK test2\r\nNICKSERV GHOST test password\r\n\ NICK test\r\nNICKSERV IDENTIFY password\r\nJOIN #test\r\nJOIN #test2\r\n" ); @@ -1252,7 +1244,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NICK test2\r\nNICKSERV RECOVER test password\ \r\nNICKSERV RELEASE test password\r\nNICK test\r\nNICKSERV IDENTIFY password\ \r\nJOIN #test\r\nJOIN #test2\r\n" @@ -1273,7 +1265,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "MODE test +B\r\nJOIN #test\r\nJOIN #test2\r\n" ); Ok(()) @@ -1288,7 +1280,7 @@ mod test { }) .await?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "NICK test2\r\n"); + assert_eq!(&get_client_value(client).await[..], "NICK test2\r\n"); Ok(()) } @@ -1317,7 +1309,7 @@ mod test { .is_ok()); client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG #test :Hi there!\r\n" ); Ok(()) @@ -1334,7 +1326,7 @@ mod test { .is_ok()); client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG #test :Hi there!\r\n" ); Ok(()) @@ -1351,7 +1343,7 @@ mod test { .is_ok()); client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PASS password\r\nNICK test\r\n" ); Ok(()) @@ -1537,7 +1529,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NOTICE test :\u{001}FINGER :test (test)\u{001}\r\n" ); Ok(()) @@ -1554,7 +1546,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], &format!( "NOTICE test :\u{001}VERSION {}\u{001}\r\n", crate::VERSION_STR, @@ -1574,7 +1566,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NOTICE test :\u{001}SOURCE https://github.com/aatxe/irc\u{001}\r\n" ); Ok(()) @@ -1591,7 +1583,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NOTICE test :\u{001}PING test\u{001}\r\n" ); Ok(()) @@ -1607,7 +1599,7 @@ mod test { }) .await?; client.stream()?.collect().await?; - let val = get_client_value(client); + let val = get_client_value(client).await; assert!(val.starts_with("NOTICE test :\u{001}TIME :")); assert!(val.ends_with("\u{001}\r\n")); Ok(()) @@ -1624,7 +1616,7 @@ mod test { .await?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NOTICE test :\u{001}USERINFO :Testing.\u{001}\ \r\n" ); @@ -1641,7 +1633,7 @@ mod test { }) .await?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], ""); + assert_eq!(&get_client_value(client).await[..], ""); Ok(()) } @@ -1651,7 +1643,7 @@ mod test { client.identify()?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "CAP END\r\nNICK test\r\n\ USER test 0 * test\r\n" ); @@ -1669,7 +1661,7 @@ mod test { client.identify()?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "CAP END\r\nPASS password\r\nNICK test\r\n\ USER test 0 * test\r\n" ); @@ -1681,7 +1673,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_pong("irc.test.net")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "PONG irc.test.net\r\n"); + assert_eq!(&get_client_value(client).await[..], "PONG irc.test.net\r\n"); Ok(()) } @@ -1691,7 +1683,7 @@ mod test { client.send_join("#test,#test2,#test3")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "JOIN #test,#test2,#test3\r\n" ); Ok(()) @@ -1702,7 +1694,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_part("#test")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "PART #test\r\n"); + assert_eq!(&get_client_value(client).await[..], "PART #test\r\n"); Ok(()) } @@ -1711,7 +1703,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_oper("test", "test")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "OPER test test\r\n"); + assert_eq!(&get_client_value(client).await[..], "OPER test test\r\n"); Ok(()) } @@ -1721,7 +1713,7 @@ mod test { client.send_privmsg("#test", "Hi, everybody!")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG #test :Hi, everybody!\r\n" ); Ok(()) @@ -1733,7 +1725,7 @@ mod test { client.send_notice("#test", "Hi, everybody!")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "NOTICE #test :Hi, everybody!\r\n" ); Ok(()) @@ -1744,7 +1736,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_topic("#test", "")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "TOPIC #test\r\n"); + assert_eq!(&get_client_value(client).await[..], "TOPIC #test\r\n"); Ok(()) } @@ -1754,7 +1746,7 @@ mod test { client.send_topic("#test", "Testing stuff.")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "TOPIC #test :Testing stuff.\r\n" ); Ok(()) @@ -1766,7 +1758,7 @@ mod test { client.send_kill("test", "Testing kills.")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "KILL test :Testing kills.\r\n" ); Ok(()) @@ -1777,7 +1769,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_kick("#test", "test", "")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "KICK #test test\r\n"); + assert_eq!(&get_client_value(client).await[..], "KICK #test test\r\n"); Ok(()) } @@ -1787,7 +1779,7 @@ mod test { client.send_kick("#test", "test", "Testing kicks.")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "KICK #test test :Testing kicks.\r\n" ); Ok(()) @@ -1798,7 +1790,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_mode("#test", &[Mode::Plus(ChannelMode::InviteOnly, None)])?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "MODE #test +i\r\n"); + assert_eq!(&get_client_value(client).await[..], "MODE #test +i\r\n"); Ok(()) } @@ -1814,7 +1806,7 @@ mod test { )?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "MODE #test +o-o test test2\r\n" ); Ok(()) @@ -1831,7 +1823,7 @@ mod test { ], )?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "MODE test +i+x\r\n"); + assert_eq!(&get_client_value(client).await[..], "MODE test +i+x\r\n"); Ok(()) } @@ -1840,7 +1832,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_samode("#test", "+i", "")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "SAMODE #test +i\r\n"); + assert_eq!(&get_client_value(client).await[..], "SAMODE #test +i\r\n"); Ok(()) } @@ -1849,7 +1841,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_samode("#test", "+o", "test")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "SAMODE #test +o test\r\n"); + assert_eq!(&get_client_value(client).await[..], "SAMODE #test +o test\r\n"); Ok(()) } @@ -1858,7 +1850,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_sanick("test", "test2")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "SANICK test test2\r\n"); + assert_eq!(&get_client_value(client).await[..], "SANICK test test2\r\n"); Ok(()) } @@ -1867,7 +1859,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_invite("test", "#test")?; client.stream()?.collect().await?; - assert_eq!(&get_client_value(client)[..], "INVITE test #test\r\n"); + assert_eq!(&get_client_value(client).await[..], "INVITE test #test\r\n"); Ok(()) } @@ -1878,7 +1870,7 @@ mod test { client.send_ctcp("test", "LINE1\r\nLINE2\r\nLINE3")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}LINE1\u{001}\r\nPRIVMSG test \u{001}LINE2\u{001}\r\nPRIVMSG test \u{001}LINE3\u{001}\r\n" ); Ok(()) @@ -1891,7 +1883,7 @@ mod test { client.send_action("test", "tests.")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test :\u{001}ACTION tests.\u{001}\r\n" ); Ok(()) @@ -1904,7 +1896,7 @@ mod test { client.send_finger("test")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}FINGER\u{001}\r\n" ); Ok(()) @@ -1917,7 +1909,7 @@ mod test { client.send_version("test")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}VERSION\u{001}\r\n" ); Ok(()) @@ -1930,7 +1922,7 @@ mod test { client.send_source("test")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}SOURCE\u{001}\r\n" ); Ok(()) @@ -1943,7 +1935,7 @@ mod test { client.send_user_info("test")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}USERINFO\u{001}\r\n" ); Ok(()) @@ -1955,7 +1947,7 @@ mod test { let mut client = Client::from_config(test_config()).await?; client.send_ctcp_ping("test")?; client.stream()?.collect().await?; - let val = get_client_value(client); + let val = get_client_value(client).await; println!("{}", val); assert!(val.starts_with("PRIVMSG test :\u{001}PING ")); assert!(val.ends_with("\u{001}\r\n")); @@ -1969,7 +1961,7 @@ mod test { client.send_time("test")?; client.stream()?.collect().await?; assert_eq!( - &get_client_value(client)[..], + &get_client_value(client).await[..], "PRIVMSG test \u{001}TIME\u{001}\r\n" ); Ok(())