diff --git a/rumqttc/CHANGELOG.md b/rumqttc/CHANGELOG.md index ce94f005..1435021a 100644 --- a/rumqttc/CHANGELOG.md +++ b/rumqttc/CHANGELOG.md @@ -23,6 +23,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 * use `Login` to store credentials * Made `DisconnectProperties` struct public. * Replace `Vec>` with `FixedBitSet` for managing packet ids of released QoS 2 publishes and incoming QoS 2 publishes in `MqttState`. +* Replace `Vec>` with `HashMap` for storing outgoing pub packet ids in `MqttState`. This dramatically reduces memory usage in processes with a large number of clients. ### Deprecated diff --git a/rumqttc/src/v5/state.rs b/rumqttc/src/v5/state.rs index 6f37a471..329b2cae 100644 --- a/rumqttc/src/v5/state.rs +++ b/rumqttc/src/v5/state.rs @@ -103,7 +103,7 @@ pub struct MqttState { /// Number of outgoing inflight publishes pub(crate) inflight: u16, /// Outgoing QoS 1, 2 publishes which aren't acked yet - pub(crate) outgoing_pub: Vec>, + pub(crate) outgoing_pub: HashMap, /// Packet ids of released QoS 2 publishes pub(crate) outgoing_rel: FixedBitSet, /// Packet ids on incoming QoS 2 publishes @@ -137,7 +137,7 @@ impl MqttState { last_pkid: 0, inflight: 0, // index 0 is wasted as 0 is not a valid packet id - outgoing_pub: vec![None; max_inflight as usize + 1], + outgoing_pub: HashMap::new(), outgoing_rel: FixedBitSet::with_capacity(max_inflight as usize + 1), incoming_pub: FixedBitSet::with_capacity(u16::MAX as usize + 1), collision: None, @@ -156,11 +156,9 @@ impl MqttState { pub fn clean(&mut self) -> Vec { let mut pending = Vec::with_capacity(100); // remove and collect pending publishes - for publish in self.outgoing_pub.iter_mut() { - if let Some(publish) = publish.take() { - let request = Request::Publish(publish); - pending.push(request); - } + for (_, publish) in self.outgoing_pub.drain() { + let request = Request::Publish(publish); + pending.push(request); } // remove and collect pending releases @@ -359,12 +357,9 @@ impl MqttState { } fn handle_incoming_puback(&mut self, puback: &PubAck) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(puback.pkid as usize) - .ok_or(StateError::Unsolicited(puback.pkid))?; - - if publish.take().is_none() { + if puback.pkid > self.max_outgoing_inflight + || self.outgoing_pub.remove(&puback.pkid).is_none() + { error!("Unsolicited puback packet: {:?}", puback.pkid); return Err(StateError::Unsolicited(puback.pkid)); } @@ -380,7 +375,7 @@ impl MqttState { } if let Some(publish) = self.check_collision(puback.pkid) { - self.outgoing_pub[publish.pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(publish.pkid, publish.clone()); self.inflight += 1; let pkid = publish.pkid; @@ -395,13 +390,10 @@ impl MqttState { } fn handle_incoming_pubrec(&mut self, pubrec: &PubRec) -> Result, StateError> { - let publish = self - .outgoing_pub - .get_mut(pubrec.pkid as usize) - .ok_or(StateError::Unsolicited(pubrec.pkid))?; - - if publish.take().is_none() { - error!("Unsolicited pubrec packet: {:?}", pubrec.pkid); + if pubrec.pkid > self.max_outgoing_inflight + || self.outgoing_pub.remove(&pubrec.pkid).is_none() + { + error!("Unsolicited puback packet: {:?}", pubrec.pkid); return Err(StateError::Unsolicited(pubrec.pkid)); } @@ -480,12 +472,12 @@ impl MqttState { } let pkid = publish.pkid; - if self - .outgoing_pub - .get(publish.pkid as usize) - .ok_or(StateError::Unsolicited(publish.pkid))? - .is_some() - { + + if publish.pkid > self.max_outgoing_inflight { + error!("Unsolicited puback packet: {:?}", publish.pkid); + return Err(StateError::Unsolicited(publish.pkid)); + } + if self.outgoing_pub.contains_key(&publish.pkid) { info!("Collision on packet id = {:?}", publish.pkid); self.collision = Some(publish); let event = Event::Outgoing(Outgoing::AwaitAck(pkid)); @@ -495,7 +487,7 @@ impl MqttState { // if there is an existing publish at this pkid, this implies that broker hasn't acked this // packet yet. This error is possible only when broker isn't acking sequentially - self.outgoing_pub[pkid as usize] = Some(publish.clone()); + self.outgoing_pub.insert(pkid, publish.clone()); self.inflight += 1; }; @@ -690,6 +682,7 @@ mod test { use super::mqttbytes::*; use super::{Event, Incoming, Outgoing, Request}; use super::{MqttState, StateError}; + use std::mem; fn build_outgoing_publish(qos: QoS) -> Publish { let topic = "hello/world".to_owned(); @@ -891,8 +884,8 @@ mod test { mqtt.handle_incoming_puback(&PubAck::new(2, None)).unwrap(); assert_eq!(mqtt.inflight, 0); - assert!(mqtt.outgoing_pub[1].is_none()); - assert!(mqtt.outgoing_pub[2].is_none()); + assert!(!mqtt.outgoing_pub.contains_key(&1)); + assert!(!mqtt.outgoing_pub.contains_key(&2)); } #[test] @@ -911,6 +904,9 @@ mod test { #[test] fn incoming_pubrec_should_release_publish_from_queue_and_add_relid_to_rel_queue() { + let publish_size = mem::size_of::(); + + println!("{publish_size}"); let mut mqtt = build_mqttstate(); let publish1 = build_outgoing_publish(QoS::AtLeastOnce); @@ -923,7 +919,7 @@ mod test { assert_eq!(mqtt.inflight, 2); // check if the remaining element's pkid is 1 - let backup = mqtt.outgoing_pub[1].clone(); + let backup = mqtt.outgoing_pub.get(&1).clone(); assert_eq!(backup.unwrap().pkid, 1); // check if the qos2 element's release pkid is 2