Skip to content

Commit 65a946a

Browse files
committed
Separate ports by protocol
1 parent 6f88ec8 commit 65a946a

File tree

4 files changed

+140
-73
lines changed

4 files changed

+140
-73
lines changed

msim-tokio/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,4 @@ real_tokio = { git = "https://github.com/mystenmark/tokio-madsim-fork.git", rev
6565
bytes = { version = "1.1" }
6666
futures = { version = "0.3.0", features = ["async-await"] }
6767
mio = { version = "0.8.1" }
68+
libc = "0.2"

msim-tokio/src/sim/net.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ impl TcpStream {
350350
}
351351

352352
async fn connect_addr(addr: impl ToSocketAddrs) -> io::Result<TcpStream> {
353-
let ep = Arc::new(Endpoint::connect(addr).await?);
353+
let ep = Arc::new(Endpoint::connect(libc::SOCK_STREAM, addr).await?);
354354
trace!("connect {:?}", ep.local_addr());
355355

356356
let remote_sock = ep.peer_addr()?;
@@ -714,7 +714,7 @@ impl TcpSocket {
714714
}
715715

716716
pub fn bind(&self, addr: StdSocketAddr) -> io::Result<()> {
717-
let ep = Endpoint::bind_sync(addr)?;
717+
let ep = Endpoint::bind_sync(libc::SOCK_STREAM, addr)?;
718718
*self.bind_addr.lock().unwrap() = Some(ep.into());
719719
Ok(())
720720
}

msim/src/sim/net/mod.rs

Lines changed: 74 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ unsafe fn accept_impl(
325325
) -> libc::c_int {
326326
let result = HostNetworkState::with_socket(
327327
sock_fd,
328-
|socket| -> Result<SocketAddr, (libc::c_int, libc::c_int)> {
328+
|socket| -> Result<(SocketAddr, libc::c_int), (libc::c_int, libc::c_int)> {
329329
let node = plugin::node();
330330
let net = plugin::simulator::<NetSim>();
331331
let network = net.network.lock().unwrap();
@@ -343,7 +343,8 @@ unsafe fn accept_impl(
343343
// We can't simulate blocking accept in a single-threaded simulator, so if there is no
344344
// connection waiting for us, just bail.
345345
network
346-
.accept_connect(node, endpoint.addr)
346+
.accept_connect(socket.ty, node, endpoint.addr)
347+
.map(|addr| (addr, socket.ty))
347348
.ok_or((-1, libc::ECONNABORTED))
348349
},
349350
)
@@ -352,18 +353,18 @@ unsafe fn accept_impl(
352353
Result::Err((-1, libc::ENOTSOCK))
353354
});
354355

355-
let remote_addr = match result {
356+
let (remote_addr, proto) = match result {
356357
Err((ret, err)) => {
357358
trace!("error status: {} {}", ret, err);
358359
set_errno(err);
359360
return ret;
360361
}
361-
Ok(addr) => addr,
362+
Ok(res) => res,
362363
};
363364

364365
write_socket_addr(address, address_len, remote_addr);
365366

366-
let endpoint = Endpoint::connect_sync(remote_addr)
367+
let endpoint = Endpoint::connect_sync(proto, remote_addr)
367368
.expect("connection failure should already have been detected");
368369

369370
let fd = alloc_fd();
@@ -396,7 +397,7 @@ define_sys_interceptor!(
396397

397398
HostNetworkState::with_socket(sock_fd, |socket| {
398399
assert!(socket.endpoint.is_none(), "socket already bound");
399-
match Endpoint::bind_sync(socket_addr) {
400+
match Endpoint::bind_sync(socket.ty, socket_addr) {
400401
Ok(ep) => {
401402
socket.endpoint = Some(Arc::new(ep));
402403
0
@@ -438,7 +439,7 @@ define_sys_interceptor!(
438439
return Err((-1, libc::EISCONN));
439440
}
440441

441-
let ep = Endpoint::connect_sync(sock_addr).map_err(|e| match e.kind() {
442+
let ep = Endpoint::connect_sync(socket.ty, sock_addr).map_err(|e| match e.kind() {
442443
io::ErrorKind::AddrInUse => (-1, libc::EADDRINUSE),
443444
io::ErrorKind::AddrNotAvailable => (-1, libc::EADDRNOTAVAIL),
444445
_ => {
@@ -453,7 +454,7 @@ define_sys_interceptor!(
453454
// the other end goes away).
454455
let net = plugin::simulator::<NetSim>();
455456
let network = net.network.lock().unwrap();
456-
if !network.signal_connect(ep.addr, sock_addr) {
457+
if !network.signal_connect(socket.ty, ep.addr, sock_addr) {
457458
return Err((-1, libc::ECONNREFUSED));
458459
}
459460

@@ -544,8 +545,7 @@ define_sys_interceptor!(
544545
match (level, name) {
545546
// called by anemo::Network::start (via socket2)
546547
// skip returning any value here since Sui only uses it to log an error anyway
547-
(libc::SOL_SOCKET, libc::SO_RCVBUF) |
548-
(libc::SOL_SOCKET, libc::SO_SNDBUF) => 0,
548+
(libc::SOL_SOCKET, libc::SO_RCVBUF) | (libc::SOL_SOCKET, libc::SO_SNDBUF) => 0,
549549

550550
_ => {
551551
warn!("unhandled getsockopt {} {}", level, name);
@@ -1015,6 +1015,7 @@ pub struct Endpoint {
10151015
net: Arc<NetSim>,
10161016
node: NodeId,
10171017
addr: SocketAddr,
1018+
proto: libc::c_int,
10181019
peer: Option<SocketAddr>,
10191020
live_tcp_ids: Mutex<HashSet<u32>>,
10201021
}
@@ -1030,16 +1031,17 @@ impl std::fmt::Debug for Endpoint {
10301031
}
10311032

10321033
impl Endpoint {
1033-
/// Bind synchronously (for UDP)
1034-
pub fn bind_sync(addr: impl ToSocketAddrs) -> io::Result<Self> {
1034+
/// Bind synchronously
1035+
pub fn bind_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
10351036
let net = plugin::simulator::<NetSim>();
10361037
let node = plugin::node();
10371038
let addr = addr.to_socket_addrs()?.next().unwrap();
1038-
let addr = net.network.lock().unwrap().bind(node, addr)?;
1039+
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
10391040
let ep = Endpoint {
10401041
net,
10411042
node,
10421043
addr,
1044+
proto,
10431045
peer: None,
10441046
live_tcp_ids: Default::default(),
10451047
};
@@ -1063,30 +1065,31 @@ impl Endpoint {
10631065
}
10641066

10651067
/// Creates a [`Endpoint`] from the given address.
1066-
pub async fn bind(addr: impl ToSocketAddrs) -> io::Result<Self> {
1068+
pub async fn bind(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
10671069
let net = plugin::simulator::<NetSim>();
10681070
let node = plugin::node();
10691071
let addr = addr.to_socket_addrs()?.next().unwrap();
10701072
net.rand_delay().await;
1071-
let addr = net.network.lock().unwrap().bind(node, addr)?;
1073+
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
10721074
Ok(Endpoint {
10731075
net,
10741076
node,
10751077
addr,
1078+
proto,
10761079
peer: None,
10771080
live_tcp_ids: Default::default(),
10781081
})
10791082
}
10801083

10811084
/// Connects this [`Endpoint`] to a remote address.
1082-
pub async fn connect(addr: impl ToSocketAddrs) -> io::Result<Self> {
1085+
pub async fn connect(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
10831086
let net = plugin::simulator::<NetSim>();
10841087
net.rand_delay().await;
1085-
Self::connect_sync(addr)
1088+
Self::connect_sync(proto, addr)
10861089
}
10871090

10881091
/// For libc::connect()
1089-
pub fn connect_sync(addr: impl ToSocketAddrs) -> io::Result<Self> {
1092+
pub fn connect_sync(proto: libc::c_int, addr: impl ToSocketAddrs) -> io::Result<Self> {
10901093
let net = plugin::simulator::<NetSim>();
10911094
let node = plugin::node();
10921095
let peer = addr.to_socket_addrs()?.next().unwrap();
@@ -1095,11 +1098,12 @@ impl Endpoint {
10951098
} else {
10961099
SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0))
10971100
};
1098-
let addr = net.network.lock().unwrap().bind(node, addr)?;
1101+
let addr = net.network.lock().unwrap().bind(node, proto, addr)?;
10991102
Ok(Endpoint {
11001103
net,
11011104
node,
11021105
addr,
1106+
proto,
11031107
peer: Some(peer),
11041108
live_tcp_ids: Default::default(),
11051109
})
@@ -1128,7 +1132,7 @@ impl Endpoint {
11281132
.network
11291133
.lock()
11301134
.unwrap()
1131-
.deregister_tcp_id(self.node, remote_sock, id);
1135+
.deregister_tcp_id(self.node, self.proto, remote_sock, id);
11321136
}
11331137

11341138
/// Returns the local socket address.
@@ -1234,7 +1238,7 @@ impl Endpoint {
12341238
.network
12351239
.lock()
12361240
.unwrap()
1237-
.send(plugin::node(), self.addr, dst, tag, data)
1241+
.send(plugin::node(), self.proto, self.addr, dst, tag, data)
12381242
}
12391243

12401244
/// Receives a raw message.
@@ -1244,12 +1248,12 @@ impl Endpoint {
12441248
#[cfg_attr(docsrs, doc(cfg(msim)))]
12451249
pub async fn recv_from_raw(&self, tag: u64) -> io::Result<(Payload, SocketAddr)> {
12461250
trace!("awaiting recv: {} tag={:x}", self.addr, tag);
1247-
let recver = self
1248-
.net
1249-
.network
1250-
.lock()
1251-
.unwrap()
1252-
.recv(plugin::node(), self.addr, tag);
1251+
let recver =
1252+
self.net
1253+
.network
1254+
.lock()
1255+
.unwrap()
1256+
.recv(plugin::node(), self.proto, self.addr, tag);
12531257
let msg = recver
12541258
.await
12551259
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "network is down"))?;
@@ -1266,7 +1270,7 @@ impl Endpoint {
12661270
.network
12671271
.lock()
12681272
.unwrap()
1269-
.recv_sync(plugin::node(), self.addr, tag)
1273+
.recv_sync(plugin::node(), self.proto, self.addr, tag)
12701274
.ok_or_else(|| io::Error::new(io::ErrorKind::WouldBlock, "recv call would blck"))?;
12711275

12721276
trace!(
@@ -1320,12 +1324,13 @@ impl Endpoint {
13201324
/// Check if there is a message waiting that can be received without blocking.
13211325
/// If not, schedule a wakeup using the context.
13221326
pub fn recv_ready(&self, cx: Option<&mut Context<'_>>, tag: u64) -> io::Result<bool> {
1323-
Ok(self
1324-
.net
1325-
.network
1326-
.lock()
1327-
.unwrap()
1328-
.recv_ready(cx, plugin::node(), self.addr, tag))
1327+
Ok(self.net.network.lock().unwrap().recv_ready(
1328+
cx,
1329+
plugin::node(),
1330+
self.proto,
1331+
self.addr,
1332+
tag,
1333+
))
13291334
}
13301335
}
13311336

@@ -1338,7 +1343,7 @@ impl Drop for Endpoint {
13381343

13391344
// avoid panic on panicking
13401345
if let Ok(mut network) = self.net.network.lock() {
1341-
network.close(self.node, self.addr);
1346+
network.close(self.proto, self.node, self.addr);
13421347
}
13431348
}
13441349
}
@@ -1372,7 +1377,7 @@ mod tests {
13721377

13731378
let barrier_ = barrier.clone();
13741379
node1.spawn(async move {
1375-
let net = Endpoint::bind(addr1).await.unwrap();
1380+
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
13761381
barrier_.wait().await;
13771382

13781383
net.send_to(addr2, 1, payload!(vec![1])).await.unwrap();
@@ -1382,7 +1387,7 @@ mod tests {
13821387
});
13831388

13841389
let f = node2.spawn(async move {
1385-
let net = Endpoint::bind(addr2).await.unwrap();
1390+
let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap();
13861391
barrier.wait().await;
13871392

13881393
let mut buf = vec![0; 0x10];
@@ -1411,14 +1416,14 @@ mod tests {
14111416

14121417
let barrier_ = barrier.clone();
14131418
node1.spawn(async move {
1414-
let net = Endpoint::bind(addr1).await.unwrap();
1419+
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
14151420
barrier_.wait().await;
14161421

14171422
net.send_to(addr2, 1, payload!(vec![1])).await.unwrap();
14181423
});
14191424

14201425
let f = node2.spawn(async move {
1421-
let net = Endpoint::bind(addr2).await.unwrap();
1426+
let net = Endpoint::bind(libc::SOCK_STREAM, addr2).await.unwrap();
14221427
let mut buf = vec![0; 0x10];
14231428
timeout(Duration::from_secs(1), net.recv_from(1, &mut buf))
14241429
.await
@@ -1443,7 +1448,7 @@ mod tests {
14431448
let node1 = runtime.create_node().ip(addr1.ip()).build();
14441449

14451450
let f = node1.spawn(async move {
1446-
let net = Endpoint::bind(addr1).await.unwrap();
1451+
let net = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
14471452
let err = net.recv_from(1, &mut []).await.unwrap_err();
14481453
assert_eq!(err.kind(), std::io::ErrorKind::BrokenPipe);
14491454
// FIXME: should still error
@@ -1466,36 +1471,47 @@ mod tests {
14661471

14671472
let f = node.spawn(async move {
14681473
// unspecified
1469-
let ep = Endpoint::bind("0.0.0.0:0").await.unwrap();
1474+
let ep = Endpoint::bind(libc::SOCK_STREAM, "0.0.0.0:0")
1475+
.await
1476+
.unwrap();
14701477
let addr = ep.local_addr().unwrap();
14711478
assert_eq!(addr.ip(), ip);
14721479
assert_ne!(addr.port(), 0);
14731480

14741481
// unspecified v6
1475-
let ep = Endpoint::bind(":::0").await.unwrap();
1482+
let ep = Endpoint::bind(libc::SOCK_STREAM, ":::0").await.unwrap();
14761483
let addr = ep.local_addr().unwrap();
14771484
assert_eq!(addr.ip(), ip);
14781485
assert_ne!(addr.port(), 0);
14791486

14801487
// localhost
1481-
let ep = Endpoint::bind("127.0.0.1:0").await.unwrap();
1488+
let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:0")
1489+
.await
1490+
.unwrap();
14821491
let addr = ep.local_addr().unwrap();
14831492
assert_eq!(addr.ip().to_string(), "127.0.0.1");
14841493
assert_ne!(addr.port(), 0);
14851494

14861495
// localhost v6
1487-
let ep = Endpoint::bind("::1:0").await.unwrap();
1496+
let ep = Endpoint::bind(libc::SOCK_STREAM, "::1:0").await.unwrap();
14881497
let addr = ep.local_addr().unwrap();
14891498
assert_eq!(addr.ip().to_string(), "::1");
14901499
assert_ne!(addr.port(), 0);
14911500

14921501
// wrong IP
1493-
let err = Endpoint::bind("10.0.0.2:0").await.err().unwrap();
1502+
let err = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.2:0")
1503+
.await
1504+
.err()
1505+
.unwrap();
14941506
assert_eq!(err.kind(), std::io::ErrorKind::AddrNotAvailable);
14951507

14961508
// drop and reuse port
1497-
let _ = Endpoint::bind("10.0.0.1:100").await.unwrap();
1498-
let _ = Endpoint::bind("10.0.0.1:100").await.unwrap();
1509+
let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100")
1510+
.await
1511+
.unwrap();
1512+
let _ = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:100")
1513+
.await
1514+
.unwrap();
14991515
});
15001516
runtime.block_on(f).unwrap();
15011517
}
@@ -1512,8 +1528,12 @@ mod tests {
15121528

15131529
let barrier_ = barrier.clone();
15141530
let f1 = node1.spawn(async move {
1515-
let ep1 = Endpoint::bind("127.0.0.1:1").await.unwrap();
1516-
let ep2 = Endpoint::bind("10.0.0.1:2").await.unwrap();
1531+
let ep1 = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1")
1532+
.await
1533+
.unwrap();
1534+
let ep2 = Endpoint::bind(libc::SOCK_STREAM, "10.0.0.1:2")
1535+
.await
1536+
.unwrap();
15171537
barrier_.wait().await;
15181538

15191539
// FIXME: ep1 should not receive messages from other node
@@ -1525,7 +1545,9 @@ mod tests {
15251545
ep2.recv_from(1, &mut []).await.unwrap();
15261546
});
15271547
let f2 = node2.spawn(async move {
1528-
let ep = Endpoint::bind("127.0.0.1:1").await.unwrap();
1548+
let ep = Endpoint::bind(libc::SOCK_STREAM, "127.0.0.1:1")
1549+
.await
1550+
.unwrap();
15291551
barrier.wait().await;
15301552

15311553
ep.send_to("10.0.0.1:1", 1, payload!(vec![1]))
@@ -1550,7 +1572,7 @@ mod tests {
15501572

15511573
let barrier_ = barrier.clone();
15521574
node1.spawn(async move {
1553-
let ep = Endpoint::bind(addr1).await.unwrap();
1575+
let ep = Endpoint::bind(libc::SOCK_STREAM, addr1).await.unwrap();
15541576
assert_eq!(ep.local_addr().unwrap(), addr1);
15551577
barrier_.wait().await;
15561578

@@ -1565,7 +1587,7 @@ mod tests {
15651587

15661588
let f = node2.spawn(async move {
15671589
barrier.wait().await;
1568-
let ep = Endpoint::connect(addr1).await.unwrap();
1590+
let ep = Endpoint::connect(libc::SOCK_STREAM, addr1).await.unwrap();
15691591
assert_eq!(ep.peer_addr().unwrap(), addr1);
15701592

15711593
ep.send(1, payload!(b"ping".to_vec())).await.unwrap();

0 commit comments

Comments
 (0)