Skip to content

Commit

Permalink
feat(ress): limit active connections (#14928)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkrasiuk authored Mar 10, 2025
1 parent 38fc49f commit c0a4c3b
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 17 deletions.
35 changes: 31 additions & 4 deletions crates/net/ress/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ use std::{
collections::HashMap,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
};
use tokio::sync::oneshot;
Expand All @@ -33,8 +37,12 @@ pub struct RessProtocolConnection<P> {
conn: ProtocolConnection,
/// Stream of incoming commands.
commands: UnboundedReceiverStream<RessPeerRequest>,
/// The total number of active connections.
active_connections: Arc<AtomicU64>,
/// Flag indicating whether the node type was sent to the peer.
node_type_sent: bool,
/// Flag indicating whether this stream has previously been terminated.
terminated: bool,
/// Incremental counter for request ids.
next_id: u64,
/// Collection of inflight requests.
Expand All @@ -52,6 +60,7 @@ impl<P> RessProtocolConnection<P> {
peer_id: PeerId,
conn: ProtocolConnection,
commands: UnboundedReceiverStream<RessPeerRequest>,
active_connections: Arc<AtomicU64>,
) -> Self {
Self {
provider,
Expand All @@ -60,7 +69,9 @@ impl<P> RessProtocolConnection<P> {
peer_id,
conn,
commands,
active_connections,
node_type_sent: false,
terminated: false,
next_id: 0,
inflight_requests: HashMap::default(),
pending_witnesses: FuturesUnordered::new(),
Expand Down Expand Up @@ -241,6 +252,14 @@ where
}
}

impl<P> Drop for RessProtocolConnection<P> {
fn drop(&mut self) {
let _ = self
.active_connections
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |c| Some(c.saturating_sub(1)));
}
}

impl<P> Stream for RessProtocolConnection<P>
where
P: RessProtocolProvider + Clone + Unpin + 'static,
Expand All @@ -250,14 +269,18 @@ where
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();

if this.terminated {
return Poll::Ready(None)
}

if !this.node_type_sent {
this.node_type_sent = true;
return Poll::Ready(Some(RessProtocolMessage::node_type(this.node_type).encoded()))
}

loop {
'conn: loop {
if let Poll::Ready(maybe_cmd) = this.commands.poll_next_unpin(cx) {
let Some(cmd) = maybe_cmd else { return Poll::Ready(None) };
let Some(cmd) = maybe_cmd else { break 'conn };
let message = this.on_command(cmd);
let encoded = message.encoded();
trace!(target: "ress::net::connection", peer_id = %this.peer_id, ?message, encoded = alloy_primitives::hex::encode(&encoded), "Sending peer command");
Expand All @@ -272,7 +295,7 @@ where
}

if let Poll::Ready(maybe_msg) = this.conn.poll_next_unpin(cx) {
let Some(next) = maybe_msg else { return Poll::Ready(None) };
let Some(next) = maybe_msg else { break 'conn };
let msg = match RessProtocolMessage::decode_message(&mut &next[..]) {
Ok(msg) => {
trace!(target: "ress::net::connection", peer_id = %this.peer_id, message = ?msg.message_type, "Processing message");
Expand All @@ -287,7 +310,7 @@ where

match this.on_ress_message(msg) {
OnRessMessageOutcome::Response(bytes) => return Poll::Ready(Some(bytes)),
OnRessMessageOutcome::Terminate => return Poll::Ready(None),
OnRessMessageOutcome::Terminate => break 'conn,
OnRessMessageOutcome::None => {}
};

Expand All @@ -296,6 +319,10 @@ where

return Poll::Pending;
}

// Terminating the connection.
this.terminated = true;
Poll::Ready(None)
}
}

Expand Down
75 changes: 69 additions & 6 deletions crates/net/ress/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@ use reth_eth_wire::{
};
use reth_network::protocol::{ConnectionHandler, OnNotSupported, ProtocolHandler};
use reth_network_api::{test_utils::PeersHandle, Direction, PeerId};
use std::{fmt, net::SocketAddr};
use std::{
fmt,
net::SocketAddr,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use tokio::sync::mpsc;
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::*;

/// The events that can be emitted by our custom protocol.
#[derive(Debug)]
Expand All @@ -23,13 +31,32 @@ pub enum ProtocolEvent {
/// Sender part for forwarding commands.
to_connection: mpsc::UnboundedSender<RessPeerRequest>,
},
/// Number of max active connections exceeded. New connection was rejected.
MaxActiveConnectionsExceeded {
/// The current number
num_active: u64,
},
}

/// Protocol state is an helper struct to store the protocol events.
#[derive(Clone, Debug)]
pub struct ProtocolState {
/// Protocol event sender.
pub events_sender: mpsc::UnboundedSender<ProtocolEvent>,
/// The number of active connections.
pub active_connections: Arc<AtomicU64>,
}

impl ProtocolState {
/// Create new protocol state.
pub fn new(events_sender: mpsc::UnboundedSender<ProtocolEvent>) -> Self {
Self { events_sender, active_connections: Arc::default() }
}

/// Returns the current number of active connections.
pub fn active_connections(&self) -> u64 {
self.active_connections.load(Ordering::Relaxed)
}
}

/// The protocol handler takes care of incoming and outgoing connections.
Expand All @@ -41,6 +68,8 @@ pub struct RessProtocolHandler<P> {
pub node_type: NodeType,
/// Peers handle.
pub peers_handle: PeersHandle,
/// The maximum number of active connections.
pub max_active_connections: u64,
/// Current state of the protocol.
pub state: ProtocolState,
}
Expand All @@ -50,6 +79,7 @@ impl<P> fmt::Debug for RessProtocolHandler<P> {
f.debug_struct("RessProtocolHandler")
.field("node_type", &self.node_type)
.field("peers_handle", &self.peers_handle)
.field("max_active_connections", &self.max_active_connections)
.field("state", &self.state)
.finish_non_exhaustive()
}
Expand All @@ -61,16 +91,44 @@ where
{
type ConnectionHandler = Self;

fn on_incoming(&self, _socket_addr: SocketAddr) -> Option<Self::ConnectionHandler> {
Some(self.clone())
fn on_incoming(&self, socket_addr: SocketAddr) -> Option<Self::ConnectionHandler> {
let num_active = self.state.active_connections();
if num_active >= self.max_active_connections {
trace!(
target: "ress::net",
num_active, max_connections = self.max_active_connections, %socket_addr,
"ignoring incoming connection, max active reached"
);
let _ = self
.state
.events_sender
.send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
None
} else {
Some(self.clone())
}
}

fn on_outgoing(
&self,
_socket_addr: SocketAddr,
_peer_id: PeerId,
socket_addr: SocketAddr,
peer_id: PeerId,
) -> Option<Self::ConnectionHandler> {
Some(self.clone())
let num_active = self.state.active_connections();
if num_active >= self.max_active_connections {
trace!(
target: "ress::net",
num_active, max_connections = self.max_active_connections, %socket_addr, %peer_id,
"ignoring outgoing connection, max active reached"
);
let _ = self
.state
.events_sender
.send(ProtocolEvent::MaxActiveConnectionsExceeded { num_active });
None
} else {
Some(self.clone())
}
}
}

Expand Down Expand Up @@ -105,18 +163,23 @@ where
) -> Self::Connection {
let (tx, rx) = mpsc::unbounded_channel();

// Emit connection established event.
self.state
.events_sender
.send(ProtocolEvent::Established { direction, peer_id, to_connection: tx })
.ok();

// Increment the number of active sessions.
self.state.active_connections.fetch_add(1, Ordering::Relaxed);

RessProtocolConnection::new(
self.provider.clone(),
self.node_type,
self.peers_handle,
peer_id,
conn,
UnboundedReceiverStream::from(rx),
self.state.active_connections,
)
}
}
Loading

0 comments on commit c0a4c3b

Please sign in to comment.