Skip to content

Commit 2757df1

Browse files
committed
Represent zero-length CIDs by specifying no CID generator
1 parent d0379f5 commit 2757df1

File tree

10 files changed

+103
-61
lines changed

10 files changed

+103
-61
lines changed

quinn-proto/src/cid_generator.rs

+29-11
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ pub trait ConnectionIdGenerator: Send + Sync + ConnectionIdParser {
2424
Ok(())
2525
}
2626

27-
/// Returns the length of a CID for connections created by this generator
28-
fn cid_len(&self) -> usize;
2927
/// Returns the lifetime of generated Connection IDs
3028
///
3129
/// Connection IDs will be retired after the returned `Duration`, if any. Assumed to be constant.
@@ -61,6 +59,10 @@ impl RandomConnectionIdGenerator {
6159
/// The given length must be less than or equal to MAX_CID_SIZE.
6260
pub fn new(cid_len: usize) -> Self {
6361
debug_assert!(cid_len <= MAX_CID_SIZE);
62+
assert!(
63+
cid_len > 0,
64+
"connection ID generators must produce non-empty IDs"
65+
);
6466
Self {
6567
cid_len,
6668
..Self::default()
@@ -90,11 +92,6 @@ impl ConnectionIdGenerator for RandomConnectionIdGenerator {
9092
ConnectionId::new(&bytes_arr[..self.cid_len])
9193
}
9294

93-
/// Provide the length of dst_cid in short header packet
94-
fn cid_len(&self) -> usize {
95-
self.cid_len
96-
}
97-
9895
fn cid_lifetime(&self) -> Option<Duration> {
9996
self.lifetime
10097
}
@@ -171,10 +168,6 @@ impl ConnectionIdGenerator for HashedConnectionIdGenerator {
171168
}
172169
}
173170

174-
fn cid_len(&self) -> usize {
175-
HASHED_CID_LEN
176-
}
177-
178171
fn cid_lifetime(&self) -> Option<Duration> {
179172
self.lifetime
180173
}
@@ -184,6 +177,31 @@ const NONCE_LEN: usize = 3; // Good for more than 16 million connections
184177
const SIGNATURE_LEN: usize = 8 - NONCE_LEN; // 8-byte total CID length
185178
const HASHED_CID_LEN: usize = NONCE_LEN + SIGNATURE_LEN;
186179

180+
/// HACK: Replace uses with `ZeroLengthConnectionIdParser` once [trait upcasting] is stable
181+
///
182+
/// CID generators should produce nonempty CIDs. We should be able to use
183+
/// `ZeroLengthConnectionIdParser` everywhere this would be needed, but that will require
184+
/// construction of `&dyn ConnectionIdParser` from `&dyn ConnectionIdGenerator`.
185+
///
186+
/// [trait upcasting]: https://github.com/rust-lang/rust/issues/65991
187+
pub(crate) struct ZeroLengthConnectionIdGenerator;
188+
189+
impl ConnectionIdParser for ZeroLengthConnectionIdGenerator {
190+
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
191+
Ok(ConnectionId::new(&[]))
192+
}
193+
}
194+
195+
impl ConnectionIdGenerator for ZeroLengthConnectionIdGenerator {
196+
fn generate_cid(&self) -> ConnectionId {
197+
unreachable!()
198+
}
199+
200+
fn cid_lifetime(&self) -> Option<Duration> {
201+
None
202+
}
203+
}
204+
187205
#[cfg(test)]
188206
mod tests {
189207
use super::*;

quinn-proto/src/config.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ impl Default for MtuDiscoveryConfig {
616616
pub struct EndpointConfig {
617617
pub(crate) reset_key: Arc<dyn HmacKey>,
618618
pub(crate) max_udp_payload_size: VarInt,
619-
pub(crate) connection_id_generator: Arc<dyn ConnectionIdGenerator>,
619+
pub(crate) connection_id_generator: Option<Arc<dyn ConnectionIdGenerator>>,
620620
pub(crate) supported_versions: Vec<u32>,
621621
pub(crate) grease_quic_bit: bool,
622622
/// Minimum interval between outgoing stateless reset packets
@@ -629,7 +629,7 @@ impl EndpointConfig {
629629
Self {
630630
reset_key,
631631
max_udp_payload_size: (1500u32 - 28).into(), // Ethernet MTU minus IP + UDP headers
632-
connection_id_generator: Arc::<HashedConnectionIdGenerator>::default(),
632+
connection_id_generator: Some(Arc::<HashedConnectionIdGenerator>::default()),
633633
supported_versions: DEFAULT_SUPPORTED_VERSIONS.to_vec(),
634634
grease_quic_bit: true,
635635
min_reset_interval: Duration::from_millis(20),
@@ -644,7 +644,10 @@ impl EndpointConfig {
644644
/// information in local connection IDs, e.g. to support stateless packet-level load balancers.
645645
///
646646
/// Defaults to [`HashedConnectionIdGenerator`].
647-
pub fn cid_generator(&mut self, generator: Arc<dyn ConnectionIdGenerator>) -> &mut Self {
647+
pub fn cid_generator(
648+
&mut self,
649+
generator: Option<Arc<dyn ConnectionIdGenerator>>,
650+
) -> &mut Self {
648651
self.connection_id_generator = generator;
649652
self
650653
}

quinn-proto/src/connection/mod.rs

+13-12
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@ use thiserror::Error;
1515
use tracing::{debug, error, trace, trace_span, warn};
1616

1717
use crate::{
18-
cid_generator::ConnectionIdGenerator,
18+
cid_generator::{ConnectionIdGenerator, ZeroLengthConnectionIdGenerator},
1919
cid_queue::CidQueue,
2020
coding::BufMutExt,
2121
config::{ServerConfig, TransportConfig},
2222
crypto::{self, KeyPair, Keys, PacketKey},
23-
frame,
24-
frame::{Close, Datagram, FrameStruct},
23+
frame::{self, Close, Datagram, FrameStruct},
2524
packet::{
2625
Header, InitialHeader, InitialPacket, LongType, Packet, PacketNumber, PartialDecode,
2726
SpaceId,
@@ -197,7 +196,7 @@ pub struct Connection {
197196
retry_token: Bytes,
198197
/// Identifies Data-space packet numbers to skip. Not used in earlier spaces.
199198
packet_number_filter: PacketNumberFilter,
200-
cid_gen: Arc<dyn ConnectionIdGenerator>,
199+
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,
201200

202201
//
203202
// Queued non-retransmittable 1-RTT data
@@ -253,7 +252,7 @@ impl Connection {
253252
remote: SocketAddr,
254253
local_ip: Option<IpAddr>,
255254
crypto: Box<dyn crypto::Session>,
256-
cid_gen: Arc<dyn ConnectionIdGenerator>,
255+
cid_gen: Option<Arc<dyn ConnectionIdGenerator>>,
257256
now: Instant,
258257
version: u32,
259258
allow_mtud: bool,
@@ -281,14 +280,13 @@ impl Connection {
281280
crypto,
282281
handshake_cid: loc_cid,
283282
rem_handshake_cid: rem_cid,
284-
local_cid_state: match cid_gen.cid_len() {
285-
0 => None,
286-
_ => Some(CidState::new(
287-
cid_gen.cid_lifetime(),
283+
local_cid_state: cid_gen.as_ref().map(|gen| {
284+
CidState::new(
285+
gen.cid_lifetime(),
288286
now,
289287
if pref_addr_cid.is_some() { 2 } else { 1 },
290-
)),
291-
},
288+
)
289+
}),
292290
path: PathData::new(remote, allow_mtud, None, now, path_validated, &config),
293291
allow_mtud,
294292
local_ip,
@@ -2103,7 +2101,10 @@ impl Connection {
21032101
while let Some(data) = remaining {
21042102
match PartialDecode::new(
21052103
data,
2106-
&*self.cid_gen,
2104+
self.cid_gen.as_ref().map_or(
2105+
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
2106+
|x| &**x,
2107+
),
21072108
&[self.version],
21082109
self.endpoint_config.grease_quic_bit,
21092110
) {

quinn-proto/src/endpoint.rs

+25-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ use thiserror::Error;
1616
use tracing::{debug, error, trace, warn};
1717

1818
use crate::{
19-
cid_generator::{ConnectionIdGenerator, RandomConnectionIdGenerator},
19+
cid_generator::{
20+
ConnectionIdGenerator, RandomConnectionIdGenerator, ZeroLengthConnectionIdGenerator,
21+
},
2022
coding::BufMutExt,
2123
config::{ClientConfig, EndpointConfig, ServerConfig},
2224
connection::{Connection, ConnectionError},
@@ -44,7 +46,7 @@ pub struct Endpoint {
4446
rng: StdRng,
4547
index: ConnectionIndex,
4648
connections: Slab<ConnectionMeta>,
47-
local_cid_generator: Arc<dyn ConnectionIdGenerator>,
49+
local_cid_generator: Option<Arc<dyn ConnectionIdGenerator>>,
4850
config: Arc<EndpointConfig>,
4951
server_config: Option<Arc<ServerConfig>>,
5052
/// Whether the underlying UDP socket promises not to fragment packets
@@ -144,7 +146,10 @@ impl Endpoint {
144146
let datagram_len = data.len();
145147
let (first_decode, remaining) = match PartialDecode::new(
146148
data,
147-
&*self.local_cid_generator,
149+
self.local_cid_generator.as_ref().map_or(
150+
&ZeroLengthConnectionIdGenerator as &dyn ConnectionIdGenerator,
151+
|x| &**x,
152+
),
148153
&self.config.supported_versions,
149154
self.config.grease_quic_bit,
150155
) {
@@ -302,8 +307,8 @@ impl Endpoint {
302307
if !first_decode.is_initial()
303308
&& self
304309
.local_cid_generator
305-
.validate(first_decode.dst_cid())
306-
.is_err()
310+
.as_ref()
311+
.map_or(false, |gen| gen.validate(first_decode.dst_cid()).is_err())
307312
{
308313
debug!("dropping packet with invalid CID");
309314
return None;
@@ -400,7 +405,7 @@ impl Endpoint {
400405
let params = TransportParameters::new(
401406
&config.transport,
402407
&self.config,
403-
self.local_cid_generator.as_ref(),
408+
self.local_cid_generator.is_some(),
404409
loc_cid,
405410
None,
406411
);
@@ -453,12 +458,11 @@ impl Endpoint {
453458
/// Generate a connection ID for `ch`
454459
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
455460
loop {
456-
let cid = self.local_cid_generator.generate_cid();
457-
if cid.len() == 0 {
461+
let Some(cid_generator) = self.local_cid_generator.as_ref() else {
458462
// Zero-length CID; nothing to track
459-
debug_assert_eq!(self.local_cid_generator.cid_len(), 0);
460-
return cid;
461-
}
463+
return ConnectionId::EMPTY;
464+
};
465+
let cid = cid_generator.generate_cid();
462466
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
463467
e.insert(ch);
464468
break cid;
@@ -589,7 +593,7 @@ impl Endpoint {
589593
let mut params = TransportParameters::new(
590594
&server_config.transport,
591595
&self.config,
592-
self.local_cid_generator.as_ref(),
596+
self.local_cid_generator.is_some(),
593597
loc_cid,
594598
Some(&server_config),
595599
);
@@ -680,10 +684,7 @@ impl Endpoint {
680684
// bytes. If this is a Retry packet, then the length must instead match our usual CID
681685
// length. If we ever issue non-Retry address validation tokens via `NEW_TOKEN`, then we'll
682686
// also need to validate CID length for those after decoding the token.
683-
if header.dst_cid.len() < 8
684-
&& (!header.token_pos.is_empty()
685-
&& header.dst_cid.len() != self.local_cid_generator.cid_len())
686-
{
687+
if header.dst_cid.len() < 8 && !header.token_pos.is_empty() {
687688
debug!(
688689
"rejecting connection due to invalid DCID length {}",
689690
header.dst_cid.len()
@@ -730,7 +731,10 @@ impl Endpoint {
730731
// with established connections. In the unlikely event that a collision occurs
731732
// between two connections in the initial phase, both will fail fast and may be
732733
// retried by the application layer.
733-
let loc_cid = self.local_cid_generator.generate_cid();
734+
let loc_cid = self
735+
.local_cid_generator
736+
.as_ref()
737+
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());
734738

735739
let token = RetryToken {
736740
orig_dst_cid: incoming.packet.header.dst_cid,
@@ -860,7 +864,10 @@ impl Endpoint {
860864
// We don't need to worry about CID collisions in initial closes because the peer
861865
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
862866
// unexpected response.
863-
let local_id = self.local_cid_generator.generate_cid();
867+
let local_id = self
868+
.local_cid_generator
869+
.as_ref()
870+
.map_or(ConnectionId::EMPTY, |gen| gen.generate_cid());
864871
let number = PacketNumber::U8(0);
865872
let header = Header::Initial(InitialHeader {
866873
dst_cid: *remote_id,

quinn-proto/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pub use crate::endpoint::{
6767
mod packet;
6868
pub use packet::{
6969
ConnectionIdParser, LongType, PacketDecodeError, PartialDecode, ProtectedHeader,
70-
ProtectedInitialHeader,
70+
ProtectedInitialHeader, ZeroLengthConnectionIdParser,
7171
};
7272

7373
mod shared;

quinn-proto/src/packet.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,16 @@ pub trait ConnectionIdParser {
773773
fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
774774
}
775775

776+
/// Trivial parser for zero-length connection IDs
777+
pub struct ZeroLengthConnectionIdParser;
778+
779+
impl ConnectionIdParser for ZeroLengthConnectionIdParser {
780+
#[inline]
781+
fn parse(&self, _: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
782+
Ok(ConnectionId::new(&[]))
783+
}
784+
}
785+
776786
/// Long packet type including non-uniform cases
777787
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
778788
pub(crate) enum LongHeaderType {
@@ -908,7 +918,7 @@ mod tests {
908918
#[test]
909919
fn header_encoding() {
910920
use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
911-
use crate::{RandomConnectionIdGenerator, Side};
921+
use crate::Side;
912922
use rustls::crypto::ring::default_provider;
913923
use rustls::quic::Version;
914924

@@ -950,7 +960,7 @@ mod tests {
950960
let supported_versions = DEFAULT_SUPPORTED_VERSIONS.to_vec();
951961
let decode = PartialDecode::new(
952962
buf.as_slice().into(),
953-
&RandomConnectionIdGenerator::new(0),
963+
&ZeroLengthConnectionIdParser,
954964
&supported_versions,
955965
false,
956966
)

quinn-proto/src/shared.rs

+6
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ pub struct ConnectionId {
7272
}
7373

7474
impl ConnectionId {
75+
/// The zero-length connection ID
76+
pub const EMPTY: ConnectionId = ConnectionId {
77+
len: 0,
78+
bytes: [0; MAX_CID_SIZE],
79+
};
80+
7581
/// Construct cid from byte array
7682
pub fn new(bytes: &[u8]) -> Self {
7783
debug_assert!(bytes.len() <= MAX_CID_SIZE);

quinn-proto/src/tests/mod.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ fn version_negotiate_client() {
6666
// packet
6767
let mut client = Endpoint::new(
6868
Arc::new(EndpointConfig {
69-
connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)),
69+
connection_id_generator: None,
7070
..Default::default()
7171
}),
7272
None,
@@ -181,7 +181,7 @@ fn server_stateless_reset() {
181181
rng.fill_bytes(&mut key_material);
182182

183183
let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key));
184-
endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0)));
184+
endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0))));
185185
let endpoint_config = Arc::new(endpoint_config);
186186

187187
let mut pair = Pair::new(endpoint_config.clone(), server_config());
@@ -211,7 +211,7 @@ fn client_stateless_reset() {
211211
rng.fill_bytes(&mut key_material);
212212

213213
let mut endpoint_config = EndpointConfig::new(Arc::new(reset_key));
214-
endpoint_config.cid_generator(Arc::new(HashedConnectionIdGenerator::from_key(0)));
214+
endpoint_config.cid_generator(Some(Arc::new(HashedConnectionIdGenerator::from_key(0))));
215215
let endpoint_config = Arc::new(endpoint_config);
216216

217217
let mut pair = Pair::new(endpoint_config.clone(), server_config());
@@ -240,7 +240,7 @@ fn stateless_reset_limit() {
240240
let _guard = subscribe();
241241
let remote = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 42);
242242
let mut endpoint_config = EndpointConfig::default();
243-
endpoint_config.cid_generator(Arc::new(RandomConnectionIdGenerator::new(8)));
243+
endpoint_config.cid_generator(Some(Arc::new(RandomConnectionIdGenerator::new(8))));
244244
let endpoint_config = Arc::new(endpoint_config);
245245
let mut endpoint = Endpoint::new(
246246
endpoint_config.clone(),
@@ -1468,7 +1468,7 @@ fn zero_length_cid() {
14681468
let _guard = subscribe();
14691469
let mut pair = Pair::new(
14701470
Arc::new(EndpointConfig {
1471-
connection_id_generator: Arc::new(RandomConnectionIdGenerator::new(0)),
1471+
connection_id_generator: None,
14721472
..EndpointConfig::default()
14731473
}),
14741474
server_config(),
@@ -1525,9 +1525,9 @@ fn cid_rotation() {
15251525
// Only test cid rotation on server side to have a clear output trace
15261526
let server = Endpoint::new(
15271527
Arc::new(EndpointConfig {
1528-
connection_id_generator: Arc::new(
1528+
connection_id_generator: Some(Arc::new(
15291529
*RandomConnectionIdGenerator::new(8).set_lifetime(CID_TIMEOUT),
1530-
),
1530+
)),
15311531
..EndpointConfig::default()
15321532
}),
15331533
Some(Arc::new(server_config())),

0 commit comments

Comments
 (0)