diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 751e8caba6..fb5883af0a 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -77,12 +77,27 @@ impl From for Error { } } +/// Type that stores the information required when processing inbound messages. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub enum InboundEntry { + Message { + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + }, + Proof { + current_count: u32, + }, +} + #[frame_support::pallet] pub mod pallet { const BYTES_U32: usize = 4; const BYTES_ACCOUNT_20: usize = 20; use orml_traits::arithmetic::One; + use sp_arithmetic::traits::EnsureSub; use sp_core::H256; use super::*; @@ -205,6 +220,7 @@ pub mod pallet { /// Inbound routers were set. InboundRoutersSet { + domain: Domain, router_hashes: BoundedVec, }, } @@ -250,39 +266,48 @@ pub mod pallet { pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; - /// Storage for routers. + /// Storage for outbound routers. /// /// This can only be set by an admin. #[pallet::storage] #[pallet::getter(fn routers)] - pub type Routers = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; + pub type OutboundRouters = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; - /// Storage for domain multi-routers. + /// Storage for outbound routers specific for a domain. /// /// This can only be set by an admin. #[pallet::storage] - #[pallet::getter(fn domain_multi_routers)] - pub type DomainMultiRouters = + #[pallet::getter(fn outbound_domain_routers)] + pub type OutboundDomainRouters = StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; - /// Storage that keeps track of incoming message proofs. - #[pallet::storage] - #[pallet::getter(fn inbound_message_proof_count)] - pub type InboundMessageProofCount = - StorageMap<_, Blake2_128Concat, Proof, u32, ValueQuery>; - - /// Storage that keeps track of incoming messages and the expected proof - /// count. + /// Storage for pending inbound messages. #[pallet::storage] - #[pallet::getter(fn inbound_messages)] - pub type InboundMessages = - StorageMap<_, Blake2_128Concat, Proof, (DomainAddress, T::Message, u32)>; - + #[pallet::getter(fn pending_inbound_entries)] + pub type PendingInboundEntries = StorageDoubleMap< + _, + Blake2_128Concat, + T::SessionId, + Blake2_128Concat, + (Proof, T::Hash), + InboundEntry, + >; + + /// Storage for inbound routers specific for a domain. + /// + /// This can only be set by an admin. #[pallet::storage] #[pallet::getter(fn inbound_routers)] pub type InboundRouters = - StorageValue<_, BoundedVec, ValueQuery>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + /// Storage for the session ID of an inbound domain. + #[pallet::storage] + #[pallet::getter(fn inbound_domain_sessions)] + pub type InboundDomainSessions = + StorageMap<_, Blake2_128Concat, Domain, T::SessionId>; + + /// Storage for inbound router session IDs. #[pallet::storage] pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; @@ -608,10 +633,10 @@ pub mod pallet { router_hashes.push(router_hash); - Routers::::insert(router_hash, router); + OutboundRouters::::insert(router_hash, router); } - >::insert( + >::insert( domain.clone(), BoundedVec::try_from(router_hashes).map_err(|_| Error::::InvalidMultiRouter)?, ); @@ -626,6 +651,7 @@ pub mod pallet { #[pallet::call_index(12)] pub fn set_inbound_routers( origin: OriginFor, + domain: Domain, router_hashes: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; @@ -635,14 +661,18 @@ pub mod pallet { Error::::InvalidMultiRouter ); - SessionIdStore::::try_mutate(|n| { + let session_id = SessionIdStore::::try_mutate(|n| { n.ensure_add_assign(One::one())?; - Ok::<(), DispatchError>(()) + Ok::(*n) })?; - InboundRouters::::set(router_hashes.clone()); + InboundRouters::::insert(domain.clone(), router_hashes.clone()); + InboundDomainSessions::::insert(domain.clone(), session_id); - Self::deposit_event(Event::InboundRoutersSet { router_hashes }); + Self::deposit_event(Event::InboundRoutersSet { + domain, + router_hashes, + }); Ok(()) } @@ -664,76 +694,261 @@ pub mod pallet { } impl Pallet { - fn clear_storages_for_inbound_messages() { - let _ = InboundMessages::::clear(u32::MAX, None); - let _ = InboundMessageProofCount::::clear(u32::MAX, None); - } - //TODO(cdamian): Use safe math fn get_expected_message_proof_count() -> u32 { T::MultiRouterCount::get() - 1 } - /// Inserts a message and its expected proof count, or increases the - /// message proof count for a particular message. - fn get_proof_and_current_count( + fn get_message_proof(message: T::Message) -> Proof { + match message.get_message_proof() { + None => message + .to_message_proof() + .get_message_proof() + .expect("message proof ensured by 'to_message_proof'"), + Some(proof) => proof, + } + } + + fn create_inbound_entry( domain_address: DomainAddress, message: T::Message, - weight: &mut Weight, - ) -> Result<(Proof, u32), DispatchError> { + ) -> InboundEntry { match message.get_message_proof() { - None => { - let message_proof = message - .to_message_proof() - .get_message_proof() - .expect("message proof ensured by 'to_message_proof'"); - - match InboundMessages::::try_mutate(message_proof, |storage_entry| { - match storage_entry { - None => { - *storage_entry = Some(( - domain_address, - message, - Self::get_expected_message_proof_count(), - )); + None => InboundEntry::Message { + domain_address, + message, + expected_proof_count: Self::get_expected_message_proof_count(), + }, + Some(_) => InboundEntry::Proof { current_count: 1 }, + } + } + + /// Validation ensures that: + /// + /// - the router that sent the inbound message is a valid router for the + /// specific domain. + /// - messages are only sent by the first inbound router. + /// - proofs are not sent by the first inbound router. + fn validate_inbound_entry( + domain: Domain, + router_hash: T::Hash, + inbound_entry: &InboundEntry, + ) -> DispatchResult { + let inbound_routers = + //TODO(cdamian): Add new error + InboundRouters::::get(domain).ok_or(Error::::InvalidMultiRouter)?; + + ensure!( + inbound_routers.iter().any(|x| x == &router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + match inbound_entry { + InboundEntry::Message { .. } => { + ensure!( + inbound_routers.get(0) == Some(&router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + Ok(()) + } + InboundEntry::Proof { .. } => { + ensure!( + inbound_routers.get(0) != Some(&router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + Ok(()) + } + } + } + + fn update_storage_entry(old: &mut InboundEntry, new: InboundEntry) -> DispatchResult { + match old { + InboundEntry::Message { + expected_proof_count, + .. + } => match new { + InboundEntry::Message { .. } => { + expected_proof_count + .ensure_add_assign(Self::get_expected_message_proof_count())?; + + Ok(()) + } + //TODO(cdamian): Update error + InboundEntry::Proof { .. } => Err(Error::::InvalidMultiRouter.into()), + }, + InboundEntry::Proof { current_count } => match new { + InboundEntry::Proof { .. } => { + current_count.ensure_add_assign(1)?; + + Ok(()) + } + //TODO(cdamian): Update error + InboundEntry::Message { .. } => Err(Error::::InvalidMultiRouter.into()), + }, + } + } + + fn update_pending_entry( + session_id: T::SessionId, + message_proof: Proof, + router_hash: T::Hash, + inbound_entry: InboundEntry, + weight: &mut Weight, + ) -> DispatchResult { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + PendingInboundEntries::::try_mutate( + session_id, + (message_proof, router_hash), + |storage_entry| match storage_entry { + None => { + *storage_entry = Some(inbound_entry); + + Ok::<(), DispatchError>(()) + } + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count: old, + .. + } => match inbound_entry { + InboundEntry::Message { + expected_proof_count: new, + .. + } => old.ensure_add_assign(new).map_err(|e| e.into()), + InboundEntry::Proof { .. } => { + // TODO(cdamian): Add new error. + Err(Error::::InvalidMultiRouter.into()) + } + }, + InboundEntry::Proof { current_count: old } => match inbound_entry { + InboundEntry::Proof { current_count: new } => { + old.ensure_add_assign(new).map_err(|e| e.into()) } - Some((_, _, expected_proof_count)) => { - // We already have a message, in this case we should expect another - // set of message proofs. - expected_proof_count - .ensure_add_assign(Self::get_expected_message_proof_count())?; + InboundEntry::Message { .. } => { + // TODO(cdamian): Add new error. + Err(Error::::InvalidMultiRouter.into()) } - }; + }, + }, + }, + ) + } - Ok(()) - }) { - Ok(_) => {} - Err(e) => return Err(e), - }; + fn validate_and_update_pending_entries( + session_id: T::SessionId, + message_proof: Proof, + router_hash: T::Hash, + domain_address: DomainAddress, + message: T::Message, + weight: &mut Weight, + ) -> DispatchResult { + let session_id = InboundDomainSessions::::get(domain_address.domain()) + .ok_or(Error::::InvalidMultiRouter)?; - *weight = weight.saturating_add(T::DbWeight::get().reads_writes(1, 1)); + let message_proof = Self::get_message_proof(message.clone()); - Ok(( - message_proof, - InboundMessageProofCount::::get(message_proof), - )) - } - Some(message_proof) => { - let message_proof_count = - match InboundMessageProofCount::::try_mutate(message_proof, |count| { - count.ensure_add_assign(1)?; + let inbound_entry = Self::create_inbound_entry(domain_address.clone(), message); + + Self::validate_inbound_entry(domain_address.domain(), router_hash, &inbound_entry)?; - Ok(*count) - }) { - Ok(r) => r, - Err(e) => return Err(e), - }; + Self::update_pending_entry( + session_id, + message_proof, + router_hash, + inbound_entry, + weight, + )?; - *weight = weight.saturating_add(T::DbWeight::get().writes(1)); + Ok(()) + } - Ok((message_proof, message_proof_count)) + fn get_executable_message( + inbound_routers: BoundedVec, + session_id: T::SessionId, + message_proof: Proof, + ) -> Option { + let mut message = None; + let mut proof_count = 0; + + for inbound_router in inbound_routers { + match PendingInboundEntries::::get(session_id, (message_proof, inbound_router)) { + // We expected one InboundEntry for each router, if that's not the case, + // we can return. + None => return None, + Some(inbound_entry) => match inbound_entry { + InboundEntry::Message { + message: stored_message, + .. + } => message = Some(stored_message), + InboundEntry::Proof { current_count } => { + if current_count > 0 { + proof_count += 1; + } + } + }, + }; + } + + if proof_count == Self::get_expected_message_proof_count() { + return message; + } + + None + } + + fn decrease_pending_entries_counts( + inbound_routers: BoundedVec, + session_id: T::SessionId, + message_proof: Proof, + ) -> DispatchResult { + for inbound_router in inbound_routers { + match PendingInboundEntries::::try_mutate( + session_id, + (message_proof, inbound_router), + |storage_entry| match storage_entry { + // TODO(cdamian): Add new error + None => Err(Error::::InvalidMultiRouter.into()), + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count, + .. + } => { + let updated_count = (*expected_proof_count) + .ensure_sub(Self::get_expected_message_proof_count())?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *expected_proof_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + InboundEntry::Proof { current_count } => { + let updated_count = (*current_count).ensure_sub(1)?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *current_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + }, + }, + ) { + Ok(()) => {} + Err(e) => return Err(e), } } + + Ok(()) } /// Give the message to the `InboundMessageHandler` to be processed. @@ -742,78 +957,64 @@ pub mod pallet { message: T::Message, router_hash: T::Hash, ) -> (DispatchResult, Weight) { + let mut weight = T::DbWeight::get().reads(1); + + let Some(inbound_routers) = InboundRouters::::get(domain_address.domain()) else { + //TODO(cdamian): Add new error + return (Err(Error::::InvalidMultiRouter.into()), weight); + }; + + if inbound_routers.len() == 0 {} + + let Some(session_id) = InboundDomainSessions::::get(domain_address.domain()) else { + //TODO(cdamian): Add error + return (Err(Error::::InvalidMultiRouter.into()), weight); + }; + + let message_proof = Self::get_message_proof(message.clone()); + + weight.saturating_accrue( + Weight::from_parts(0, T::Message::max_encoded_len() as u64) + .saturating_add(LP_DEFENSIVE_WEIGHT), + ); + let mut count = 0; for submessage in message.submessages() { count += 1; - let (message_proof, mut current_message_proof_count) = - match Self::get_proof_and_current_count( - domain_address.clone(), - message.clone(), - &mut weight, - ) { - Ok(r) => r, - Err(e) => return (Err(e), weight), - }; - - let (_, message, mut total_expected_proof_count) = - match InboundMessages::::get(message_proof) { - None => return (Ok(()), weight), - Some(r) => r, - }; + if let Err(e) = Self::validate_and_update_pending_entries( + session_id, + message_proof, + router_hash, + domain_address.clone(), + submessage.clone(), + &mut weight, + ) { + return (Err(e), weight); + } - weight = weight.saturating_add(T::DbWeight::get().reads(1)); - - let expected_message_proof_count = Self::get_expected_message_proof_count(); - - match current_message_proof_count.cmp(&expected_message_proof_count) { - Ordering::Less => return (Ok(()), weight), - Ordering::Equal => { - InboundMessageProofCount::::remove(message_proof); - total_expected_proof_count -= expected_message_proof_count; - - if total_expected_proof_count == 0 { - InboundMessages::::remove(message_proof); - } else { - InboundMessages::::insert( - message_proof, - ( - domain_address.clone(), - message.clone(), - total_expected_proof_count, - ), - ); - } - } - Ordering::Greater => { - current_message_proof_count -= expected_message_proof_count; - InboundMessageProofCount::::insert( + match Self::get_executable_message( + inbound_routers.clone(), + session_id, + message_proof, + ) { + Some(m) => { + if let Err(e) = Self::decrease_pending_entries_counts( + inbound_routers.clone(), + session_id, message_proof, - current_message_proof_count, - ); - - total_expected_proof_count -= expected_message_proof_count; - - if total_expected_proof_count == 0 { - InboundMessages::::remove(message_proof); - } else { - InboundMessages::::insert( - message_proof, - ( - domain_address.clone(), - message.clone(), - total_expected_proof_count, - ), - ); + ) { + return (Err(e), weight.saturating_mul(count)); } - } - } - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), submessage) - { - // We only consume the processed weight if error during the batch - return (Err(e), LP_DEFENSIVE_WEIGHT.saturating_mul(count)); + if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) + { + // We only consume the processed weight if error during the batch + return (Err(e), weight.saturating_mul(count)); + } + } + None => continue, } } @@ -830,7 +1031,7 @@ pub mod pallet { ) -> (DispatchResult, Weight) { let read_weight = T::DbWeight::get().reads(1); - let Some(router) = Routers::::get(router_hash) else { + let Some(router) = OutboundRouters::::get(router_hash) else { return (Err(Error::::RouterNotFound.into()), read_weight); }; @@ -843,7 +1044,7 @@ pub mod pallet { } fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - let router_hashes = DomainMultiRouters::::get(destination.clone()) + let router_hashes = OutboundDomainRouters::::get(destination.clone()) .ok_or(Error::::MultiRouterNotFound)?; let message_proof = message.to_message_proof();