Skip to content

Commit

Permalink
Bugfix/133- fix util unit test memory leak (#625)
Browse files Browse the repository at this point in the history
* Remove cyclic dependency in vnet::UdpConn

- vnet::UdpConn holds Arc to vnet::ConnObserver, which would most of the time hold the Arc to vnet::UdpConn. We use Weak when pointing upward (i.e. parent)

* Remove cyclic dependency in vnet::resolver

- We replace the Arc pointing to parent with Weak

* Fix util::UdpConn owns it's own Arc

- util::UdpConn should not own the table of UdpConn because newly created util::UdpConn will always be added into the table. This table of UdpConn is sufficiently owned by Listener and the ListenConfig::read_loop.

* Remove cyclic dependency in vnet::router

- The Arc pointing to parent is replaced by Weak
- Nic and RouterInternal should not point to each other. Make the nics table in RouterInternal Weak.

* Fix clippy warning

---------

Co-authored-by: mutexd <[email protected]>
  • Loading branch information
mutexd and mutexd authored Nov 3, 2024
1 parent a1611af commit 2ec027f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 35 deletions.
12 changes: 2 additions & 10 deletions util/src/conn/conn_udp_listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl ListenConfig {
}
}

let udp_conn = Arc::new(UdpConn::new(Arc::clone(pconn), Arc::clone(conns), raddr));
let udp_conn = Arc::new(UdpConn::new(Arc::clone(pconn), raddr));
{
let accept_ch = accept_ch_tx.lock().await;
if let Some(tx) = &*accept_ch {
Expand All @@ -235,20 +235,14 @@ impl ListenConfig {
/// UdpConn augments a connection-oriented connection over a UdpSocket
pub struct UdpConn {
pconn: Arc<dyn Conn + Send + Sync>,
conns: Arc<Mutex<HashMap<String, Arc<UdpConn>>>>,
raddr: SocketAddr,
buffer: Buffer,
}

impl UdpConn {
fn new(
pconn: Arc<dyn Conn + Send + Sync>,
conns: Arc<Mutex<HashMap<String, Arc<UdpConn>>>>,
raddr: SocketAddr,
) -> Self {
fn new(pconn: Arc<dyn Conn + Send + Sync>, raddr: SocketAddr) -> Self {
UdpConn {
pconn,
conns,
raddr,
buffer: Buffer::new(0, 0),
}
Expand Down Expand Up @@ -287,8 +281,6 @@ impl Conn for UdpConn {
}

async fn close(&self) -> Result<()> {
let mut conns = self.conns.lock().await;
conns.remove(self.raddr.to_string().as_str());
Ok(())
}

Expand Down
17 changes: 11 additions & 6 deletions util/src/vnet/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ mod conn_test;

use std::net::{IpAddr, SocketAddr};
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::{Arc, Weak};

use async_trait::async_trait;
use portable_atomic::AtomicBool;
Expand Down Expand Up @@ -34,7 +34,7 @@ pub(crate) struct UdpConn {
read_ch_tx: Arc<Mutex<Option<ChunkChTx>>>,
read_ch_rx: Mutex<mpsc::Receiver<Box<dyn Chunk + Send + Sync>>>,
closed: AtomicBool,
obs: Arc<Mutex<dyn ConnObserver + Send + Sync>>,
obs: Weak<Mutex<dyn ConnObserver + Send + Sync>>,
}

impl UdpConn {
Expand All @@ -45,13 +45,14 @@ impl UdpConn {
) -> Self {
let (read_ch_tx, read_ch_rx) = mpsc::channel(MAX_READ_QUEUE_SIZE);

let weak_obs = Arc::downgrade(&obs);
UdpConn {
loc_addr,
rem_addr: RwLock::new(rem_addr),
read_ch_tx: Arc::new(Mutex::new(Some(read_ch_tx))),
read_ch_rx: Mutex::new(read_ch_rx),
closed: AtomicBool::new(false),
obs,
obs: weak_obs,
}
}

Expand Down Expand Up @@ -112,8 +113,10 @@ impl Conn for UdpConn {
/// send_to writes a packet with payload p to addr.
/// send_to can be made to time out and return
async fn send_to(&self, buf: &[u8], target: SocketAddr) -> Result<usize> {
let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;

let src_ip = {
let obs = self.obs.lock().await;
let obs = obs.lock().await;
match obs.determine_source_ip(self.loc_addr.ip(), target.ip()) {
Some(ip) => ip,
None => return Err(Error::ErrLocAddr),
Expand All @@ -126,7 +129,7 @@ impl Conn for UdpConn {
chunk.user_data = buf.to_vec();
{
let c: Box<dyn Chunk + Send + Sync> = Box::new(chunk);
let obs = self.obs.lock().await;
let obs = obs.lock().await;
obs.write(c).await?
}

Expand All @@ -142,6 +145,8 @@ impl Conn for UdpConn {
}

async fn close(&self) -> Result<()> {
let obs = self.obs.upgrade().ok_or_else(|| Error::ErrVnetDisabled)?;

if self.closed.load(Ordering::SeqCst) {
return Err(Error::ErrAlreadyClosed);
}
Expand All @@ -151,7 +156,7 @@ impl Conn for UdpConn {
reach_ch.take();
}
{
let obs = self.obs.lock().await;
let obs = obs.lock().await;
obs.on_closed(self.loc_addr).await;
}

Expand Down
11 changes: 5 additions & 6 deletions util/src/vnet/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@ use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::sync::Weak;

use tokio::sync::Mutex;

use crate::error::*;

#[derive(Default)]
pub(crate) struct Resolver {
parent: Option<Arc<Mutex<Resolver>>>,
parent: Option<Weak<Mutex<Resolver>>>,
hosts: HashMap<String, IpAddr>,
}

Expand All @@ -31,7 +31,7 @@ impl Resolver {
r
}

pub(crate) fn set_parent(&mut self, p: Arc<Mutex<Resolver>>) {
pub(crate) fn set_parent(&mut self, p: Weak<Mutex<Resolver>>) {
self.parent = Some(p);
}

Expand All @@ -55,10 +55,9 @@ impl Resolver {
}

// mutex must be unlocked before calling into parent Resolver
if let Some(parent) = &self.parent {
let parent2 = Arc::clone(parent);
if let Some(parent) = self.parent.clone().and_then(|p| p.upgrade()).clone() {
Box::pin(async move {
let p = parent2.lock().await;
let p = parent.lock().await;
p.lookup(host_name).await
})
} else {
Expand Down
4 changes: 3 additions & 1 deletion util/src/vnet/resolver/resolver_test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use std::sync::Arc;

const DEMO_IP: &str = "1.2.3.4";

Expand Down Expand Up @@ -49,7 +50,8 @@ async fn test_resolver_cascaded() -> Result<()> {
let ip1 = IpAddr::from_str(ip_addr1)?;
r1.add_host(name1.to_owned(), ip_addr1.to_owned())?;

r1.set_parent(Arc::new(Mutex::new(r0)));
let resolver0 = Arc::new(Mutex::new(r0));
r1.set_parent(Arc::downgrade(&resolver0));

if let Some(resolved) = r1.lookup(name0.to_owned()).await {
assert_eq!(resolved, ip0, "should match");
Expand Down
24 changes: 12 additions & 12 deletions util/src/vnet/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::ops::{Add, Sub};
use std::pin::Pin;
use std::str::FromStr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::sync::{Arc, Weak};
use std::time::SystemTime;

use async_trait::async_trait;
Expand Down Expand Up @@ -77,12 +77,12 @@ pub type ChunkFilterFn = Box<dyn (Fn(&(dyn Chunk + Send + Sync)) -> bool) + Send

#[derive(Default)]
pub struct RouterInternal {
pub(crate) nat_type: Option<NatType>, // read-only
pub(crate) ipv4net: IpNet, // read-only
pub(crate) parent: Option<Arc<Mutex<Router>>>, // read-only
pub(crate) nat: NetworkAddressTranslator, // read-only
pub(crate) nics: HashMap<String, Arc<Mutex<dyn Nic + Send + Sync>>>, // read-only
pub(crate) chunk_filters: Vec<ChunkFilterFn>, // requires mutex [x]
pub(crate) nat_type: Option<NatType>, // read-only
pub(crate) ipv4net: IpNet, // read-only
pub(crate) parent: Option<Weak<Mutex<Router>>>, // read-only
pub(crate) nat: NetworkAddressTranslator, // read-only
pub(crate) nics: HashMap<String, Weak<Mutex<dyn Nic + Send + Sync>>>, // read-only
pub(crate) chunk_filters: Vec<ChunkFilterFn>, // requires mutex [x]
pub(crate) last_id: u8, // requires mutex [x], used to assign the last digit of IPv4 address
}

Expand Down Expand Up @@ -157,7 +157,7 @@ impl Nic for Router {
async fn set_router(&self, parent: Arc<Mutex<Router>>) -> Result<()> {
{
let mut router_internal = self.router_internal.lock().await;
router_internal.parent = Some(Arc::clone(&parent));
router_internal.parent = Some(Arc::downgrade(&parent));
}

let parent_resolver = {
Expand All @@ -166,7 +166,7 @@ impl Nic for Router {
};
{
let mut resolver = self.resolver.lock().await;
resolver.set_parent(parent_resolver);
resolver.set_parent(Arc::downgrade(&parent_resolver));
}

let mut mapped_ips = vec![];
Expand Down Expand Up @@ -492,7 +492,7 @@ impl Router {
// check if the destination is in our subnet
if ipv4net.contains(&dst_ip) {
// search for the destination NIC
if let Some(nic) = ri.nics.get(&dst_ip.to_string()) {
if let Some(nic) = ri.nics.get(&dst_ip.to_string()).and_then(|p| p.upgrade()) {
// found the NIC, forward the chunk to the NIC.
// call to NIC must unlock mutex
let ni = nic.lock().await;
Expand All @@ -504,7 +504,7 @@ impl Router {
} else {
// the destination is outside of this subnet
// is this WAN?
if let Some(parent) = &ri.parent {
if let Some(parent) = &ri.parent.clone().and_then(|p| p.upgrade()) {
// Pass it to the parent via NAT
if let Some(to_parent) = ri.nat.translate_outbound(&*c).await? {
// call to parent router mutex unlock mutex
Expand Down Expand Up @@ -545,7 +545,7 @@ impl RouterInternal {
if !self.ipv4net.contains(ip) {
return Err(Error::ErrStaticIpIsBeyondSubnet);
}
self.nics.insert(ip.to_string(), Arc::clone(&nic));
self.nics.insert(ip.to_string(), Arc::downgrade(&nic));
ipnets.push(IpNet::from_str(&format!(
"{}/{}",
ip,
Expand Down

0 comments on commit 2ec027f

Please sign in to comment.