Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/executor/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub(crate) const RAW_SOCKET_BUFFER_SIZE: usize = 256 * 1024;
pub(crate) struct RawSocket {
pub remote_cid: u32,
pub remote_port: u32,
/// The listen port this connection was accepted on. Zero for listener and
/// outbound-connect sockets.
pub listen_port: u32,
pub fwd_cnt: u32,
pub peer_fwd_cnt: u32,
pub peer_buf_alloc: u32,
Expand All @@ -48,6 +51,7 @@ impl RawSocket {
Self {
remote_cid: 0,
remote_port: 0,
listen_port: 0,
fwd_cnt: 0,
peer_fwd_cnt: 0,
peer_buf_alloc: 0,
Expand Down Expand Up @@ -78,7 +82,22 @@ async fn vsock_run() {
let mut vsock_guard = VSOCK_MAP.lock();
let header_cid: u32 = header.src_cid.to_ne().try_into().unwrap();

let Some(raw) = vsock_guard.get_mut_socket(port) else {
// For data/shutdown packets, prefer a connected socket that was
// accepted on this port over the listener entry itself.
let header_cid_inner: u32 = header_cid;
let raw_port = header.src_port.to_ne();
let raw = if matches!(op, Op::Rw | Op::Shutdown | Op::CreditUpdate | Op::Response) {
if let Some(conn) = vsock_guard.get_mut_connected(port, header_cid_inner, raw_port)
{
conn
} else if let Some(s) = vsock_guard.get_mut_socket(port) {
s
} else {
return;
}
} else if let Some(s) = vsock_guard.get_mut_socket(port) {
s
} else {
return;
};

Expand Down Expand Up @@ -206,9 +225,48 @@ impl VsockMap {
self.port_map.get_mut(&port)
}

/// Look up a connected socket by its original listen port and the remote
/// endpoint. Used to route data packets after a connection has been moved
/// to an ephemeral port by `move_to_ephemeral`.
pub fn get_mut_connected(
&mut self,
listen_port: u32,
remote_cid: u32,
remote_port: u32,
) -> Option<&mut RawSocket> {
self.port_map.values_mut().find(|raw| {
raw.state == VsockState::Connected
&& raw.listen_port == listen_port
&& raw.remote_cid == remote_cid
&& raw.remote_port == remote_port
})
}

pub fn remove_socket(&mut self, port: u32) {
self.port_map.remove(&port);
}

/// Move the socket at `listen_port` to a fresh ephemeral port, reset the
/// listener entry to `Listen`, and return the ephemeral port.
pub fn move_to_ephemeral(&mut self, listen_port: u32) -> io::Result<u32> {
let mut conn = self.port_map.remove(&listen_port).ok_or(Errno::Inval)?;
conn.state = VsockState::Connected;
conn.listen_port = listen_port;

for ep in u32::MAX / 4..u32::MAX {
if let btree_map::Entry::Vacant(v) = self.port_map.entry(ep) {
v.insert(conn);
self.port_map
.insert(listen_port, RawSocket::new(VsockState::Listen));
return Ok(ep);
}
}

// No ephemeral port available; restore the entry to avoid losing it.
self.port_map
.insert(listen_port, RawSocket::new(VsockState::Listen));
Err(Errno::Badf)
}
}

pub(crate) fn init() {
Expand Down
69 changes: 38 additions & 31 deletions src/fd/socket/vsock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ impl ObjectInterface for NullSocket {}

pub struct Socket {
port: u32,
/// The port this socket is bound/listening on. Stays fixed across accepts
/// while `port` is updated to the ephemeral connection port after each accept.
listen_port: u32,
cid: u32,
is_nonblocking: bool,
}
Expand All @@ -59,6 +62,7 @@ impl Socket {
pub fn new() -> Self {
Self {
port: 0,
listen_port: 0,
cid: u32::MAX,
is_nonblocking: false,
}
Expand Down Expand Up @@ -139,6 +143,7 @@ impl ObjectInterface for Socket {
match endpoint {
ListenEndpoint::Vsock(ep) => {
self.port = ep.port;
self.listen_port = ep.port;
if let Some(cid) = ep.cid {
self.cid = cid;
} else {
Expand Down Expand Up @@ -234,10 +239,10 @@ impl ObjectInterface for Socket {
}

async fn accept(&mut self) -> io::Result<(Arc<async_lock::RwLock<Fd>>, Endpoint)> {
let port = self.port;
let port = self.listen_port;
let cid = self.cid;

let endpoint = future::poll_fn(|cx| {
let (conn_port, endpoint) = future::poll_fn(|cx| {
let mut guard = VSOCK_MAP.lock();
let raw = guard.get_mut_socket(port).ok_or(Errno::Inval)?;

Expand All @@ -251,44 +256,46 @@ impl ObjectInterface for Socket {
}
}
VsockState::ReceiveRequest => {
let result = {
const HEADER_SIZE: usize = size_of::<Hdr>();
let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
let local_cid = driver_guard.get_cid();

driver_guard.send_packet(HEADER_SIZE, |buffer| {
let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };

response.src_cid = le64::from_ne(local_cid);
response.dst_cid = le64::from_ne(raw.remote_cid.into());
response.src_port = le32::from_ne(port);
response.dst_port = le32::from_ne(raw.remote_port);
response.len = le32::from_ne(0);
response.type_ = le16::from_ne(Type::Stream.into());
if local_cid != u64::from(cid) && cid != u32::MAX {
response.op = le16::from_ne(Op::Rst.into());
} else {
response.op = le16::from_ne(Op::Response.into());
}
response.flags = le32::from_ne(0);
response.buf_alloc = le32::from_ne(
crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32,
);
response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
});
const HEADER_SIZE: usize = size_of::<Hdr>();
let mut driver_guard = hardware::get_vsock_driver().unwrap().lock();
let local_cid = driver_guard.get_cid();

driver_guard.send_packet(HEADER_SIZE, |buffer| {
let response = unsafe { &mut *buffer.as_mut_ptr().cast::<Hdr>() };

response.src_cid = le64::from_ne(local_cid);
response.dst_cid = le64::from_ne(raw.remote_cid.into());
response.src_port = le32::from_ne(port);
response.dst_port = le32::from_ne(raw.remote_port);
response.len = le32::from_ne(0);
response.type_ = le16::from_ne(Type::Stream.into());
if local_cid != u64::from(cid) && cid != u32::MAX {
response.op = le16::from_ne(Op::Rst.into());
} else {
response.op = le16::from_ne(Op::Response.into());
}
response.flags = le32::from_ne(0);
response.buf_alloc =
le32::from_ne(crate::executor::vsock::RAW_SOCKET_BUFFER_SIZE as u32);
response.fwd_cnt = le32::from_ne(raw.fwd_cnt);
});

raw.state = VsockState::Connected;
let endpoint = VsockEndpoint::new(raw.remote_port, raw.remote_cid);

Ok(VsockEndpoint::new(raw.remote_port, raw.remote_cid))
};
// Move the accepted connection to an ephemeral port so the
// listener entry can be reset to Listen for the next accept.
let conn_port = guard.move_to_ephemeral(port)?;

Poll::Ready(result)
Poll::Ready(Ok((conn_port, endpoint)))
}
_ => Poll::Ready(Err(Errno::Badf)),
}
})
.await?;

// This Socket now tracks the accepted connection, not the listener.
self.port = conn_port;

Ok((
Arc::new(async_lock::RwLock::new(NullSocket::new().into())),
Endpoint::Vsock(endpoint),
Expand Down
33 changes: 32 additions & 1 deletion xtask/src/ci/qemu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,19 @@ impl Qemu {
.any(|feature| feature == "client");
test_vsock(has_client)?;
}
"vsock_server" => {
test_vsock_server()?;
}
_ => {}
}

if matches!(
image_name,
"axum-example" | "http_server" | "http_server_poll" | "http_server_select" | "vsock"
"axum-example"
| "http_server"
| "http_server_poll"
| "http_server_select"
| "vsock" | "vsock_server"
) || self.devices.contains(&Device::CadenceGem)
// sifive_u, on which we test CadenceGem, does not support software shutdowns, so we have to terminate the machine ourselves.
{
Expand Down Expand Up @@ -624,6 +631,30 @@ fn test_vsock(has_client: bool) -> Result<()> {
Ok(())
}

fn test_vsock_server() -> Result<()> {
const PORT: u32 = 9975;
const CONNECTIONS: usize = 2;

thread::sleep(Duration::from_secs(10));
let first_stream = VsockStream::connect_with_cid_port(3, PORT)?;

let do_ping_pong = |mut stream: VsockStream| -> Result<()> {
stream.write_all(b"ping")?;
let mut buf = [0u8; 64];
let n = stream.read(&mut buf)?;
let msg = from_utf8(&buf[..n])?;
ensure!(msg == "pong", "expected 'pong', got {msg:?}");
Ok(())
};

do_ping_pong(first_stream)?;
for _ in 1..CONNECTIONS {
do_ping_pong(VsockStream::connect_with_cid_port(3, PORT)?)?;
}

Ok(())
}

fn test_http_server(guest_ip: IpAddr) -> Result<()> {
thread::sleep(Duration::from_secs(10));
let url = format!("http://{guest_ip}:9975");
Expand Down
Loading