diff --git a/xtra/src/message_channel.rs b/xtra/src/message_channel.rs index ad693c1..01f6089 100644 --- a/xtra/src/message_channel.rs +++ b/xtra/src/message_channel.rs @@ -3,6 +3,7 @@ //! the message type rather than the actor type. use std::fmt; +use std::hash::{Hash, Hasher}; use crate::address::{ActorJoinHandle, Address}; use crate::chan::RefCounter; @@ -207,6 +208,17 @@ where } } +impl Hash for MessageChannel +where + M: Send + 'static, + R: Send + 'static, + Rc: Send + 'static, +{ + fn hash(&self, state: &mut H) { + self.inner.hash(state) + } +} + impl Clone for MessageChannel where R: Send + 'static, @@ -295,6 +307,8 @@ trait MessageChannelTrait { fn to_either( &self, ) -> Box + Send + Sync + 'static>; + + fn hash(&self, state: &mut dyn Hasher); } impl MessageChannelTrait for Address @@ -366,4 +380,101 @@ where { Box::new(Address(self.0.to_tx_either())) } + + fn hash(&self, state: &mut dyn Hasher) { + state.write_usize(self.0.inner_ptr() as *const _ as usize); + state.write_u8(self.0.is_strong() as u8); + state.finish(); + } +} + +#[cfg(test)] +mod test { + use std::hash::{Hash, Hasher}; + + use crate::{Actor, Handler, Mailbox}; + + type TestMessageChannel = super::MessageChannel; + + struct TestActor; + struct TestMessage; + + impl Actor for TestActor { + type Stop = (); + + async fn stopped(self) -> Self::Stop {} + } + + impl Handler for TestActor { + type Return = (); + + async fn handle(&mut self, _: TestMessage, _: &mut crate::Context) -> Self::Return {} + } + + struct RecordingHasher(Vec); + + impl RecordingHasher { + fn record_hash(value: &H) -> Vec { + let mut h = Self(Vec::new()); + value.hash(&mut h); + assert!(!h.0.is_empty(), "the hash data not be empty"); + h.0 + } + } + + impl Hasher for RecordingHasher { + fn finish(&self) -> u64 { + 0 + } + + fn write(&mut self, bytes: &[u8]) { + self.0.extend_from_slice(bytes) + } + } + + #[test] + fn hashcode() { + let (a1, _) = Mailbox::::unbounded(); + let c1 = TestMessageChannel::new(a1.clone()); + + let h1 = RecordingHasher::record_hash(&c1); + let h2 = RecordingHasher::record_hash(&c1.clone()); + let h3 = RecordingHasher::record_hash(&TestMessageChannel::new(a1)); + + assert_eq!(h1, h2, "hashes from cloned channels should match"); + assert_eq!( + h1, h3, + "hashes channels created against the same address should match" + ); + + let h4 = RecordingHasher::record_hash(&TestMessageChannel::new( + Mailbox::::unbounded().0, + )); + + assert_ne!( + h1, h4, + "hashes from channels created against different addresses should differ" + ); + } + + #[test] + fn partial_eq() { + let (a1, _) = Mailbox::::unbounded(); + let c1 = TestMessageChannel::new(a1.clone()); + let c2 = c1.clone(); + let c3 = TestMessageChannel::new(a1); + + assert_eq!(c1, c2, "cloned channels should match"); + assert_eq!( + c1, c3, + "channels created against the same address should match" + ); + + let c4 = TestMessageChannel::new(Mailbox::::unbounded().0); + + assert_ne!( + c1, c4, + "channels created against different addresses should differ" + ); + } }