Skip to content

Commit 615a1ed

Browse files
committed
Represent zero-length CIDs by specifying no CID generator
1 parent 61dbea6 commit 615a1ed

File tree

11 files changed

+113
-63
lines changed

11 files changed

+113
-63
lines changed

fuzz/fuzz_targets/packet.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,23 @@ extern crate proto;
55
use libfuzzer_sys::fuzz_target;
66
use proto::{
77
fuzzing::{PacketParams, PartialDecode},
8-
RandomConnectionIdGenerator, DEFAULT_SUPPORTED_VERSIONS,
8+
ConnectionIdParser, RandomConnectionIdGenerator, ZeroLengthConnectionIdParser,
9+
DEFAULT_SUPPORTED_VERSIONS,
910
};
1011

1112
fuzz_target!(|data: PacketParams| {
1213
let len = data.buf.len();
1314
let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec();
15+
let cid_gen;
1416
if let Ok(decoded) = PartialDecode::new(
1517
data.buf,
16-
&RandomConnectionIdGenerator::new(data.local_cid_len),
18+
match data.local_cid_len {
19+
0 => &ZeroLengthConnectionIdParser as &dyn ConnectionIdParser,
20+
_ => {
21+
cid_gen = RandomConnectionIdGenerator::new(data.local_cid_len);
22+
&cid_gen as &dyn ConnectionIdParser
23+
}
24+
},
1725
&supported_versions,
1826
data.grease_quic_bit,
1927
) {

quinn-proto/src/cid_generator.rs

+29-11
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser {
2626
Ok(())
2727
}
2828

29-
/// Returns the length of a CID for connections created by this generator
30-
fn cid_len(&self) -> usize;
3129
/// Returns the lifetime of generated Connection IDs
3230
///
3331
/// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
@@ -63,6 +61,10 @@ impl RandomConnectionIdGenerator {
6361
/// The given length must be less than or equal to MAX_CID_SIZE.
6462
pub fn new(cid_len: usize) -> Self {
6563
debug_assert!(cid_len <= MAX_CID_SIZE);
64+
assert!(
65+
cid_len > 0,
66+
"connection ID generators must produce non-empty IDs"
67+
);
6668
Self {
6769
cid_len,
6870
..Self::default()
@@ -92,11 +94,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator {
9294
ConnectionId::new(&bytes_arr[..self.cid_len])
9395
}
9496

95-
/// Provide the length of dst_cid in short header packet
96-
fn cid_len(&self) -> usize {
97-
self.cid_len
98-
}
99-
10097
fn cid_lifetime(&self) -> Option<Duration> {
10198
self.lifetime
10299
}
@@ -173,10 +170,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator {
173170
}
174171
}
175172

176-
fn cid_len(&self) -> usize {
177-
HASHED_CID_LEN
178-
}
179-
180173
fn cid_lifetime(&self) -> Option<Duration> {
181174
self.lifetime
182175
}
@@ -186,6 +179,31 @@ const NONCE_LEN: usize = 3; // Good for more than 16 million connections
186179
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
187180
const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN;
188181

182+
/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable
183+
///
184+
/// CID generators should produce nonempty CIDs. We should be able to use
185+
/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require
186+
/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`.
187+
///
188+
/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991
189+
pub(crate) struct ZeroLengthConnectionIdGenerator;
190+
191+
impl ConnectionIdParser for ZeroLengthConnectionIdGenerator {
192+
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
193+
Ok(ConnectionId::new(&[]))
194+
}
195+
}
196+
197+
impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator {
198+
fn generate_cid(&self) -> ConnectionId {
199+
unreachable!()
200+
}
201+
202+
fn cid_lifetime(&self) -> Option<Duration> {
203+
None
204+
}
205+
}
206+
189207
#[cfg(test)]
190208
mod tests {
191209
use super::*;

quinn-proto/src/config.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ impl Default for MtuDiscoveryConfig {
616616
pub struct EndpointConfig {
617617
pub(crate) reset_key: Arc<dyn HmacKey>,
618618
pub(crate) max_udp_payload_size: VarInt,
619-
pub(crate) connection_id_generator: Arc<dyn ConnectionIdGenerator>,
619+
pub(crate) connection_id_generator: Option<Arc<dyn ConnectionIdGenerator>>,
620620
pub(crate) supported_versions: Vec<u32>,
621621
pub(crate) grease_quic_bit: bool,
622622
/// Minimum interval between outgoing stateless reset packets
@@ -629,7 +629,7 @@ impl EndpointConfig {
629629
Self {
630630
reset_key,
631631
max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers
632-
connection_id_generator: Arc::<HashedConnectionIdGenerator>::default(),
632+
connection_id_generator: Some(Arc::<HashedConnectionIdGenerator>::default()),
633633
supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(),
634634
grease_quic_bit: true,
635635
min_reset_interval: Duration::from_millis(20),
@@ -644,7 +644,10 @@ impl EndpointConfig {
644644
/// information in local connection IDs, e.g. to support stateless packet-level load balancers.
645645
///
646646
/// Defaults to [`HashedConnectionIdGenerator`].
647-
pub fn cid_generator(&mut self, generator: Arc<dyn ConnectionIdGenerator>) -> &mut Self {
647+
pub fn cid_generator(
648+
&mut self,
649+
generator: Option<Arc<dyn ConnectionIdGenerator>>,
650+
) -> &mut Self {
648651
self.connection_id_generator = generator;
649652
self
650653
}

quinn-proto/src/connection/mod.rs

+13-12
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ use thiserror::Error;
1515
use tracing::{debug, error, trace, trace_span, warn};
1616

1717
use crate::{
18-
cid_generator::ConnectionIdGenerator,
18+
cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator},
1919
cid_queue::CidQueue,
2020
coding::BufMutExt,
2121
config::{ServerConfig, TransportConfig},
2222
crypto::{self, KeyPair, Keys, PacketKey},
23-
frame,
24-
frame::{Close, Datagram, FrameStruct},
23+
frame::{self, Close, Datagram, FrameStruct},
2524
packet::{
2625
Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode,
2726
SpaceId,
@@ -197,7 +196,7 @@ pub struct Connection {
197196
retry_token: Bytes,
198197
/// Identifies Data-space packet numbers to skip. Not used in earlier spaces.
199198
packet_number_filter: PacketNumberFilter,
200-
cid_gen: Arc<dyn ConnectionIdGenerator>,
199+
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,
201200

202201
//
203202
// Queued non-retransmittable 1-RTT data
@@ -253,7 +252,7 @@ impl Connection {
253252
remote: SocketAddr,
254253
local_ip: Option<IpAddr>,
255254
crypto: Box<dyn crypto::Session>,
256-
cid_gen: Arc<dyn ConnectionIdGenerator>,
255+
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,
257256
now: Instant,
258257
version: u32,
259258
allow_mtud: bool,
@@ -281,14 +280,13 @@ impl Connection {
281280
crypto,
282281
handshake_cid: loc_cid,
283282
rem_handshake_cid: rem_cid,
284-
local_cid_state: match cid_gen.cid_len() {
285-
0 => None,
286-
_ => Some(CidState::new(
287-
cid_gen.cid_lifetime(),
283+
local_cid_state: cid_gen.as_ref().map(|gen| {
284+
CidState::new(
285+
gen.cid_lifetime(),
288286
now,
289287
if pref_addr_cid.is_some() { 2 } else { 1 },
290-
)),
291-
},
288+
)
289+
}),
292290
path: PathData::new(remote, allow_mtud, None, now, path_validated, &config),
293291
allow_mtud,
294292
local_ip,
@@ -2103,7 +2101,10 @@ impl Connection {
21032101
while let Some(data) = remaining {
21042102
match PartialDecode::new(
21052103
data,
2106-
&*self.cid_gen,
2104+
self.cid_gen.as_ref().map_or(
2105+
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
2106+
|x| &**x,
2107+
),
21072108
&[self.version],
21082109
self.endpoint_config.grease_quic_bit,
21092110
) {

quinn-proto/src/endpoint.rs

+25-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ use thiserror::Error;
1616
use tracing::{debug, error, trace, warn};
1717

1818
use crate::{
19-
cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator},
19+
cid_generator::{
20+
ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator,
21+
},
2022
coding::BufMutExt,
2123
config::{ClientConfig, EndpointConfig, ServerConfig},
2224
connection::{Connection, ConnectionError},
@@ -44,7 +46,7 @@ pub struct Endpoint {
4446
rng: StdRng,
4547
index: ConnectionIndex,
4648
connections: Slab<ConnectionMeta>,
47-
local_cid_generator: Arc<dyn ConnectionIdGenerator>,
49+
local_cid_generator: Option<Arc<dyn ConnectionIdGenerator>>,
4850
config: Arc<EndpointConfig>,
4951
server_config: Option<Arc<ServerConfig>>,
5052
/// Whether the underlying UDP socket promises not to fragment packets
@@ -144,7 +146,10 @@ impl Endpoint {
144146
let datagram_len = data.len();
145147
let (first_decode, remaining) = match PartialDecode::new(
146148
data,
147-
&*self.local_cid_generator,
149+
self.local_cid_generator.as_ref().map_or(
150+
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
151+
|x| &**x,
152+
),
148153
&self.config.supported_versions,
149154
self.config.grease_quic_bit,
150155
) {
@@ -302,8 +307,8 @@ impl Endpoint {
302307
if !first_decode.is_initial()
303308
&& self
304309
.local_cid_generator
305-
.validate(first_decode.dst_cid())
306-
.is_err()
310+
.as_ref()
311+
.map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err())
307312
{
308313
debug!("dropping packet with invalid CID");
309314
return None;
@@ -400,7 +405,7 @@ impl Endpoint {
400405
let params = TransportParameters::new(
401406
&config.transport,
402407
&self.config,
403-
self.local_cid_generator.as_ref(),
408+
self.local_cid_generator.is_some(),
404409
loc_cid,
405410
None,
406411
);
@@ -453,12 +458,11 @@ impl Endpoint {
453458
/// Generate a connection ID for `ch`
454459
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
455460
loop {
456-
let cid = self.local_cid_generator.generate_cid();
457-
if cid.len() == 0 {
461+
let Some(cid_generator) = self.local_cid_generator.as_ref() else {
458462
// Zero-length CID; nothing to track
459-
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
460-
return cid;
461-
}
463+
return ConnectionId::EMPTY;
464+
};
465+
let cid = cid_generator.generate_cid();
462466
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
463467
e.insert(ch);
464468
break cid;
@@ -589,7 +593,7 @@ impl Endpoint {
589593
let mut params = TransportParameters::new(
590594
&server_config.transport,
591595
&self.config,
592-
self.local_cid_generator.as_ref(),
596+
self.local_cid_generator.is_some(),
593597
loc_cid,
594598
Some(&server_config),
595599
);
@@ -680,10 +684,7 @@ impl Endpoint {
680684
// bytes. If this is a Retry packet, then the length must instead match our usual CID
681685
// length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
682686
// also need to validate CID length for those after decoding the token.
683-
if header.dst_cid.len() < 8
684-
&& (!header.token_pos.is_empty()
685-
&& header.dst_cid.len() != self.local_cid_generator.cid_len())
686-
{
687+
if header.dst_cid.len() < 8 && !header.token_pos.is_empty() {
687688
debug!(
688689
"rejecting connection due to invalid DCID length {}",
689690
header.dst_cid.len()
@@ -730,7 +731,10 @@ impl Endpoint {
730731
// with established connections. In the unlikely event that a collision occurs
731732
// between two connections in the initial phase, both will fail fast and may be
732733
// retried by the application layer.
733-
let loc_cid = self.local_cid_generator.generate_cid();
734+
let loc_cid = self
735+
.local_cid_generator
736+
.as_ref()
737+
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());
734738

735739
let token = RetryToken {
736740
orig_dst_cid: incoming.packet.header.dst_cid,
@@ -860,7 +864,10 @@ impl Endpoint {
860864
// We don't need to worry about CID collisions in initial closes because the peer
861865
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
862866
// unexpected response.
863-
let local_id = self.local_cid_generator.generate_cid();
867+
let local_id = self
868+
.local_cid_generator
869+
.as_ref()
870+
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());
864871
let number = PacketNumber::U8(0);
865872
let header = Header::Initial(InitialHeader {
866873
dst_cid: *remote_id,

quinn-proto/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub use crate::endpoint::{
6767
mod packet;
6868
pub use packet::{
6969
ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader,
70-
ProtectedInitialHeader,
70+
ProtectedInitialHeader, ZeroLengthConnectionIdParser,
7171
};
7272

7373
mod shared;

quinn-proto/src/packet.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,16 @@ pub trait ConnectionIdParser {
773773
fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
774774
}
775775

776+
/// Trivial parser for zero-length connection IDs
777+
pub struct ZeroLengthConnectionIdParser;
778+
779+
impl ConnectionIdParser for ZeroLengthConnectionIdParser {
780+
#[inline]
781+
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
782+
Ok(ConnectionId::new(&[]))
783+
}
784+
}
785+
776786
/// Long packet type including non-uniform cases
777787
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
778788
pub(crate) enum LongHeaderType {
@@ -908,7 +918,7 @@ mod tests {
908918
#[test]
909919
fn header_encoding() {
910920
use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
911-
use crate::{RandomConnectionIdGenerator, Side};
921+
use crate::Side;
912922
use rustls::crypto::ring::default_provider;
913923
use rustls::quic::Version;
914924

@@ -950,7 +960,7 @@ mod tests {
950960
let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec();
951961
let decode = PartialDecode::new(
952962
buf.as_slice().into(),
953-
&RandomConnectionIdGenerator::new(0),
963+
&ZeroLengthConnectionIdParser,
954964
&supported_versions,
955965
false,
956966
)

quinn-proto/src/shared.rs

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ pub struct ConnectionId {
7272
}
7373

7474
impl ConnectionId {
75+
/// The zero-length connection ID
76+
pub const EMPTY: Self = Self {
77+
len: 0,
78+
bytes: [0; MAX_CID_SIZE],
79+
};
80+
7581
/// Construct cid from byte array
7682
pub fn new(bytes: &[u8]) -> Self {
7783
debug_assert!(bytes.len() <= MAX_CID_SIZE);

0 commit comments

Comments
 (0)