diff --git a/iroh-relay/src/server/client.rs b/iroh-relay/src/server/client.rs index f941b9dd0c..66476da0ff 100644 --- a/iroh-relay/src/server/client.rs +++ b/iroh-relay/src/server/client.rs @@ -48,6 +48,8 @@ pub(super) struct Config { pub(super) struct Client { /// Identity of the connected peer. node_id: NodeId, + /// Connection identifier. + connection_id: u64, /// Used to close the connection loop. done: CancellationToken, /// Actor handle. @@ -64,7 +66,7 @@ impl Client { /// Creates a client from a connection & starts a read and write loop to handle io to and from /// the client /// Call [`Client::shutdown`] to close the read and write loops before dropping the [`Client`] - pub(super) fn new(config: Config, clients: &Clients) -> Client { + pub(super) fn new(config: Config, connection_id: u64, clients: &Clients) -> Client { let Config { node_id, stream: io, @@ -98,29 +100,21 @@ impl Client { disco_send_queue: disco_send_queue_r, node_gone: peer_gone_r, node_id, + connection_id, clients: clients.clone(), }; // start io loop let io_done = done.clone(); - let handle = tokio::task::spawn( - async move { - match actor.run(io_done).await { - Err(e) => { - warn!("writer closed in error {e:#?}"); - } - Ok(()) => { - debug!("writer closed"); - } - } - } - .instrument( - tracing::info_span!("client connection actor", remote_node = %node_id.fmt_short()), - ), - ); + let handle = tokio::task::spawn(actor.run(io_done).instrument(tracing::info_span!( + "client connection actor", + remote_node = %node_id.fmt_short(), + connection_id = connection_id + ))); Client { node_id, + connection_id, handle: AbortOnDropHandle::new(handle), done, send_queue: send_queue_s, @@ -129,11 +123,15 @@ impl Client { } } + pub(super) fn connection_id(&self) -> u64 { + self.connection_id + } + /// Shutdown the reader and writer loops and closes the connection. /// /// Any shutdown errors will be logged as warnings. pub(super) async fn shutdown(self) { - self.done.cancel(); + self.start_shutdown(); if let Err(e) = self.handle.await { warn!( remote_node = %self.node_id.fmt_short(), @@ -142,6 +140,11 @@ impl Client { }; } + /// Starts the process of shutdown. + pub(super) fn start_shutdown(&self) { + self.done.cancel(); + } + pub(super) fn try_send_packet( &self, src: NodeId, @@ -194,12 +197,29 @@ struct Actor { node_gone: mpsc::Receiver, /// [`NodeId`] of this client node_id: NodeId, + /// Connection identifier. + connection_id: u64, /// Reference to the other connected clients. clients: Clients, } impl Actor { - async fn run(mut self, done: CancellationToken) -> Result<()> { + async fn run(mut self, done: CancellationToken) { + match self.run_inner(done).await { + Err(e) => { + warn!("actor errored {e:#?}, exiting"); + } + Ok(()) => { + debug!("actor finished, exiting"); + } + } + + self.clients + .unregister(self.connection_id, self.node_id) + .await; + } + + async fn run_inner(&mut self, done: CancellationToken) -> Result<()> { let jitter = Duration::from_secs(5); let mut keep_alive = tokio::time::interval(KEEP_ALIVE + jitter); // ticks immediately @@ -304,7 +324,7 @@ impl Actor { match frame { Frame::SendPacket { dst_key, packet } => { let packet_len = packet.len(); - self.handle_frame_send_packet(dst_key, packet).await?; + self.handle_frame_send_packet(dst_key, packet)?; inc_by!(Metrics, bytes_recv, packet_len as u64); } Frame::Ping { data } => { @@ -323,15 +343,13 @@ impl Actor { Ok(()) } - async fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<()> { + fn handle_frame_send_packet(&self, dst: NodeId, data: Bytes) -> Result<()> { if disco::looks_like_disco_wrapper(&data) { inc!(Metrics, disco_packets_recv); - self.clients - .send_disco_packet(dst, data, self.node_id) - .await?; + self.clients.send_disco_packet(dst, data, self.node_id)?; } else { inc!(Metrics, send_packets_recv); - self.clients.send_packet(dst, data, self.node_id).await?; + self.clients.send_packet(dst, data, self.node_id)?; } Ok(()) } @@ -546,6 +564,7 @@ mod tests { send_queue: send_queue_r, disco_send_queue: disco_send_queue_r, node_gone: peer_gone_r, + connection_id: 0, node_id, clients: clients.clone(), }; @@ -630,7 +649,7 @@ mod tests { .await?; done.cancel(); - handle.await??; + handle.await?; Ok(()) } diff --git a/iroh-relay/src/server/clients.rs b/iroh-relay/src/server/clients.rs index 322df35b78..c7d2238fca 100644 --- a/iroh-relay/src/server/clients.rs +++ b/iroh-relay/src/server/clients.rs @@ -1,7 +1,13 @@ //! The "Server" side of the client. Uses the `ClientConnManager`. // Based on tailscale/derp/derp_server.go -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::HashSet, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, +}; use anyhow::{bail, Result}; use bytes::Bytes; @@ -24,6 +30,8 @@ struct Inner { clients: DashMap, /// Map of which client has sent where sent_to: DashMap>, + /// Connection ID Counter + next_connection_id: AtomicU64, } impl Clients { @@ -41,9 +49,10 @@ impl Clients { /// Builds the client handler and starts the read & write loops for the connection. pub async fn register(&self, client_config: Config) { let node_id = client_config.node_id; + let connection_id = self.get_connection_id(); trace!(remote_node = node_id.fmt_short(), "registering client"); - let client = Client::new(client_config, self); + let client = Client::new(client_config, connection_id, self); if let Some(old_client) = self.0.clients.insert(node_id, client) { debug!( remote_node = node_id.fmt_short(), @@ -53,20 +62,23 @@ impl Clients { } } + fn get_connection_id(&self) -> u64 { + self.0.next_connection_id.fetch_add(1, Ordering::Relaxed) + } + /// Removes the client from the map of clients, & sends a notification /// to each client that peers has sent data to, to let them know that /// peer is gone from the network. /// /// Explicitly drops the reference to the client to avoid deadlock. - async fn unregister<'a>( - &self, - client: dashmap::mapref::one::Ref<'a, iroh_base::PublicKey, Client>, - node_id: NodeId, - ) { - drop(client); // avoid deadlock + pub(super) async fn unregister<'a>(&self, connection_id: u64, node_id: NodeId) { trace!(node_id = node_id.fmt_short(), "unregistering client"); - if let Some((_, client)) = self.0.clients.remove(&node_id) { + if let Some((_, client)) = self + .0 + .clients + .remove_if(&node_id, |_, c| c.connection_id() == connection_id) + { if let Some((_, sent_to)) = self.0.sent_to.remove(&node_id) { for key in sent_to { match client.try_send_peer_gone(key) { @@ -91,7 +103,7 @@ impl Clients { } /// Attempt to send a packet to client with [`NodeId`] `dst`. - pub(super) async fn send_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> { + pub(super) fn send_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> { let Some(client) = self.0.clients.get(&dst) else { debug!(dst = dst.fmt_short(), "no connected client, dropped packet"); inc!(Metrics, send_packets_dropped); @@ -115,19 +127,14 @@ impl Clients { dst = dst.fmt_short(), "can no longer write to client, dropping message and pruning connection" ); - self.unregister(client, dst).await; + client.start_shutdown(); bail!("failed to send message: gone"); } } } /// Attempt to send a disco packet to client with [`NodeId`] `dst`. - pub(super) async fn send_disco_packet( - &self, - dst: NodeId, - data: Bytes, - src: NodeId, - ) -> Result<()> { + pub(super) fn send_disco_packet(&self, dst: NodeId, data: Bytes, src: NodeId) -> Result<()> { let Some(client) = self.0.clients.get(&dst) else { debug!( dst = dst.fmt_short(), @@ -154,7 +161,7 @@ impl Clients { dst = dst.fmt_short(), "can no longer write to client, dropping disco message and pruning connection" ); - self.unregister(client, dst).await; + client.start_shutdown(); bail!("failed to send message: gone"); } } @@ -205,9 +212,7 @@ mod tests { // send packet let data = b"hello world!"; - clients - .send_packet(a_key, Bytes::from(&data[..]), b_key) - .await?; + clients.send_packet(a_key, Bytes::from(&data[..]), b_key)?; let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, @@ -218,9 +223,7 @@ mod tests { ); // send disco packet - clients - .send_disco_packet(a_key, Bytes::from(&data[..]), b_key) - .await?; + clients.send_disco_packet(a_key, Bytes::from(&data[..]), b_key)?; let frame = recv_frame(FrameType::RecvPacket, &mut a_rw).await?; assert_eq!( frame, @@ -230,13 +233,23 @@ mod tests { } ); - let client = clients.0.clients.get(&a_key).unwrap(); - - // send peer_gone. Also, tests that we do not get a deadlock - // when unregistering. - clients.unregister(client, a_key).await; + { + let client = clients.0.clients.get(&a_key).unwrap(); + // shutdown client a, this should trigger the removal from the clients list + client.start_shutdown(); + } - assert!(!clients.0.clients.contains_key(&a_key)); + // need to wait a moment for the removal to be processed + let c = clients.clone(); + tokio::time::timeout(Duration::from_secs(1), async move { + loop { + if !c.0.clients.contains_key(&a_key) { + break; + } + tokio::time::sleep(Duration::from_millis(100)).await; + } + }) + .await?; clients.shutdown().await; Ok(())