diff --git a/Cargo.lock b/Cargo.lock index 3743f9dbbb..8f8b2d2768 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5450,6 +5450,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -8233,9 +8242,12 @@ dependencies = [ "frame-support", "frame-system", "hex", + "itertools 0.13.0", + "lazy_static", "orml-traits", "parity-scale-codec", "scale-info", + "sp-arithmetic", "sp-core", "sp-io", "sp-runtime", diff --git a/Cargo.toml b/Cargo.toml index f5cac551f1..9e6599e871 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ impl-trait-for-tuples = "0.2.2" num-traits = { version = "0.2.17", default-features = false } num_enum = { version = "0.5.3", default-features = false } chrono = { version = "0.4", default-features = false } +itertools = { version = "0.13.0", default-features = false } # Cumulus cumulus-pallet-aura-ext = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.7.2" } diff --git a/libs/mocks/src/liquidity_pools.rs b/libs/mocks/src/liquidity_pools.rs index acee16d4db..dcae66f60f 100644 --- a/libs/mocks/src/liquidity_pools.rs +++ b/libs/mocks/src/liquidity_pools.rs @@ -2,7 +2,7 @@ pub mod pallet { use cfg_traits::liquidity_pools::InboundMessageHandler; use frame_support::pallet_prelude::*; - use mock_builder::{execute_call, register_call}; + use mock_builder::{execute_call, register_call, CallHandler}; #[pallet::config] pub trait Config: frame_system::Config { @@ -17,8 +17,10 @@ pub mod pallet { type CallIds = StorageMap<_, _, String, mock_builder::CallId>; impl Pallet { - pub fn mock_handle(f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static) { - register_call!(move |(sender, msg)| f(sender, msg)); + pub fn mock_handle( + f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static, + ) -> CallHandler { + register_call!(move |(sender, msg)| f(sender, msg)) } } diff --git a/libs/mocks/src/router_message.rs b/libs/mocks/src/router_message.rs index 4d21688a11..58e7166d38 100644 --- a/libs/mocks/src/router_message.rs +++ b/libs/mocks/src/router_message.rs @@ -2,7 +2,7 @@ pub mod pallet { use cfg_traits::liquidity_pools::{MessageReceiver, MessageSender}; use frame_support::pallet_prelude::*; - use mock_builder::{execute_call, register_call}; + use mock_builder::{execute_call, register_call, CallHandler}; #[pallet::config] pub trait Config: frame_system::Config { @@ -25,8 +25,8 @@ pub mod pallet { pub fn mock_send( f: impl Fn(T::Middleware, T::Origin, Vec) -> DispatchResult + 'static, - ) { - register_call!(move |(a, b, c)| f(a, b, c)); + ) -> CallHandler { + register_call!(move |(a, b, c)| f(a, b, c)) } } diff --git a/libs/primitives/src/lib.rs b/libs/primitives/src/lib.rs index d7200cfa29..323cf554f6 100644 --- a/libs/primitives/src/lib.rs +++ b/libs/primitives/src/lib.rs @@ -161,6 +161,9 @@ pub mod types { /// The type for LP gateway message nonces. pub type LPGatewayQueueMessageNonce = u64; + + /// The type for LP gateway session IDs. + pub type LPGatewaySessionId = u64; } /// Common constants for all runtimes diff --git a/libs/traits/src/liquidity_pools.rs b/libs/traits/src/liquidity_pools.rs index 0a90a072f9..0f80787337 100644 --- a/libs/traits/src/liquidity_pools.rs +++ b/libs/traits/src/liquidity_pools.rs @@ -15,6 +15,8 @@ use frame_support::{dispatch::DispatchResult, weights::Weight}; use sp_runtime::DispatchError; use sp_std::vec::Vec; +pub type Proof = [u8; 32]; + /// An encoding & decoding trait for the purpose of meeting the /// LiquidityPools General Message Passing Format pub trait LPEncoding: Sized { @@ -31,11 +33,20 @@ pub trait LPEncoding: Sized { /// Creates an empty message. /// It's the identity message for composing messages with pack_with fn empty() -> Self; + + /// Retrieves the message proof hash, if the message is a proof type. + fn get_proof(&self) -> Option; + + /// Converts the message into a message proof type. + fn to_proof_message(&self) -> Self; } -pub trait RouterSupport: Sized { +pub trait RouterProvider: Sized { + /// The router identifier. + type RouterId; + /// Returns a list of routers supported for the given domain. - fn for_domain(domain: Domain) -> Vec; + fn routers_for_domain(domain: Domain) -> Vec; } /// The behavior of an entity that can send messages diff --git a/pallets/axelar-router/src/lib.rs b/pallets/axelar-router/src/lib.rs index d18b96bd42..21abe3687f 100644 --- a/pallets/axelar-router/src/lib.rs +++ b/pallets/axelar-router/src/lib.rs @@ -62,6 +62,12 @@ pub enum AxelarId { Evm(EVMChainId), } +impl Default for AxelarId { + fn default() -> Self { + Self::Evm(1) + } +} + /// Configuration for outbound messages though axelar #[derive(Debug, Encode, Decode, Clone, PartialEq, Eq, TypeInfo, MaxEncodedLen)] pub struct AxelarConfig { diff --git a/pallets/liquidity-pools-gateway/Cargo.toml b/pallets/liquidity-pools-gateway/Cargo.toml index 5f01d96284..14c3f2c8eb 100644 --- a/pallets/liquidity-pools-gateway/Cargo.toml +++ b/pallets/liquidity-pools-gateway/Cargo.toml @@ -19,6 +19,7 @@ hex = { workspace = true } orml-traits = { workspace = true } parity-scale-codec = { workspace = true } scale-info = { workspace = true } +sp-arithmetic = { workspace = true } sp-core = { workspace = true } sp-runtime = { workspace = true } sp-std = { workspace = true } @@ -34,6 +35,8 @@ cfg-utils = { workspace = true } [dev-dependencies] cfg-mocks = { workspace = true, default-features = true } +itertools = { workspace = true, default-features = true } +lazy_static = { workspace = true, default-features = true } sp-io = { workspace = true, default-features = true } [features] @@ -53,6 +56,7 @@ std = [ "cfg-utils/std", "hex/std", "cfg-primitives/std", + "sp-arithmetic/std", ] try-runtime = [ "cfg-traits/try-runtime", diff --git a/pallets/liquidity-pools-gateway/queue/src/lib.rs b/pallets/liquidity-pools-gateway/queue/src/lib.rs index 57c5817e17..f473c20b42 100644 --- a/pallets/liquidity-pools-gateway/queue/src/lib.rs +++ b/pallets/liquidity-pools-gateway/queue/src/lib.rs @@ -38,7 +38,7 @@ pub mod pallet { type RuntimeEvent: From> + IsType<::RuntimeEvent>; /// The message type. - type Message: Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec + Default; + type Message: Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec; /// Type used for message identification. type MessageNonce: Parameter diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index f1b1c62336..6e815c45b2 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -31,7 +31,7 @@ use core::fmt::Debug; use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{ InboundMessageHandler, LPEncoding, MessageProcessor, MessageQueue, MessageReceiver, - MessageSender, OutboundMessageHandler, RouterSupport, + MessageSender, OutboundMessageHandler, Proof, RouterProvider, }; use cfg_types::domain_address::{Domain, DomainAddress}; use frame_support::{dispatch::DispatchResult, pallet_prelude::*}; @@ -40,9 +40,11 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::FullCodec; +use sp_arithmetic::traits::{BaseArithmetic, EnsureAddAssign, One}; +use sp_runtime::SaturatedConversion; use sp_std::{convert::TryInto, vec::Vec}; -use crate::weights::WeightInfo; +use crate::{message_processing::InboundEntry, weights::WeightInfo}; mod origin; pub use origin::*; @@ -54,12 +56,14 @@ pub mod weights; #[cfg(test)] mod mock; +mod message_processing; #[cfg(test)] mod tests; #[frame_support::pallet] pub mod pallet { use super::*; + use crate::message_processing::ProofEntry; const STORAGE_VERSION: StorageVersion = StorageVersion::new(1); @@ -91,13 +95,23 @@ pub mod pallet { type AdminOrigin: EnsureOrigin<::RuntimeOrigin>; /// The Liquidity Pools message type. - type Message: LPEncoding + Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec; - - /// The target of of the messages comming from this chain + type Message: LPEncoding + + Clone + + Debug + + PartialEq + + Eq + + MaxEncodedLen + + TypeInfo + + FullCodec; + + /// The target of the messages coming from this chain type MessageSender: MessageSender; /// An identification of a router - type RouterId: RouterSupport + Parameter; + type RouterId: Parameter + MaxEncodedLen; + + /// The type that provides the router available for a domain. + type RouterProvider: RouterProvider; /// The type that processes inbound messages. type InboundMessageHandler: InboundMessageHandler< @@ -117,12 +131,25 @@ pub mod pallet { type Sender: Get; /// Type used for queueing messages. - type MessageQueue: MessageQueue>; + type MessageQueue: MessageQueue>; + + /// Maximum number of routers allowed for a domain. + #[pallet::constant] + type MaxRouterCount: Get; + + /// Type for identifying sessions of inbound routers. + type SessionId: Parameter + Member + BaseArithmetic + Default + Copy + MaxEncodedLen; } #[pallet::event] #[pallet::generate_deposit(pub (super) fn deposit_event)] pub enum Event { + /// The routers for a given domain were set. + RoutersSet { + router_ids: BoundedVec, + session_id: T::SessionId, + }, + /// An instance was added to a domain. InstanceAdded { instance: DomainAddress }, @@ -134,8 +161,22 @@ pub mod pallet { domain: Domain, hook_address: [u8; 20], }, + + /// Message recovery was executed. + MessageRecoveryExecuted { + proof: Proof, + router_id: T::RouterId, + }, } + /// Storage for routers. + /// + /// This can only be set by an admin. + #[pallet::storage] + #[pallet::getter(fn routers)] + pub type Routers = + StorageValue<_, BoundedVec, ValueQuery>; + /// Storage that contains a limited number of whitelisted instances of /// deployed liquidity pools for a particular domain. /// @@ -154,13 +195,29 @@ pub mod pallet { pub type DomainHookAddress = StorageMap<_, Blake2_128Concat, Domain, [u8; 20], OptionQuery>; - /// Stores a batch message, not ready yet to be enqueue. + /// Stores a batch message, not ready yet to be enqueued. /// Lifetime handled by `start_batch_message()` and `end_batch_message()` /// extrinsics. #[pallet::storage] pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; + /// Storage for pending inbound messages. + #[pallet::storage] + #[pallet::getter(fn pending_inbound_entries)] + pub type PendingInboundEntries = StorageDoubleMap< + _, + Blake2_128Concat, + Proof, + Blake2_128Concat, + T::RouterId, + InboundEntry, + >; + + /// Storage for inbound message session IDs. + #[pallet::storage] + pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; + #[pallet::error] pub enum Error { /// The origin of the message to be processed is invalid. @@ -178,8 +235,8 @@ pub mod pallet { /// Unknown instance. UnknownInstance, - /// Router not found. - RouterConfigurationNotFound, + /// Routers not found. + RoutersNotFound, /// Emitted when you call `start_batch_messages()` but that was already /// called. You should finalize the message with `end_batch_messages()` @@ -188,10 +245,69 @@ pub mod pallet { /// Emitted when you can `end_batch_message()` but the packing process /// was not started by `start_batch_message()`. MessagePackingNotStarted, + + /// Unknown router. + UnknownRouter, + + /// The router that sent the message is not the first one. + MessageExpectedFromFirstRouter, + + /// The router that sent the proof should not be the first one. + ProofNotExpectedFromFirstRouter, + + /// A message was expected instead of a proof. + ExpectedMessageType, + + /// A message proof was expected instead of a message. + ExpectedMessageProofType, + + /// Pending inbound entry not found. + PendingInboundEntryNotFound, + + /// Message proof cannot be retrieved. + MessageProofRetrieval, + + /// Recovery message not found. + RecoveryMessageNotFound, + + /// Not enough routers are stored for a domain. + NotEnoughRoutersForDomain, + + /// The messages of 2 inbound entries do not match. + InboundEntryMessageMismatch, + + /// The domain addresses of 2 inbound entries do not match. + InboundEntryDomainAddressMismatch, } #[pallet::call] impl Pallet { + /// Sets the IDs of the routers that are used when receiving and sending + /// messages. + #[pallet::weight(T::WeightInfo::set_routers())] + #[pallet::call_index(0)] + pub fn set_routers( + origin: OriginFor, + router_ids: BoundedVec, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + >::set(router_ids.clone()); + + let new_session_id = SessionIdStore::::try_mutate(|n| { + n.ensure_add_assign(One::one())?; + + Ok::(*n) + })?; + + Self::deposit_event(Event::RoutersSet { + router_ids, + session_id: new_session_id, + }); + + Ok(()) + } + /// Add a known instance of a deployed liquidity pools integration for a /// specific domain. #[pallet::weight(T::WeightInfo::add_instance())] @@ -254,7 +370,7 @@ pub mod pallet { /// Set the address of the domain hook /// /// Can only be called by `AdminOrigin`. - #[pallet::weight(T::WeightInfo::set_domain_router())] + #[pallet::weight(T::WeightInfo::set_domain_hook_address())] #[pallet::call_index(8)] pub fn set_domain_hook_address( origin: OriginFor, @@ -301,69 +417,66 @@ pub mod pallet { match PackedMessage::::take((&sender, &destination)) { Some(msg) if msg.submessages().is_empty() => Ok(()), //No-op - Some(message) => Self::queue_message(destination, message), + Some(message) => Self::queue_outbound_message(destination, message), None => Err(Error::::MessagePackingNotStarted.into()), } } - } - impl Pallet { - /// Give the message to the `InboundMessageHandler` to be processed. - fn process_inbound_message( + /// Manually increase the proof count for a particular message and + /// executes it if the required count is reached. + /// + /// Can only be called by `AdminOrigin`. + #[pallet::weight(T::WeightInfo::execute_message_recovery())] + #[pallet::call_index(11)] + pub fn execute_message_recovery( + origin: OriginFor, domain_address: DomainAddress, - message: T::Message, - ) -> (DispatchResult, Weight) { - let mut count = 0; + proof: Proof, + router_id: T::RouterId, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; - for submessage in message.submessages() { - count += 1; + let router_ids = Self::get_router_ids_for_domain(domain_address.domain())?; - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), submessage) - { - // We only consume the processed weight if error during the batch - return (Err(e), LP_DEFENSIVE_WEIGHT.saturating_mul(count)); + ensure!( + router_ids.iter().any(|x| x == &router_id), + Error::::UnknownRouter + ); + // Message recovery shouldn't be supported for setups that have less than 1 + // router since no proofs are required in that case. + ensure!(router_ids.len() > 1, Error::::NotEnoughRoutersForDomain); + + let session_id = SessionIdStore::::get(); + + PendingInboundEntries::::try_mutate(proof, router_id.clone(), |storage_entry| { + match storage_entry { + Some(stored_inbound_entry) => { + stored_inbound_entry.increment_proof_count(session_id) + } + None => { + *storage_entry = Some(InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + })); + + Ok::<(), DispatchError>(()) + } } - } - - (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) - } - - /// Retrieves the router stored for the provided domain, sends the - /// message using the router, and calculates and returns the required - /// weight for these operations in the `DispatchResultWithPostInfo`. - fn process_outbound_message( - sender: DomainAddress, - domain: Domain, - message: T::Message, - ) -> (DispatchResult, Weight) { - let router_ids = T::RouterId::for_domain(domain); - - // TODO handle router ids logic + })?; - let mut count = 0; - let bytes = message.serialize(); - - for router_id in router_ids { - count += 1; - if let Err(e) = T::MessageSender::send(router_id, sender.clone(), bytes.clone()) { - return (Err(e), LP_DEFENSIVE_WEIGHT.saturating_mul(count)); - } - } + let expected_proof_count = Self::get_expected_proof_count(&router_ids)?; - // TODO: Should we fix weights? - (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) - } + Self::execute_if_requirements_are_met( + proof, + &router_ids, + session_id, + expected_proof_count, + domain_address, + )?; - fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - // We are using the sender specified in the pallet config so that we can - // ensure that the account is funded - let gateway_message = GatewayMessage::::Outbound { - sender: T::Sender::get(), - destination, - message, - }; + Self::deposit_event(Event::::MessageRecoveryExecuted { proof, router_id }); - T::MessageQueue::submit(gateway_message) + Ok(()) } } @@ -384,7 +497,7 @@ pub mod pallet { PackedMessage::::mutate((&from, destination.clone()), |batch| match batch { Some(batch) => batch.pack_with(message), - None => Self::queue_message(destination, message), + None => Self::queue_outbound_message(destination, message), }) } } @@ -396,27 +509,36 @@ pub mod pallet { } impl MessageProcessor for Pallet { - type Message = GatewayMessage; + type Message = GatewayMessage; fn process(msg: Self::Message) -> (DispatchResult, Weight) { match msg { GatewayMessage::Inbound { domain_address, message, - } => Self::process_inbound_message(domain_address, message), + router_id, + } => Self::process_inbound_message(domain_address, message, router_id), GatewayMessage::Outbound { sender, - destination, message, - } => Self::process_outbound_message(sender, destination, message), + router_id, + } => { + let weight = LP_DEFENSIVE_WEIGHT; + + match T::MessageSender::send(router_id, sender, message.serialize()) { + Ok(_) => (Ok(()), weight), + Err(e) => (Err(e), weight), + } + } } } - /// Process a message. + /// Returns the max processing weight for a message, based on its + /// direction. fn max_processing_weight(msg: &Self::Message) -> Weight { match msg { GatewayMessage::Inbound { message, .. } => { - LP_DEFENSIVE_WEIGHT.saturating_mul(message.submessages().len() as u64) + LP_DEFENSIVE_WEIGHT.saturating_mul(message.submessages().len().saturated_into()) } GatewayMessage::Outbound { .. } => LP_DEFENSIVE_WEIGHT, } @@ -428,20 +550,19 @@ pub mod pallet { type Origin = DomainAddress; fn receive( - _router_id: T::RouterId, + router_id: T::RouterId, origin_address: DomainAddress, message: Vec, ) -> DispatchResult { - // TODO handle router ids logic with votes and session_id - ensure!( Allowlist::::contains_key(origin_address.domain(), origin_address.clone()), Error::::UnknownInstance, ); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::::Inbound { domain_address: origin_address, message: T::Message::deserialize(&message)?, + router_id, }; T::MessageQueue::submit(gateway_message) diff --git a/pallets/liquidity-pools-gateway/src/message.rs b/pallets/liquidity-pools-gateway/src/message.rs index 505e075b0f..1c604f5d0d 100644 --- a/pallets/liquidity-pools-gateway/src/message.rs +++ b/pallets/liquidity-pools-gateway/src/message.rs @@ -1,25 +1,17 @@ -use cfg_types::domain_address::{Domain, DomainAddress}; +use cfg_types::domain_address::DomainAddress; use frame_support::pallet_prelude::{Decode, Encode, MaxEncodedLen, TypeInfo}; /// Message type used by the LP gateway. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] -pub enum GatewayMessage { +pub enum GatewayMessage { Inbound { domain_address: DomainAddress, message: Message, + router_id: RouterId, }, Outbound { sender: DomainAddress, - destination: Domain, message: Message, + router_id: RouterId, }, } - -impl Default for GatewayMessage { - fn default() -> Self { - GatewayMessage::Inbound { - domain_address: Default::default(), - message: Default::default(), - } - } -} diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs new file mode 100644 index 0000000000..1783c7e869 --- /dev/null +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -0,0 +1,495 @@ +use cfg_primitives::LP_DEFENSIVE_WEIGHT; +use cfg_traits::liquidity_pools::{ + InboundMessageHandler, LPEncoding, MessageQueue, Proof, RouterProvider, +}; +use cfg_types::domain_address::{Domain, DomainAddress}; +use frame_support::{ + dispatch::DispatchResult, + ensure, + pallet_prelude::{Decode, Encode, Get, TypeInfo}, + weights::Weight, +}; +use parity_scale_codec::MaxEncodedLen; +use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub, SaturatedConversion}; +use sp_runtime::DispatchError; +use sp_std::vec::Vec; + +use crate::{ + message::GatewayMessage, Config, Error, Pallet, PendingInboundEntries, Routers, SessionIdStore, +}; + +/// Type that holds the information needed for inbound message entries. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub struct MessageEntry { + /// The session ID for this entry. + pub session_id: T::SessionId, + + /// The sender of the inbound message. + /// + /// NOTE - the `RouterProvider` ensures that we cannot have the same message + /// entry, for the same message, for different domain addresses. + pub domain_address: DomainAddress, + + /// The LP message. + pub message: T::Message, + + /// The expected proof count for processing one or more of the provided + /// message. + /// + /// NOTE - this gets increased by the `expected_proof_count` for a set of + /// routers (see `get_expected_proof_count`) every time a new identical + /// message is submitted. + pub expected_proof_count: u32, +} + +/// Type that holds the information needed for inbound proof entries. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub struct ProofEntry { + /// The session ID for this entry. + pub session_id: T::SessionId, + + /// The number of proofs received for a particular message. + /// + /// NOTE - this gets increased by 1 every time a new identical message is + /// submitted. + pub current_count: u32, +} + +impl ProofEntry { + /// Returns `true` if all the following conditions are true: + /// - the session IDs match + /// - the `current_count` is greater than 0 + pub fn has_valid_vote_for_session(&self, session_id: T::SessionId) -> bool { + self.session_id == session_id && self.current_count > 0 + } +} + +/// Type used when storing inbound message information. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub enum InboundEntry { + Message(MessageEntry), + Proof(ProofEntry), +} + +impl From> for InboundEntry { + fn from(message_entry: MessageEntry) -> Self { + Self::Message(message_entry) + } +} + +impl From> for InboundEntry { + fn from(proof_entry: ProofEntry) -> Self { + Self::Proof(proof_entry) + } +} + +impl InboundEntry { + /// Creates an inbound entry based on the type of message. + pub fn create( + message: T::Message, + session_id: T::SessionId, + domain_address: DomainAddress, + expected_proof_count: u32, + ) -> Self { + match message.get_proof() { + None => InboundEntry::Message(MessageEntry { + session_id, + domain_address, + message, + expected_proof_count, + }), + Some(_) => InboundEntry::Proof(ProofEntry { + session_id, + current_count: 1, + }), + } + } + + /// Creates a new `InboundEntry` based on the information provided. + /// + /// If the updated counts reach 0, it means that a new entry is no longer + /// required, otherwise, the counts are decreased accordingly, based on the + /// entry type. + pub fn create_post_voting_entry( + inbound_entry: &InboundEntry, + expected_proof_count: u32, + ) -> Result, DispatchError> { + match inbound_entry { + InboundEntry::Message(message_entry) => { + let updated_count = message_entry + .expected_proof_count + .ensure_sub(expected_proof_count)?; + + if updated_count == 0 { + return Ok(None); + } + + Ok(Some( + MessageEntry { + expected_proof_count: updated_count, + ..message_entry.clone() + } + .into(), + )) + } + InboundEntry::Proof(proof_entry) => { + let updated_count = proof_entry.current_count.ensure_sub(1)?; + + if updated_count == 0 { + return Ok(None); + } + + Ok(Some( + ProofEntry { + current_count: updated_count, + ..proof_entry.clone() + } + .into(), + )) + } + } + } + + /// Validation ensures that: + /// + /// - the router that sent the inbound message is a valid router for the + /// specific domain. + /// - messages are only sent by the first inbound router. + /// - proofs are not sent by the first inbound router. + pub fn validate(&self, router_ids: &[T::RouterId], router_id: &T::RouterId) -> DispatchResult { + ensure!( + router_ids.iter().any(|x| x == router_id), + Error::::UnknownRouter + ); + + match self { + InboundEntry::Message { .. } => { + ensure!( + router_ids.first() == Some(router_id), + Error::::MessageExpectedFromFirstRouter + ); + + Ok(()) + } + InboundEntry::Proof { .. } => { + ensure!( + router_ids.first() != Some(router_id), + Error::::ProofNotExpectedFromFirstRouter + ); + + Ok(()) + } + } + } + + /// Checks if the entry type is a proof and increments the count by 1 + /// or sets it to 1 if the session is changed. + pub fn increment_proof_count(&mut self, session_id: T::SessionId) -> DispatchResult { + match self { + InboundEntry::Proof(proof_entry) => { + if proof_entry.session_id != session_id { + proof_entry.session_id = session_id; + proof_entry.current_count = 1; + } else { + proof_entry.current_count.ensure_add_assign(1)?; + } + + Ok::<(), DispatchError>(()) + } + InboundEntry::Message(_) => Err(Error::::ExpectedMessageProofType.into()), + } + } + + /// A pre-dispatch update involves increasing the `expected_proof_count` or + /// `current_count` of `self` with the one of `other`. + /// + /// If a session ID change is detected, `self` is replaced completely by + /// `other`. + pub fn pre_dispatch_update(&mut self, other: Self) -> DispatchResult { + match (&mut *self, &other) { + // Message entries + ( + InboundEntry::Message(self_message_entry), + InboundEntry::Message(other_message_entry), + ) => { + if self_message_entry.session_id != other_message_entry.session_id { + *self = other; + + return Ok(()); + } + + self_message_entry + .expected_proof_count + .ensure_add_assign(other_message_entry.expected_proof_count)?; + + Ok(()) + } + // Proof entries + (InboundEntry::Proof(self_proof_entry), InboundEntry::Proof(other_proof_entry)) => { + if self_proof_entry.session_id != other_proof_entry.session_id { + *self = other; + + return Ok(()); + } + + self_proof_entry + .current_count + .ensure_add_assign(other_proof_entry.current_count)?; + + Ok(()) + } + // Mismatches + (InboundEntry::Message(_), InboundEntry::Proof(_)) => { + Err(Error::::ExpectedMessageType.into()) + } + (InboundEntry::Proof(_), InboundEntry::Message(_)) => { + Err(Error::::ExpectedMessageProofType.into()) + } + } + } +} + +impl Pallet { + /// Retrieves all stored routers and then filters them based + /// on the available routers for the provided domain. + pub(crate) fn get_router_ids_for_domain( + domain: Domain, + ) -> Result, DispatchError> { + let stored_routers = Routers::::get(); + + let all_routers_for_domain = T::RouterProvider::routers_for_domain(domain); + + let res = stored_routers + .iter() + .filter(|stored_router| { + all_routers_for_domain + .iter() + .any(|available_router| *stored_router == available_router) + }) + .cloned() + .collect::>(); + + if res.is_empty() { + return Err(Error::::NotEnoughRoutersForDomain.into()); + } + + Ok(res) + } + + /// Calculates and returns the proof count required for processing one + /// inbound message. + pub(crate) fn get_expected_proof_count( + router_ids: &[T::RouterId], + ) -> Result { + let expected_proof_count = router_ids + .len() + .ensure_sub(1) + .map_err(|_| Error::::NotEnoughRoutersForDomain)?; + + Ok(expected_proof_count.saturated_into()) + } + + /// Gets the message proof for a message. + pub(crate) fn get_message_proof(message: T::Message) -> Proof { + match message.get_proof() { + None => message + .to_proof_message() + .get_proof() + .expect("message proof ensured by 'to_message_proof'"), + Some(proof) => proof, + } + } + + /// Upserts an inbound entry for a particular message, increasing the + /// relevant counts accordingly. + pub(crate) fn upsert_pending_entry( + message_proof: Proof, + router_id: &T::RouterId, + new_inbound_entry: InboundEntry, + ) -> DispatchResult { + PendingInboundEntries::::try_mutate(message_proof, router_id, |storage_entry| { + match storage_entry { + None => { + *storage_entry = Some(new_inbound_entry); + + Ok::<(), DispatchError>(()) + } + Some(stored_inbound_entry) => { + stored_inbound_entry.pre_dispatch_update(new_inbound_entry) + } + } + }) + } + + /// Checks if the number of proofs required for executing one message + /// were received, and if so, decreases the counts accordingly and executes + /// the message. + pub(crate) fn execute_if_requirements_are_met( + message_proof: Proof, + router_ids: &[T::RouterId], + session_id: T::SessionId, + expected_proof_count: u32, + domain_address: DomainAddress, + ) -> DispatchResult { + let mut message = None; + let mut votes = 0; + + for router_id in router_ids { + match PendingInboundEntries::::get(message_proof, router_id) { + // We expected one InboundEntry for each router, if that's not the case, + // we can return. + None => return Ok(()), + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message(message_entry) => message = Some(message_entry.message), + InboundEntry::Proof(proof_entry) + if proof_entry.has_valid_vote_for_session(session_id) => + { + votes.ensure_add_assign(1)?; + } + _ => {} + }, + }; + } + + if votes < expected_proof_count { + return Ok(()); + } + + if let Some(msg) = message { + Self::execute_post_voting_dispatch(message_proof, router_ids, expected_proof_count)?; + + T::InboundMessageHandler::handle(domain_address, msg)?; + } + + Ok(()) + } + + /// Decreases the counts for inbound entries and removes them if the + /// counts reach 0. + pub(crate) fn execute_post_voting_dispatch( + message_proof: Proof, + router_ids: &[T::RouterId], + expected_proof_count: u32, + ) -> DispatchResult { + for router_id in router_ids { + PendingInboundEntries::::try_mutate(message_proof, router_id, |storage_entry| { + match storage_entry { + None => { + // This case cannot be reproduced in production since this function is + // called only if a message is submitted for further processing, which + // means that all the pending inbound entries are present. + Err::<(), DispatchError>(Error::::PendingInboundEntryNotFound.into()) + } + Some(stored_inbound_entry) => { + let post_dispatch_entry = InboundEntry::create_post_voting_entry( + stored_inbound_entry, + expected_proof_count, + )?; + + *storage_entry = post_dispatch_entry; + + Ok(()) + } + } + })?; + } + + Ok(()) + } + + /// Iterates over a batch of messages and checks if the requirements for + /// processing each message are met. + pub(crate) fn process_inbound_message( + domain_address: DomainAddress, + message: T::Message, + router_id: T::RouterId, + ) -> (DispatchResult, Weight) { + let weight = LP_DEFENSIVE_WEIGHT; + + let router_ids = match Self::get_router_ids_for_domain(domain_address.domain()) { + Ok(r) => r, + Err(e) => return (Err(e), weight), + }; + + let session_id = SessionIdStore::::get(); + + let expected_proof_count = match Self::get_expected_proof_count(&router_ids) { + Ok(r) => r, + Err(e) => return (Err(e), weight), + }; + + let mut count = 0; + + for submessage in message.submessages() { + if let Err(e) = count.ensure_add_assign(1) { + return (Err(e.into()), weight.saturating_mul(count)); + } + + let message_proof = Self::get_message_proof(submessage.clone()); + + let inbound_entry: InboundEntry = InboundEntry::create( + submessage, + session_id, + domain_address.clone(), + expected_proof_count, + ); + + if let Err(e) = inbound_entry.validate(&router_ids, &router_id.clone()) { + return (Err(e), weight.saturating_mul(count)); + } + + if let Err(e) = Self::upsert_pending_entry(message_proof, &router_id, inbound_entry) { + return (Err(e), weight.saturating_mul(count)); + } + + match Self::execute_if_requirements_are_met( + message_proof, + &router_ids, + session_id, + expected_proof_count, + domain_address.clone(), + ) { + Err(e) => return (Err(e), weight.saturating_mul(count)), + Ok(_) => continue, + } + } + + (Ok(()), weight.saturating_mul(count)) + } + + /// Retrieves the IDs of the routers set for a domain and queues the + /// message and proofs accordingly. + pub(crate) fn queue_outbound_message( + destination: Domain, + message: T::Message, + ) -> DispatchResult { + let router_ids = Self::get_router_ids_for_domain(destination)?; + + let message_proof = message.to_proof_message(); + let mut message_opt = Some(message); + + for router_id in router_ids { + // Ensure that we only send the actual message once, using one router. + // The remaining routers will send the message proof. + let router_msg = match message_opt.take() { + Some(m) => m, + None => message_proof.clone(), + }; + + // We are using the sender specified in the pallet config so that we can + // ensure that the account is funded + let gateway_message = GatewayMessage::::Outbound { + sender: T::Sender::get(), + message: router_msg, + router_id, + }; + + T::MessageQueue::submit(gateway_message)?; + } + + Ok(()) + } +} diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 1ea75c1885..9296a0b534 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -1,6 +1,11 @@ +use std::fmt::{Debug, Formatter}; + use cfg_mocks::{pallet_mock_liquidity_pools, pallet_mock_liquidity_pools_gateway_queue}; -use cfg_traits::liquidity_pools::{LPEncoding, RouterSupport}; -use cfg_types::domain_address::{Domain, DomainAddress}; +use cfg_traits::liquidity_pools::{LPEncoding, Proof, RouterProvider}; +use cfg_types::{ + domain_address::{Domain, DomainAddress}, + EVMChainId, +}; use frame_support::{derive_impl, weights::constants::RocksDbWeight}; use frame_system::EnsureRoot; use parity_scale_codec::{Decode, Encode, MaxEncodedLen}; @@ -10,16 +15,36 @@ use sp_runtime::{traits::IdentityLookup, DispatchError, DispatchResult}; use crate::{pallet as pallet_liquidity_pools_gateway, EnsureLocal, GatewayMessage}; +pub const TEST_SESSION_ID: u32 = 1; +pub const TEST_EVM_CHAIN: EVMChainId = 1; +pub const TEST_DOMAIN_ADDRESS: DomainAddress = DomainAddress::EVM(TEST_EVM_CHAIN, [1; 20]); + +pub const ROUTER_ID_1: RouterId = RouterId(1); +pub const ROUTER_ID_2: RouterId = RouterId(2); +pub const ROUTER_ID_3: RouterId = RouterId(3); + pub const LP_ADMIN_ACCOUNT: AccountId32 = AccountId32::new([u8::MAX; 32]); pub const MAX_PACKED_MESSAGES_ERR: &str = "packed limit error"; pub const MAX_PACKED_MESSAGES: usize = 10; -#[derive(Default, Debug, Eq, PartialEq, Clone, Encode, Decode, TypeInfo)] +pub const MESSAGE_PROOF: [u8; 32] = [1; 32]; + +#[derive(Eq, PartialEq, Clone, Encode, Decode, TypeInfo, Hash)] pub enum Message { - #[default] Simple, Pack(Vec), + Proof([u8; 32]), +} + +impl Debug for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Simple => write!(f, "Simple"), + Message::Pack(p) => write!(f, "Pack - {:?}", p), + Message::Proof(_) => write!(f, "Proof"), + } + } } /// Avoiding automatic infinity loop with the MaxEncodedLen derive @@ -32,8 +57,8 @@ impl MaxEncodedLen for Message { impl LPEncoding for Message { fn serialize(&self) -> Vec { match self { - Self::Simple => vec![0x42], Self::Pack(list) => list.iter().map(|_| 0x42).collect(), + _ => vec![0x42], } } @@ -47,10 +72,6 @@ impl LPEncoding for Message { fn pack_with(&mut self, other: Self) -> DispatchResult { match self { - Self::Simple => { - *self = Self::Pack(vec![Self::Simple, other]); - Ok(()) - } Self::Pack(list) if list.len() == MAX_PACKED_MESSAGES => { Err(MAX_PACKED_MESSAGES_ERR.into()) } @@ -58,27 +79,52 @@ impl LPEncoding for Message { list.push(other); Ok(()) } + _ => { + *self = Self::Pack(vec![self.clone(), other]); + Ok(()) + } } } fn submessages(&self) -> Vec { match self { - Self::Simple => vec![Self::Simple], Self::Pack(list) => list.clone(), + _ => vec![self.clone()], } } fn empty() -> Self { Self::Pack(vec![]) } + + fn get_proof(&self) -> Option { + match self { + Message::Proof(p) => Some(p.clone()), + _ => None, + } + } + + fn to_proof_message(&self) -> Self { + match self { + Message::Proof(_) => self.clone(), + _ => Message::Proof(MESSAGE_PROOF), + } + } } -#[derive(Debug, Encode, Decode, Clone, PartialEq, Eq, TypeInfo, MaxEncodedLen)] -pub struct RouterId(u32); +#[derive(Default, Debug, Encode, Decode, Clone, PartialEq, Eq, TypeInfo, MaxEncodedLen, Hash)] +pub struct RouterId(pub u32); + +pub struct TestRouterProvider; -impl RouterSupport for RouterId { - fn for_domain(_domain: Domain) -> Vec { - vec![] // TODO +impl RouterProvider for TestRouterProvider { + type RouterId = RouterId; + + fn routers_for_domain(domain: Domain) -> Vec { + match domain { + Domain::Centrifuge => vec![], + Domain::EVM(_) => vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3], + } } } @@ -106,7 +152,7 @@ impl pallet_mock_liquidity_pools::Config for Runtime { } impl pallet_mock_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; } impl cfg_mocks::router_message::pallet::Config for Runtime { @@ -115,9 +161,10 @@ impl cfg_mocks::router_message::pallet::Config for Runtime { } frame_support::parameter_types! { - pub Sender: DomainAddress = DomainAddress::Centrifuge(AccountId32::from(H256::from_low_u64_be(1).to_fixed_bytes()).into()); + pub Sender: DomainAddress = DomainAddress::Centrifuge(AccountId32::from(H256::from_low_u64_be(123).to_fixed_bytes()).into()); pub const MaxIncomingMessageSize: u32 = 1024; pub const LpAdminAccount: AccountId32 = LP_ADMIN_ACCOUNT; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -125,18 +172,19 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = MockLiquidityPools; type LocalEVMOrigin = EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; type MessageSender = MockMessageSender; type RouterId = RouterId; + type RouterProvider = TestRouterProvider; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = u32; type WeightInfo = (); } -/* pub fn new_test_ext() -> sp_io::TestExternalities { System::externalities() } -*/ diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 61e54fde63..d687a9cd57 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,14 +1,17 @@ -/* -use cfg_mocks::*; +use std::collections::HashMap; + use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler}; use cfg_types::domain_address::*; -use frame_support::{ - assert_err, assert_noop, assert_ok, dispatch::PostDispatchInfo, pallet_prelude::Pays, - weights::Weight, -}; +use frame_support::{assert_err, assert_noop, assert_ok}; +use itertools::Itertools; +use lazy_static::lazy_static; +use sp_arithmetic::ArithmeticError::{Overflow, Underflow}; use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160}; -use sp_runtime::{DispatchError, DispatchError::BadOrigin, DispatchErrorWithPostInfo}; +use sp_runtime::{ + DispatchError, + DispatchError::{Arithmetic, BadOrigin}, +}; use sp_std::sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -19,7 +22,10 @@ use super::{ origin::*, pallet::*, }; -use crate::GatewayMessage; +use crate::{ + message_processing::{InboundEntry, MessageEntry, ProofEntry}, + GatewayMessage, +}; mod utils { use super::*; @@ -42,812 +48,4041 @@ mod utils { use utils::*; -mod set_domain_router { +mod extrinsics { use super::*; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); + mod set_routers { + use super::*; - assert_ok!(LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router.clone(), - )); + #[test] + fn success() { + new_test_ext().execute_with(|| { + let mut session_id = 1; - let storage_entry = DomainRouters::::get(domain.clone()); - assert_eq!(storage_entry.unwrap(), router); + let mut router_ids = + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]).unwrap(); - event_exists(Event::::DomainRouterSet { domain, router }); - }); - } - #[test] - fn router_init_error() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Err(DispatchError::Other("error"))); - - assert_noop!( - LiquidityPoolsGateway::set_domain_router( + assert_ok!(LiquidityPoolsGateway::set_routers( RuntimeOrigin::root(), - domain.clone(), - router, - ), - Error::::RouterInitFailed, - ); - }); - } + router_ids.clone(), + )); - #[test] - fn bad_origin() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let router = RouterMock::::default(); - - assert_noop!( - LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::signed(get_test_account_id()), - domain.clone(), - router, - ), - BadOrigin - ); - - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); - }); - } + assert_eq!(Routers::::get(), router_ids.clone()); + assert_eq!(SessionIdStore::::get(), session_id); - #[test] - fn unsupported_domain() { - new_test_ext().execute_with(|| { - let domain = Domain::Centrifuge; - let router = RouterMock::::default(); + event_exists(Event::::RoutersSet { + router_ids, + session_id, + }); - assert_noop!( - LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router - ), - Error::::DomainNotSupported - ); - - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); - }); - } -} + router_ids = + BoundedVec::try_from(vec![ROUTER_ID_3, ROUTER_ID_2, ROUTER_ID_1]).unwrap(); -mod add_instance { - use super::*; + session_id += 1; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + assert_ok!(LiquidityPoolsGateway::set_routers( + RuntimeOrigin::root(), + router_ids.clone(), + )); - assert_ok!(LiquidityPoolsGateway::add_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); + assert_eq!(Routers::::get(), router_ids.clone()); + assert_eq!(SessionIdStore::::get(), session_id); - assert!(Allowlist::::contains_key( - domain_address.domain(), - domain_address.clone() - )); + event_exists(Event::::RoutersSet { + router_ids, + session_id, + }); + }); + } - event_exists(Event::::InstanceAdded { - instance: domain_address, + #[test] + fn bad_origin() { + new_test_ext().execute_with(|| { + assert_noop!( + LiquidityPoolsGateway::set_routers( + RuntimeOrigin::signed(get_test_account_id()), + BoundedVec::try_from(vec![]).unwrap(), + ), + BadOrigin + ); + + assert!(Routers::::get().is_empty()); + assert_eq!(SessionIdStore::::get(), 0); }); - }); + } + + #[test] + fn session_id_overflow() { + new_test_ext().execute_with(|| { + SessionIdStore::::set(u32::MAX); + + assert_noop!( + LiquidityPoolsGateway::set_routers( + RuntimeOrigin::root(), + BoundedVec::try_from(vec![ROUTER_ID_1]).unwrap(), + ), + Arithmetic(Overflow) + ); + }); + } } - #[test] - fn bad_origin() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + mod add_instance { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); - assert_noop!( - LiquidityPoolsGateway::add_instance( - RuntimeOrigin::signed(get_test_account_id()), + assert_ok!(LiquidityPoolsGateway::add_instance( + RuntimeOrigin::root(), domain_address.clone(), - ), - BadOrigin - ); - - assert!(!Allowlist::::contains_key( - domain_address.domain(), - domain_address.clone() - )); - }); - } + )); - #[test] - fn unsupported_domain() { - new_test_ext().execute_with(|| { - let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); - - assert_noop!( - LiquidityPoolsGateway::add_instance(RuntimeOrigin::root(), domain_address.clone()), - Error::::DomainNotSupported - ); - - assert!(!Allowlist::::contains_key( - domain_address.domain(), - domain_address.clone() - )); - }); - } + assert!(Allowlist::::contains_key( + domain_address.domain(), + domain_address.clone() + )); - #[test] - fn instance_already_added() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); - - assert_ok!(LiquidityPoolsGateway::add_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); - - assert!(Allowlist::::contains_key( - domain_address.domain(), - domain_address.clone() - )); - - assert_noop!( - LiquidityPoolsGateway::add_instance(RuntimeOrigin::root(), domain_address), - Error::::InstanceAlreadyAdded - ); - }); - } -} + event_exists(Event::::InstanceAdded { + instance: domain_address, + }); + }); + } -mod remove_instance { - use super::*; + #[test] + fn bad_origin() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + + assert_noop!( + LiquidityPoolsGateway::add_instance( + RuntimeOrigin::signed(get_test_account_id()), + domain_address.clone(), + ), + BadOrigin + ); + + assert!(!Allowlist::::contains_key( + domain_address.domain(), + domain_address.clone() + )); + }); + } - #[test] - fn success() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + #[test] + fn unsupported_domain() { + new_test_ext().execute_with(|| { + let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); + + assert_noop!( + LiquidityPoolsGateway::add_instance( + RuntimeOrigin::root(), + domain_address.clone() + ), + Error::::DomainNotSupported + ); + + assert!(!Allowlist::::contains_key( + domain_address.domain(), + domain_address.clone() + )); + }); + } - assert_ok!(LiquidityPoolsGateway::add_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); + #[test] + fn instance_already_added() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); - assert_ok!(LiquidityPoolsGateway::remove_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); + assert_ok!(LiquidityPoolsGateway::add_instance( + RuntimeOrigin::root(), + domain_address.clone(), + )); - assert!(!Allowlist::::contains_key( - domain_address.domain(), - domain_address.clone() - )); + assert!(Allowlist::::contains_key( + domain_address.domain(), + domain_address.clone() + )); - event_exists(Event::::InstanceRemoved { - instance: domain_address.clone(), + assert_noop!( + LiquidityPoolsGateway::add_instance(RuntimeOrigin::root(), domain_address), + Error::::InstanceAlreadyAdded + ); }); - }); + } } - #[test] - fn bad_origin() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + mod remove_instance { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + + assert_ok!(LiquidityPoolsGateway::add_instance( + RuntimeOrigin::root(), + domain_address.clone(), + )); - assert_noop!( - LiquidityPoolsGateway::remove_instance( - RuntimeOrigin::signed(get_test_account_id()), + assert_ok!(LiquidityPoolsGateway::remove_instance( + RuntimeOrigin::root(), domain_address.clone(), - ), - BadOrigin - ); - }); + )); + + assert!(!Allowlist::::contains_key( + domain_address.domain(), + domain_address.clone() + )); + + event_exists(Event::::InstanceRemoved { + instance: domain_address.clone(), + }); + }); + } + + #[test] + fn bad_origin() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + + assert_noop!( + LiquidityPoolsGateway::remove_instance( + RuntimeOrigin::signed(get_test_account_id()), + domain_address.clone(), + ), + BadOrigin + ); + }); + } + + #[test] + fn instance_not_found() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + + assert_noop!( + LiquidityPoolsGateway::remove_instance( + RuntimeOrigin::root(), + domain_address.clone(), + ), + Error::::UnknownInstance, + ); + }); + } } - #[test] - fn instance_not_found() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + mod receive_message { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + let message = Message::Simple; + + let router_id = ROUTER_ID_1; - assert_noop!( - LiquidityPoolsGateway::remove_instance( + assert_ok!(LiquidityPoolsGateway::add_instance( RuntimeOrigin::root(), domain_address.clone(), - ), - Error::::UnknownInstance, - ); - }); - } -} + )); -mod receive_message_domain { - use super::*; + let encoded_msg = message.serialize(); + + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_id.clone(), + }; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); - let message = Message::Simple; + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { + assert_eq!(mock_message, gateway_message); + Ok(()) + }); - assert_ok!(LiquidityPoolsGateway::add_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); + assert_ok!(LiquidityPoolsGateway::receive_message( + GatewayOrigin::Domain(domain_address).into(), + router_id, + BoundedVec::::try_from(encoded_msg).unwrap() + )); - let encoded_msg = message.serialize(); + assert_eq!(handler.times(), 1); + }); + } - let gateway_message = GatewayMessage::::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - }; + #[test] + fn bad_origin() { + new_test_ext().execute_with(|| { + let encoded_msg = Message::Simple.serialize(); + + let router_id = ROUTER_ID_1; + + assert_noop!( + LiquidityPoolsGateway::receive_message( + RuntimeOrigin::signed(AccountId32::new([0u8; 32])), + router_id, + BoundedVec::::try_from(encoded_msg).unwrap() + ), + BadOrigin, + ); + }); + } - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { - assert_eq!(mock_message, gateway_message); - Ok(()) + #[test] + fn invalid_message_origin() { + new_test_ext().execute_with(|| { + let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); + let encoded_msg = Message::Simple.serialize(); + let router_id = ROUTER_ID_1; + + assert_noop!( + LiquidityPoolsGateway::receive_message( + GatewayOrigin::Domain(domain_address).into(), + router_id, + BoundedVec::::try_from(encoded_msg).unwrap() + ), + Error::::InvalidMessageOrigin, + ); }); + } - assert_ok!(LiquidityPoolsGateway::receive_message( - GatewayOrigin::Domain(domain_address).into(), - BoundedVec::::try_from(encoded_msg).unwrap() - )); - }); - } + #[test] + fn unknown_instance() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + let encoded_msg = Message::Simple.serialize(); + let router_id = ROUTER_ID_1; + + assert_noop!( + LiquidityPoolsGateway::receive_message( + GatewayOrigin::Domain(domain_address).into(), + router_id, + BoundedVec::::try_from(encoded_msg).unwrap() + ), + Error::::UnknownInstance, + ); + }); + } - #[test] - fn bad_origin() { - new_test_ext().execute_with(|| { - let encoded_msg = Message::Simple.serialize(); + #[test] + fn message_queue_error() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(0, address.into()); + let message = Message::Simple; - assert_noop!( - LiquidityPoolsGateway::receive_message( - RuntimeOrigin::signed(AccountId32::new([0u8; 32])), - BoundedVec::::try_from(encoded_msg).unwrap() - ), - BadOrigin, - ); - }); - } + let router_id = ROUTER_ID_1; - #[test] - fn invalid_message_origin() { - new_test_ext().execute_with(|| { - let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); - let encoded_msg = Message::Simple.serialize(); + assert_ok!(LiquidityPoolsGateway::add_instance( + RuntimeOrigin::root(), + domain_address.clone(), + )); - assert_noop!( - LiquidityPoolsGateway::receive_message( - GatewayOrigin::Domain(domain_address).into(), - BoundedVec::::try_from(encoded_msg).unwrap() - ), - Error::::InvalidMessageOrigin, - ); - }); + let encoded_msg = message.serialize(); + + let err = sp_runtime::DispatchError::from("liquidity_pools error"); + + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_id.clone(), + }; + + MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { + assert_eq!(mock_message, gateway_message); + Err(err) + }); + + assert_noop!( + LiquidityPoolsGateway::receive_message( + GatewayOrigin::Domain(domain_address).into(), + router_id, + BoundedVec::::try_from(encoded_msg).unwrap() + ), + err, + ); + }); + } } - #[test] - fn unknown_instance() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); - let encoded_msg = Message::Simple.serialize(); + mod set_domain_hook { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); - assert_noop!( - LiquidityPoolsGateway::receive_message( - GatewayOrigin::Domain(domain_address).into(), - BoundedVec::::try_from(encoded_msg).unwrap() - ), - Error::::UnknownInstance, - ); - }); + assert_ok!(LiquidityPoolsGateway::set_domain_hook_address( + RuntimeOrigin::root(), + domain, + get_test_hook_bytes() + )); + }); + } + + #[test] + fn bad_origin() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + + assert_noop!( + LiquidityPoolsGateway::set_domain_hook_address( + RuntimeOrigin::signed(AccountId32::new([0u8; 32])), + domain, + get_test_hook_bytes() + ), + BadOrigin + ); + }); + } + + #[test] + fn domain_not_supported() { + new_test_ext().execute_with(|| { + let domain = Domain::Centrifuge; + + assert_noop!( + LiquidityPoolsGateway::set_domain_hook_address( + RuntimeOrigin::root(), + domain, + get_test_hook_bytes() + ), + Error::::DomainNotSupported + ); + }); + } } - #[test] - fn message_queue_error() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); - let message = Message::Simple; + mod batches { + use super::*; + + const USER: AccountId32 = AccountId32::new([1; 32]); + const OTHER: AccountId32 = AccountId32::new([2; 32]); + const DOMAIN: Domain = Domain::EVM(TEST_EVM_CHAIN); - assert_ok!(LiquidityPoolsGateway::add_instance( - RuntimeOrigin::root(), - domain_address.clone(), - )); + #[test] + fn pack_empty() { + new_test_ext().execute_with(|| { + assert_ok!(LiquidityPoolsGateway::start_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); + assert_ok!(LiquidityPoolsGateway::end_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); + }); + } + + #[test] + fn pack_several() { + new_test_ext().execute_with(|| { + assert_ok!(LiquidityPoolsGateway::start_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); + + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(|_| Ok(())); - let encoded_msg = message.serialize(); + // Ok Batched + assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + + Routers::::set(BoundedVec::try_from(vec![ROUTER_ID_1]).unwrap()); + + // Not batched, it belongs to OTHER + assert_ok!(LiquidityPoolsGateway::handle( + OTHER, + DOMAIN, + Message::Simple + )); + + // Not batched, it belongs to EVM 2 + assert_ok!(LiquidityPoolsGateway::handle( + USER, + Domain::EVM(2), + Message::Simple + )); - let err = sp_runtime::DispatchError::from("liquidity_pools error"); + // Ok Batched + assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + + // Two non-packed messages + assert_eq!(handler.times(), 2); - let gateway_message = GatewayMessage::::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - }; + assert_ok!(LiquidityPoolsGateway::end_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { - assert_eq!(mock_message, gateway_message); - Err(err) + // Packed message queued + assert_eq!(handler.times(), 3); }); + } - assert_noop!( - LiquidityPoolsGateway::receive_message( - GatewayOrigin::Domain(domain_address).into(), - BoundedVec::::try_from(encoded_msg).unwrap() - ), - err, - ); - }); - } -} + #[test] + fn pack_over_limit() { + new_test_ext().execute_with(|| { + assert_ok!(LiquidityPoolsGateway::start_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); -mod outbound_message_handler_impl { - use super::*; + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(|_| Ok(())); - #[test] - fn success() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let sender = get_test_account_id(); - let msg = Message::Simple; + (0..MAX_PACKED_MESSAGES).for_each(|_| { + assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + }); - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); + assert_err!( + LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple), + DispatchError::Other(MAX_PACKED_MESSAGES_ERR) + ); - assert_ok!(LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router.clone(), - )); + let router_id_1 = ROUTER_ID_1; - let gateway_message = GatewayMessage::::Outbound { - sender: ::Sender::get(), - destination: domain.clone(), - message: msg.clone(), - }; + Routers::::set(BoundedVec::try_from(vec![router_id_1]).unwrap()); - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { - assert_eq!(mock_msg, gateway_message); + assert_ok!(LiquidityPoolsGateway::end_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); + assert_eq!(handler.times(), 1); + }); + } - Ok(()) + #[test] + fn end_before_start() { + new_test_ext().execute_with(|| { + assert_err!( + LiquidityPoolsGateway::end_batch_message(RuntimeOrigin::signed(USER), DOMAIN), + Error::::MessagePackingNotStarted + ); }); + } - assert_ok!(LiquidityPoolsGateway::handle(sender, domain, msg)); - }); - } + #[test] + fn start_before_end() { + new_test_ext().execute_with(|| { + assert_ok!(LiquidityPoolsGateway::start_batch_message( + RuntimeOrigin::signed(USER), + DOMAIN + )); + + assert_err!( + LiquidityPoolsGateway::start_batch_message(RuntimeOrigin::signed(USER), DOMAIN), + Error::::MessagePackingAlreadyStarted + ); + }); + } - #[test] - fn local_domain() { - new_test_ext().execute_with(|| { - let domain = Domain::Centrifuge; - let sender = get_test_account_id(); - let msg = Message::Simple; - - assert_noop!( - LiquidityPoolsGateway::handle(sender, domain, msg), - Error::::DomainNotSupported - ); - }); - } + #[test] + fn process_inbound() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(TEST_EVM_CHAIN, address.into()); - #[test] - fn message_queue_error() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let sender = get_test_account_id(); - let msg = Message::Simple; + let router_id_1 = ROUTER_ID_1; - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); + Routers::::set(BoundedVec::try_from(vec![router_id_1]).unwrap()); + SessionIdStore::::set(1); - assert_ok!(LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router.clone(), - )); + let handler = MockLiquidityPools::mock_handle(|_, _| Ok(())); - let gateway_message = GatewayMessage::::Outbound { - sender: ::Sender::get(), - destination: domain.clone(), - message: msg.clone(), - }; + let submessage_count = 5; - let err = DispatchError::Unavailable; + let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { + domain_address, + message: Message::deserialize(&(1..=submessage_count).collect::>()) + .unwrap(), + router_id: ROUTER_ID_1, + }); - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { - assert_eq!(mock_msg, gateway_message); + let expected_weight = LP_DEFENSIVE_WEIGHT.saturating_mul(submessage_count.into()); - Err(err) + assert_ok!(result); + assert_eq!(weight, expected_weight); + assert_eq!(handler.times(), submessage_count as u32); }); + } - assert_noop!(LiquidityPoolsGateway::handle(sender, domain, msg), err); - }); - } -} + #[test] + fn process_inbound_with_errors() { + new_test_ext().execute_with(|| { + let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); + let domain_address = DomainAddress::EVM(1, address.into()); -mod set_domain_hook { - use super::*; + let router_id_1 = ROUTER_ID_1; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - - assert_ok!(LiquidityPoolsGateway::set_domain_hook_address( - RuntimeOrigin::root(), - domain, - get_test_hook_bytes() - )); - }); - } + Routers::::set(BoundedVec::try_from(vec![router_id_1]).unwrap()); + SessionIdStore::::set(1); - #[test] - fn failure_bad_origin() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); + let counter = Arc::new(AtomicU32::new(0)); - assert_noop!( - LiquidityPoolsGateway::set_domain_hook_address( - RuntimeOrigin::signed(AccountId32::new([0u8; 32])), - domain, - get_test_hook_bytes() - ), - BadOrigin - ); - }); + let handler = MockLiquidityPools::mock_handle(move |_, _| { + match counter.fetch_add(1, Ordering::Relaxed) { + 2 => Err(DispatchError::Unavailable), + _ => Ok(()), + } + }); + + let (result, _) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { + domain_address, + message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + router_id: ROUTER_ID_1, + }); + + assert_err!(result, DispatchError::Unavailable); + // 2 correct messages and 1 failed message processed. + assert_eq!(handler.times(), 3); + }); + } } - #[test] - fn failure_centrifuge_domain() { - new_test_ext().execute_with(|| { - let domain = Domain::Centrifuge; + mod execute_message_recovery { + use super::*; - assert_noop!( - LiquidityPoolsGateway::set_domain_hook_address( + #[test] + fn success_with_execution() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]).unwrap(), + ); + SessionIdStore::::set(session_id); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + InboundEntry::Message(MessageEntry { + session_id, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ); + + let handler = + MockLiquidityPools::mock_handle(move |mock_domain_address, mock_message| { + assert_eq!(mock_domain_address, TEST_DOMAIN_ADDRESS); + assert_eq!(mock_message, Message::Simple); + + Ok(()) + }); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), - domain, - get_test_hook_bytes() - ), - Error::::DomainNotSupported - ); - }); + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_2, + )); + + event_exists(Event::::MessageRecoveryExecuted { + proof: MESSAGE_PROOF, + router_id: ROUTER_ID_2, + }); + + assert_eq!(handler.times(), 1); + + assert!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).is_none() + ); + assert!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_2).is_none() + ); + }); + } + + #[test] + fn success_without_execution() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]).unwrap(), + ); + SessionIdStore::::set(session_id); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + InboundEntry::Message(MessageEntry { + session_id, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_2, + )); + + event_exists(Event::::MessageRecoveryExecuted { + proof: MESSAGE_PROOF, + router_id: ROUTER_ID_2, + }); + + assert_eq!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1), + Some( + MessageEntry { + session_id, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into() + ) + ); + assert_eq!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_2), + Some( + ProofEntry { + session_id, + current_count: 1 + } + .into() + ) + ); + assert!(PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_3).is_none()) + }); + } + + #[test] + fn not_enough_routers_for_domain() { + new_test_ext().execute_with(|| { + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_1, + ), + Error::::NotEnoughRoutersForDomain + ); + + Routers::::set(BoundedVec::try_from(vec![ROUTER_ID_1]).unwrap()); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_1, + ), + Error::::NotEnoughRoutersForDomain + ); + }); + } + + #[test] + fn unknown_router() { + new_test_ext().execute_with(|| { + Routers::::set(BoundedVec::try_from(vec![ROUTER_ID_1]).unwrap()); + SessionIdStore::::set(1); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_2 + ), + Error::::UnknownRouter + ); + }); + } + + #[test] + fn proof_count_overflow() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]).unwrap(), + ); + SessionIdStore::::set(session_id); + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_2, + InboundEntry::Proof(ProofEntry { + session_id, + current_count: u32::MAX, + }), + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_2 + ), + Arithmetic(Overflow) + ); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let session_id = 1; + + Routers::::set( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]).unwrap(), + ); + SessionIdStore::::set(session_id); + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_2, + InboundEntry::Message(MessageEntry { + session_id, + domain_address: domain_address.clone(), + message: Message::Simple, + expected_proof_count: 2, + }), + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, + MESSAGE_PROOF, + ROUTER_ID_2 + ), + Error::::ExpectedMessageProofType + ); + }); + } } } -mod message_processor_impl { +mod implementations { use super::*; - mod inbound { + mod outbound_message_handler { use super::*; #[test] fn success() { new_test_ext().execute_with(|| { - let domain_address = DomainAddress::EVM(1, [1; 20]); - let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - }; + let domain = Domain::EVM(0); + let sender = get_test_account_id(); + let msg = Message::Simple; + let message_proof = msg.to_proof_message().get_proof().unwrap(); - MockLiquidityPools::mock_handle(move |mock_domain_address, mock_mesage| { - assert_eq!(mock_domain_address, domain_address); - assert_eq!(mock_mesage, message); + assert_ok!(LiquidityPoolsGateway::set_routers( + RuntimeOrigin::root(), + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]).unwrap(), + )); + + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { + match mock_msg { + GatewayMessage::Inbound { .. } => { + assert!(false, "expected outbound message") + } + GatewayMessage::Outbound { + sender, message, .. + } => { + assert_eq!(sender, ::Sender::get()); + + match message { + Message::Proof(p) => { + assert_eq!(p, message_proof); + } + _ => {} + } + } + } Ok(()) }); - let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); + assert_ok!(LiquidityPoolsGateway::handle(sender, domain, msg)); + assert_eq!(handler.times(), 3); }); } #[test] - fn inbound_message_handler_error() { + fn domain_not_supported() { new_test_ext().execute_with(|| { - let domain_address = DomainAddress::EVM(1, [1; 20]); - let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), + let domain = Domain::Centrifuge; + let sender = get_test_account_id(); + let msg = Message::Simple; + + assert_noop!( + LiquidityPoolsGateway::handle(sender, domain, msg), + Error::::DomainNotSupported + ); + }); + } + + #[test] + fn routers_not_found() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let sender = get_test_account_id(); + let msg = Message::Simple; + + assert_noop!( + LiquidityPoolsGateway::handle(sender, domain, msg), + Error::::NotEnoughRoutersForDomain + ); + }); + } + + #[test] + fn message_queue_error() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let sender = get_test_account_id(); + let msg = Message::Simple; + + assert_ok!(LiquidityPoolsGateway::set_routers( + RuntimeOrigin::root(), + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]).unwrap(), + )); + + let gateway_message = GatewayMessage::Outbound { + sender: ::Sender::get(), + message: msg.clone(), + router_id: ROUTER_ID_1, }; let err = DispatchError::Unavailable; - MockLiquidityPools::mock_handle(move |mock_domain_address, mock_mesage| { - assert_eq!(mock_domain_address, domain_address); - assert_eq!(mock_mesage, message); + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { + assert_eq!(mock_msg, gateway_message); Err(err) }); - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, err); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT); + assert_noop!(LiquidityPoolsGateway::handle(sender, domain, msg), err); + assert_eq!(handler.times(), 1); }); } } - mod outbound { + mod message_processor { use super::*; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let sender = get_test_account_id(); - let domain = Domain::EVM(1); - let message = Message::Simple; + mod inbound { + use super::*; + + #[macro_use] + mod util { + use super::*; + + pub fn run_inbound_message_test_suite(suite: InboundMessageTestSuite) { + let test_routers = suite.routers; + + for test in suite.tests { + println!("Executing test for - {:?}", test.router_messages); + + new_test_ext().execute_with(|| { + let session_id = TEST_SESSION_ID; + + Routers::::set( + BoundedVec::try_from(test_routers.clone()).unwrap(), + ); + SessionIdStore::::set( + session_id, + ); + + let handler = MockLiquidityPools::mock_handle(move |_, _| Ok(())); + + for router_message in test.router_messages { + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: router_message.1, + router_id: router_message.0, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + } + + let expected_message_submitted_times = + test.expected_test_result.message_submitted_times; + let message_submitted_times = handler.times(); + + assert_eq!( + message_submitted_times, + expected_message_submitted_times, + "Expected message to be submitted {expected_message_submitted_times} times, was {message_submitted_times}" + ); + + for expected_storage_entry in + test.expected_test_result.expected_storage_entries + { + let expected_storage_entry_router_id = expected_storage_entry.0; + let expected_inbound_entry = expected_storage_entry.1; + + let storage_entry = PendingInboundEntries::::get( + MESSAGE_PROOF, expected_storage_entry_router_id, + ); + assert_eq!(storage_entry, expected_inbound_entry, "Expected inbound entry {expected_inbound_entry:?}, found {storage_entry:?}"); + } + }); + } + } - let expected_sender = sender.clone(); - let expected_message = message.clone(); + /// Used for generating all `RouterMessage` combinations like: + /// + /// vec![ + /// (ROUTER_ID_1, Message::Simple), + /// (ROUTER_ID_1, Message::Simple), + /// ] + /// vec![ + /// (ROUTER_ID_1, Message::Simple), + /// (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + /// ] + /// vec![ + /// (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + /// (ROUTER_ID_1, Message::Simple), + /// ] + /// vec![ + /// (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + /// (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + /// ] + pub fn generate_test_combinations( + t: T, + count: usize, + ) -> Vec::Item>> + where + T: IntoIterator + Clone, + T::IntoIter: Clone, + T::Item: Clone, + { + std::iter::repeat(t.clone().into_iter()) + .take(count) + .multi_cartesian_product() + .collect::>() + } - let router_post_info = PostDispatchInfo { - actual_weight: Some(Weight::from_parts(1, 1)), - pays_fee: Pays::Yes, - }; + /// Type used for mapping a message to a router hash. + pub type RouterMessage = (RouterId, Message); + + /// Type used for aggregating tests for inbound messages. + pub struct InboundMessageTestSuite { + pub routers: Vec, + pub tests: Vec, + } + + /// Type used for defining a test which contains a set of + /// `RouterMessage` combinations and the expected test result. + pub struct InboundMessageTest { + pub router_messages: Vec, + pub expected_test_result: ExpectedTestResult, + } + + /// Type used for defining the number of expected inbound + /// message submission and the exected storage state. + #[derive(Clone, Debug)] + pub struct ExpectedTestResult { + pub message_submitted_times: u32, + pub expected_storage_entries: Vec<(RouterId, Option>)>, + } + + /// Generates the combinations of `RouterMessage` used when + /// testing, maps the `ExpectedTestResult` for each and + /// creates the `InboundMessageTestSuite`. + pub fn generate_test_suite( + routers: Vec, + test_data: Vec, + expected_results: HashMap, ExpectedTestResult>, + message_count: usize, + ) -> InboundMessageTestSuite { + let tests = generate_test_combinations(test_data, message_count); + + let tests = tests + .into_iter() + .map(|router_messages| { + let expected_test_result = expected_results + .get(&router_messages) + .expect( + format!("test for {router_messages:?} should be covered") + .as_str(), + ) + .clone(); + + InboundMessageTest { + router_messages, + expected_test_result, + } + }) + .collect::>(); + + InboundMessageTestSuite { routers, tests } + } + } + + use util::*; + + mod one_router { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let message_proof = message.to_proof_message().get_proof().unwrap(); + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_id = ROUTER_ID_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_id.clone(), + }; + + Routers::::set( + BoundedVec::<_, _>::try_from(vec![router_id.clone()]).unwrap(), + ); + SessionIdStore::::set(session_id); + + let handler = MockLiquidityPools::mock_handle( + move |mock_domain_address, mock_message| { + assert_eq!(mock_domain_address, domain_address); + assert_eq!(mock_message, message); + + Ok(()) + }, + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + assert_eq!(handler.times(), 1); + + assert!( + PendingInboundEntries::::get(message_proof, router_id) + .is_none() + ); + }); + } + + #[test] + fn multi_router_not_found() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = ROUTER_ID_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_hash, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::NotEnoughRoutersForDomain); + }); + } + + #[test] + fn unknown_inbound_message_router() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = ROUTER_ID_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + // The router stored has a different hash, this should trigger the + // expected error. + router_id: ROUTER_ID_2, + }; + + Routers::::set( + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + SessionIdStore::::set(session_id); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::UnknownRouter); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let message_proof = message.to_proof_message().get_proof().unwrap(); + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_id = ROUTER_ID_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_id.clone(), + }; + + Routers::::set( + BoundedVec::<_, _>::try_from(vec![router_id.clone()]).unwrap(), + ); + SessionIdStore::::set(session_id); + PendingInboundEntries::::insert( + message_proof, + router_id, + InboundEntry::Proof(ProofEntry { + session_id, + current_count: 0, + }), + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ExpectedMessageProofType); + }); + } + } + + mod two_routers { + use super::*; + + mod success { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![ROUTER_ID_1, ROUTER_ID_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 3, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 3, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![ROUTER_ID_1, ROUTER_ID_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod four_messages { + use super::*; + + const MESSAGE_COUNT: usize = 4; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 4, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![ROUTER_ID_1, ROUTER_ID_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + } + + mod failure { + use super::*; + + #[test] + fn message_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set( + BoundedVec::<_, _>::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]) + .unwrap(), + ); + SessionIdStore::::set(session_id); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + router_id: ROUTER_ID_2, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::MessageExpectedFromFirstRouter); + }); + } + + #[test] + fn proof_not_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set( + BoundedVec::<_, _>::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]) + .unwrap(), + ); + SessionIdStore::::set(session_id); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Proof(MESSAGE_PROOF), + router_id: ROUTER_ID_1, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ProofNotExpectedFromFirstRouter); + }); + } + } + } + + mod three_routers { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 6, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 3, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 3, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + ROUTER_ID_1, + Some( + MessageEntry { + session_id: TEST_SESSION_ID, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into(), + ), + ), + (ROUTER_ID_2, None), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_1, Message::Simple), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + (ROUTER_ID_2, None), + (ROUTER_ID_3, None), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ], + }, + ), + ( + vec![ + (ROUTER_ID_2, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + (ROUTER_ID_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (ROUTER_ID_1, None), + ( + ROUTER_ID_2, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 1, + } + .into(), + ), + ), + ( + ROUTER_ID_3, + Some( + ProofEntry { + session_id: TEST_SESSION_ID, + current_count: 2, + } + .into(), + ), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + } + + #[test] + fn inbound_message_handler_error() { + new_test_ext().execute_with(|| { + let domain_address = DomainAddress::EVM(1, [1; 20]); + + Routers::::set( + BoundedVec::try_from(vec![ROUTER_ID_1.clone()]).unwrap(), + ); + SessionIdStore::::set(1); + + let message = Message::Simple; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: ROUTER_ID_1, + }; + + let err = DispatchError::Unavailable; - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); + MockLiquidityPools::mock_handle(move |mock_domain_address, mock_mesage| { + assert_eq!(mock_domain_address, domain_address); + assert_eq!(mock_mesage, message); - Ok(router_post_info) + Err(err) + }); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, err); + }); + } + } + + mod outbound { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let sender = TEST_DOMAIN_ADDRESS; + let message = Message::Simple; + + let gateway_message = GatewayMessage::Outbound { + sender: sender.clone(), + message: message.clone(), + router_id: ROUTER_ID_1, + }; + + let handler = MockMessageSender::mock_send( + move |mock_router_id, mock_sender, mock_message| { + assert_eq!(mock_router_id, ROUTER_ID_1); + assert_eq!(mock_sender, sender); + assert_eq!(mock_message, message.serialize()); + + Ok(()) + }, + ); + + let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + assert!(weight.eq(&LP_DEFENSIVE_WEIGHT)); + assert_eq!(handler.times(), 1); + }); + } + + #[test] + fn message_sender_error() { + new_test_ext().execute_with(|| { + let sender = TEST_DOMAIN_ADDRESS; + let message = Message::Simple; + + let gateway_message = GatewayMessage::Outbound { + sender: sender.clone(), + message: message.clone(), + router_id: ROUTER_ID_1, + }; + + let router_err = DispatchError::Unavailable; + + MockMessageSender::mock_send( + move |mock_router_id, mock_sender, mock_message| { + assert_eq!(mock_router_id, ROUTER_ID_1); + assert_eq!(mock_sender, sender); + assert_eq!(mock_message, message.serialize()); + + Err(router_err) + }, + ); + + let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, router_err); + assert!(weight.eq(&LP_DEFENSIVE_WEIGHT)); + }); + } + } + } + + mod pallet { + use super::*; + + mod get_router_ids_for_domain { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let domain = TEST_DOMAIN_ADDRESS.domain(); + let test_routers = vec![ROUTER_ID_1]; + + Routers::::set(BoundedVec::try_from(test_routers.clone()).unwrap()); + + let res = LiquidityPoolsGateway::get_router_ids_for_domain(domain).unwrap(); + assert_eq!(res, test_routers); + }); + } + + #[test] + fn not_enough_routers_for_domain() { + new_test_ext().execute_with(|| { + let domain = TEST_DOMAIN_ADDRESS.domain(); + + let res = LiquidityPoolsGateway::get_router_ids_for_domain(domain.clone()); + + assert_eq!( + res.err().unwrap(), + Error::::NotEnoughRoutersForDomain.into() + ); + + let test_routers = vec![RouterId(4)]; + + Routers::::set(BoundedVec::try_from(test_routers.clone()).unwrap()); + + let res = LiquidityPoolsGateway::get_router_ids_for_domain(domain); + + assert_eq!( + res.err().unwrap(), + Error::::NotEnoughRoutersForDomain.into() + ); + }); + } + } + + mod get_expected_proof_count { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let tests = vec![ + vec![ROUTER_ID_1], + vec![ROUTER_ID_1, ROUTER_ID_2], + vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3], + ]; + + for test in tests { + let res = LiquidityPoolsGateway::get_expected_proof_count(&test).unwrap(); + + assert_eq!(res, (test.len() - 1) as u32); + } + }); + } + + #[test] + fn not_enough_routers_for_domain() { + new_test_ext().execute_with(|| { + let res = LiquidityPoolsGateway::get_expected_proof_count(&vec![]); + + assert_eq!( + res.err().unwrap(), + Error::::NotEnoughRoutersForDomain.into() + ); }); + } + } + + mod create_inbound_entry { + use super::*; + + #[test] + fn create_inbound_entry() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let session_id = 1; + let expected_proof_count = 2; + + let tests: Vec<(Message, InboundEntry)> = vec![ + ( + Message::Simple, + MessageEntry { + session_id, + domain_address: domain_address.clone(), + message: Message::Simple, + expected_proof_count, + } + .into(), + ), + ( + Message::Proof(MESSAGE_PROOF), + ProofEntry { + session_id, + current_count: 1, + } + .into(), + ), + ]; + + for (test_message, expected_inbound_entry) in tests { + let res = InboundEntry::create( + test_message, + session_id, + domain_address.clone(), + expected_proof_count, + ); + + assert_eq!(res, expected_inbound_entry) + } + }); + } + } - DomainRouters::::insert(domain.clone(), router_mock); + mod upsert_pending_entry { + use super::*; + + #[test] + fn no_stored_entry() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let session_id = 1; + let expected_proof_count = 2; + + let tests: Vec<(RouterId, InboundEntry)> = vec![ + ( + ROUTER_ID_1, + MessageEntry { + session_id, + domain_address, + message: Message::Simple, + expected_proof_count, + } + .into(), + ), + ( + ROUTER_ID_2, + ProofEntry { + session_id, + current_count: 1, + } + .into(), + ), + ]; + + for (test_router_id, test_inbound_entry) in tests { + assert_ok!(LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &test_router_id.clone(), + test_inbound_entry.clone(), + )); + + let res = + PendingInboundEntries::::get(MESSAGE_PROOF, test_router_id) + .unwrap(); + + assert_eq!(res, test_inbound_entry); + } + }); + } + + #[test] + fn message_entry_same_session() { + new_test_ext().execute_with(|| { + let inbound_entry: InboundEntry = MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + inbound_entry.clone(), + ); + + assert_ok!(LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + inbound_entry, + )); + + let res = + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).unwrap(); + assert_eq!( + res, + MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + } + .into() + ); + }); + } + + #[test] + fn message_entry_new_session() { + new_test_ext().execute_with(|| { + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + InboundEntry::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ); + + assert_ok!(LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + MessageEntry { + session_id: 2, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(), + )); + + let res = + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).unwrap(); + assert_eq!( + res, + MessageEntry { + session_id: 2, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into() + ); + }); + } + + #[test] + fn expected_message_type() { + new_test_ext().execute_with(|| { + let inbound_entry: InboundEntry = MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + } + .into(); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + inbound_entry.clone(), + ); + + assert_noop!( + LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + InboundEntry::Proof(ProofEntry { + session_id: 1, + current_count: 1 + }), + ), + Error::::ExpectedMessageType + ); + }); + } + + #[test] + fn proof_entry_same_session() { + new_test_ext().execute_with(|| { + let inbound_entry: InboundEntry = ProofEntry { + session_id: 1, + current_count: 1, + } + .into(); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + inbound_entry.clone(), + ); + + assert_ok!(LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + inbound_entry, + )); + + let res = + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).unwrap(); + assert_eq!( + res, + ProofEntry { + session_id: 1, + current_count: 2, + } + .into() + ); + }); + } + + #[test] + fn proof_entry_new_session() { + new_test_ext().execute_with(|| { + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + InboundEntry::Proof(ProofEntry { + session_id: 1, + current_count: 2, + }), + ); + + assert_ok!(LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + ProofEntry { + session_id: 2, + current_count: 1, + } + .into(), + )); + + let res = + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).unwrap(); + assert_eq!( + res, + ProofEntry { + session_id: 2, + current_count: 1, + } + .into() + ); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let inbound_entry: InboundEntry = ProofEntry { + session_id: 1, + current_count: 1, + } + .into(); + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + inbound_entry.clone(), + ); + + assert_noop!( + LiquidityPoolsGateway::upsert_pending_entry( + MESSAGE_PROOF, + &ROUTER_ID_1, + InboundEntry::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + Error::::ExpectedMessageProofType + ); + }); + } + } - let min_expected_weight = ::DbWeight::get() - .reads(1) + router_post_info.actual_weight.unwrap() - + Weight::from_parts(0, message.serialize().len() as u64); + mod execute_if_requirements_are_met { + use super::*; + + #[test] + fn entries_with_invalid_session_are_ignored() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let router_ids = vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]; + let session_id = 1; + let expected_proof_count = 2; + + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_1, + InboundEntry::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ); + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_2, + InboundEntry::Proof(ProofEntry { + session_id: 2, + current_count: 1, + }), + ); + PendingInboundEntries::::insert( + MESSAGE_PROOF, + ROUTER_ID_3, + InboundEntry::Proof(ProofEntry { + session_id: 3, + current_count: 1, + }), + ); + + assert_ok!(LiquidityPoolsGateway::execute_if_requirements_are_met( + MESSAGE_PROOF, + &router_ids, + session_id, + expected_proof_count, + domain_address, + )); + assert!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_1).is_some() + ); + assert!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_2).is_some() + ); + assert!( + PendingInboundEntries::::get(MESSAGE_PROOF, ROUTER_ID_3).is_some() + ); + }); + } + } - let gateway_message = GatewayMessage::::Outbound { - sender, - destination: domain, + mod execute_post_voting_dispatch { + use super::*; + + #[test] + fn pending_inbound_entry_not_found() { + new_test_ext().execute_with(|| { + let router_ids = vec![ROUTER_ID_1]; + let expected_proof_count = 2; + + assert_noop!( + LiquidityPoolsGateway::execute_post_voting_dispatch( + MESSAGE_PROOF, + &router_ids, + expected_proof_count, + ), + Error::::PendingInboundEntryNotFound + ); + }); + } + } + } +} + +mod inbound_entry { + use super::*; + + mod create_post_voting_entry { + use super::*; + + #[test] + fn message_entry_some() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + + let inbound_entry = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, message: message.clone(), - }; + expected_proof_count: 4, + }); - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); - assert!(weight.all_lte(min_expected_weight)); + let expected_proof_count = 2; + + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count) + .unwrap(); + + assert_eq!( + res, + Some(InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message, + expected_proof_count: 2, + })) + ); }); } #[test] - fn router_not_found() { + fn message_entry_count_underflow() { new_test_ext().execute_with(|| { - let sender = get_test_account_id(); - let domain = Domain::EVM(1); let message = Message::Simple; - let expected_weight = ::DbWeight::get().reads(1); + let inbound_entry = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: message.clone(), + expected_proof_count: 2, + }); + + let expected_proof_count = 3; - let gateway_message = GatewayMessage::::Outbound { - sender, - destination: domain, - message, - }; + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count); - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, Error::::RouterNotFound); - assert_eq!(weight, expected_weight); + assert_noop!(res, Arithmetic(Underflow)); }); } #[test] - fn router_error() { + fn message_entry_zero_updated_count() { new_test_ext().execute_with(|| { - let sender = get_test_account_id(); - let domain = Domain::EVM(1); let message = Message::Simple; - let expected_sender = sender.clone(); - let expected_message = message.clone(); + let inbound_entry = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: message.clone(), + expected_proof_count: 2, + }); - let router_post_info = PostDispatchInfo { - actual_weight: Some(Weight::from_parts(1, 1)), - pays_fee: Pays::Yes, - }; + let expected_proof_count = 2; - let router_err = DispatchError::Unavailable; + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count) + .unwrap(); - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); + assert_eq!(res, None); + }); + } - Err(DispatchErrorWithPostInfo { - post_info: router_post_info, - error: router_err, - }) + #[test] + fn proof_entry_some() { + new_test_ext().execute_with(|| { + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 2, }); - DomainRouters::::insert(domain.clone(), router_mock); + let expected_proof_count = 2; - let min_expected_weight = ::DbWeight::get() - .reads(1) + router_post_info.actual_weight.unwrap() - + Weight::from_parts(0, message.serialize().len() as u64); + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count) + .unwrap(); - let gateway_message = GatewayMessage::::Outbound { - sender, - destination: domain, - message: message.clone(), - }; + assert_eq!( + res, + Some(InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 1 + })) + ); + }); + } + + #[test] + fn proof_entry_count_underflow() { + new_test_ext().execute_with(|| { + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 0, + }); + + let expected_proof_count = 2; - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, router_err); - assert!(weight.all_lte(min_expected_weight)); + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count); + + assert_noop!(res, Arithmetic(Underflow)); }); } - } -} -mod batches { - use super::*; + #[test] + fn proof_entry_zero_updated_count() { + new_test_ext().execute_with(|| { + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 1, + }); - const USER: AccountId32 = AccountId32::new([1; 32]); - const OTHER: AccountId32 = AccountId32::new([2; 32]); - const DOMAIN: Domain = Domain::EVM(1); - - #[test] - fn pack_empty() { - new_test_ext().execute_with(|| { - assert_ok!(LiquidityPoolsGateway::start_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - assert_ok!(LiquidityPoolsGateway::end_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - }); - } + let expected_proof_count = 2; + + let res = + InboundEntry::create_post_voting_entry(&inbound_entry, expected_proof_count) + .unwrap(); - #[test] - fn pack_several() { - new_test_ext().execute_with(|| { - assert_ok!(LiquidityPoolsGateway::start_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - - let handle = MockLiquidityPoolsGatewayQueue::mock_submit(|_| Ok(())); - - // Ok Batched - assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); - - // Not batched, it belong to OTHER - assert_ok!(LiquidityPoolsGateway::handle( - OTHER, - DOMAIN, - Message::Simple - )); - - // Not batched, it belong to EVM 2 - assert_ok!(LiquidityPoolsGateway::handle( - USER, - Domain::EVM(2), - Message::Simple - )); - - // Ok Batched - assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); - - // Just the two non-packed messages - assert_eq!(handle.times(), 2); - - assert_ok!(LiquidityPoolsGateway::end_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - - // Packed message queued - assert_eq!(handle.times(), 3); - }); + assert_eq!(res, None,); + }); + } } - #[test] - fn pack_over_limit() { - new_test_ext().execute_with(|| { - assert_ok!(LiquidityPoolsGateway::start_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); + mod validate { + use super::*; - MockLiquidityPoolsGatewayQueue::mock_submit(|_| Ok(())); + #[test] + fn success() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let router_ids = vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]; + let session_id = 1; + let expected_proof_count = 2; + + let inbound_entry = InboundEntry::::Message(MessageEntry { + session_id, + domain_address, + message: Message::Simple, + expected_proof_count, + }); - (0..MAX_PACKED_MESSAGES).for_each(|_| { - assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + assert_ok!(inbound_entry.validate(&router_ids, &ROUTER_ID_1)); + + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + }); + + assert_ok!(inbound_entry.validate(&router_ids, &ROUTER_ID_2)); }); + } - assert_err!( - LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple), - DispatchError::Other(MAX_PACKED_MESSAGES_ERR) - ); + #[test] + fn unknown_router() { + new_test_ext().execute_with(|| { + let router_ids = vec![ROUTER_ID_1, ROUTER_ID_2]; + let session_id = 1; - assert_ok!(LiquidityPoolsGateway::end_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - }); - } + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + }); - #[test] - fn end_before_start() { - new_test_ext().execute_with(|| { - assert_err!( - LiquidityPoolsGateway::end_batch_message(RuntimeOrigin::signed(USER), DOMAIN), - Error::::MessagePackingNotStarted - ); - }); - } + assert_noop!( + inbound_entry.validate(&router_ids, &ROUTER_ID_3), + Error::::UnknownRouter + ); + }); + } + + #[test] + fn message_type_mismatch() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let router_ids = vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]; + let session_id = 1; + let expected_proof_count = 2; + + let inbound_entry = InboundEntry::::Message(MessageEntry { + session_id, + domain_address, + message: Message::Simple, + expected_proof_count, + }); + + assert_noop!( + inbound_entry.validate(&router_ids, &ROUTER_ID_2), + Error::::MessageExpectedFromFirstRouter + ); - #[test] - fn start_before_end() { - new_test_ext().execute_with(|| { - assert_ok!(LiquidityPoolsGateway::start_batch_message( - RuntimeOrigin::signed(USER), - DOMAIN - )); - - assert_err!( - LiquidityPoolsGateway::start_batch_message(RuntimeOrigin::signed(USER), DOMAIN), - Error::::MessagePackingAlreadyStarted - ); - }); + let inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + }); + + assert_noop!( + inbound_entry.validate(&router_ids, &ROUTER_ID_1), + Error::::ProofNotExpectedFromFirstRouter + ); + }); + } } - #[test] - fn process_inbound() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); + mod increment_proof_count { + use super::*; + + #[test] + fn success_same_session() { + new_test_ext().execute_with(|| { + let session_id = 1; + let mut inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + }); + + assert_ok!(inbound_entry.increment_proof_count(session_id)); + assert_eq!( + inbound_entry, + InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 2, + }) + ); + }); + } + + #[test] + fn success_new_session() { + new_test_ext().execute_with(|| { + let session_id = 1; + let mut inbound_entry = InboundEntry::::Proof(ProofEntry { + session_id, + current_count: 1, + }); - MockLiquidityPools::mock_handle(|_, _| Ok(())); + let new_session_id = session_id + 1; - let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { - domain_address, - message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + assert_ok!(inbound_entry.increment_proof_count(new_session_id)); + assert_eq!( + inbound_entry, + InboundEntry::::Proof(ProofEntry { + session_id: new_session_id, + current_count: 1, + }) + ); }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let mut inbound_entry = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 5); - assert_ok!(result); - }); + assert_noop!( + inbound_entry.increment_proof_count(1), + Error::::ExpectedMessageProofType + ); + }); + } } - #[test] - fn process_inbound_with_errors() { - new_test_ext().execute_with(|| { - let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); - let domain_address = DomainAddress::EVM(0, address.into()); - - let counter = Arc::new(AtomicU32::new(0)); - MockLiquidityPools::mock_handle(move |_, _| { - match counter.fetch_add(1, Ordering::Relaxed) { - 2 => Err(DispatchError::Unavailable), - _ => Ok(()), - } + mod pre_dispatch_update { + use super::*; + + #[test] + fn message_success_same_session() { + new_test_ext().execute_with(|| { + let mut inbound_entry_1 = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }); + + let inbound_entry_2 = inbound_entry_1.clone(); + + assert_ok!(inbound_entry_1.pre_dispatch_update(inbound_entry_2)); + assert_eq!( + inbound_entry_1, + InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }) + ) + }); + } + + #[test] + fn message_success_session_change() { + new_test_ext().execute_with(|| { + let mut inbound_entry_1 = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }); + + let inbound_entry_2 = InboundEntry::::Message(MessageEntry { + session_id: 2, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 5, + }); + + assert_ok!(inbound_entry_1.pre_dispatch_update(inbound_entry_2.clone())); + assert_eq!(inbound_entry_1, inbound_entry_2) }); + } + + #[test] + fn proof_success_same_session() { + new_test_ext().execute_with(|| { + let mut inbound_entry_1 = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 1, + }); - let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { - domain_address, - message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + let inbound_entry_2 = inbound_entry_1.clone(); + + assert_ok!(inbound_entry_1.pre_dispatch_update(inbound_entry_2)); + assert_eq!( + inbound_entry_1, + InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 2, + }) + ) }); + } - // 2 correct messages and 1 failed message processed. - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 3); - assert_err!(result, DispatchError::Unavailable); - }); + #[test] + fn proof_success_session_change() { + new_test_ext().execute_with(|| { + let mut inbound_entry_1 = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 1, + }); + + let inbound_entry_2 = InboundEntry::::Proof(ProofEntry { + session_id: 2, + current_count: 3, + }); + + assert_ok!(inbound_entry_1.pre_dispatch_update(inbound_entry_2.clone())); + assert_eq!(inbound_entry_1, inbound_entry_2) + }); + } + + #[test] + fn mismatch_errors() { + new_test_ext().execute_with(|| { + let mut inbound_entry_1 = InboundEntry::::Message(MessageEntry { + session_id: 1, + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }); + + let mut inbound_entry_2 = InboundEntry::::Proof(ProofEntry { + session_id: 1, + current_count: 1, + }); + + assert_noop!( + inbound_entry_1.pre_dispatch_update(inbound_entry_2.clone()), + Error::::ExpectedMessageType + ); + + assert_noop!( + inbound_entry_2.pre_dispatch_update(inbound_entry_1), + Error::::ExpectedMessageProofType + ); + }); + } } } -*/ diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index b1ac9ed578..b330d71ac6 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -13,16 +13,16 @@ use frame_support::weights::{constants::RocksDbWeight, Weight}; pub trait WeightInfo { - fn set_domain_router() -> Weight; + fn set_routers() -> Weight; fn add_instance() -> Weight; fn remove_instance() -> Weight; fn add_relayer() -> Weight; fn remove_relayer() -> Weight; fn receive_message() -> Weight; - fn process_outbound_message() -> Weight; - fn process_failed_outbound_message() -> Weight; fn start_batch_message() -> Weight; fn end_batch_message() -> Weight; + fn set_domain_hook_address() -> Weight; + fn execute_message_recovery() -> Weight; } // NOTE: We use temporary weights here. `execute_epoch` is by far our heaviest @@ -31,7 +31,7 @@ pub trait WeightInfo { const N: u64 = 4; impl WeightInfo for () { - fn set_domain_router() -> Weight { + fn set_routers() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -103,18 +103,18 @@ impl WeightInfo for () { .saturating_add(Weight::from_parts(0, 17774).saturating_mul(N)) } - fn process_outbound_message() -> Weight { + fn start_batch_message() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` // This one has one read and one write for sure and possible one // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().reads(1)) .saturating_add(RocksDbWeight::get().writes(1)) } - fn process_failed_outbound_message() -> Weight { + fn end_batch_message() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -122,21 +122,21 @@ impl WeightInfo for () { // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(1)) + .saturating_add(RocksDbWeight::get().writes(2)) } - fn start_batch_message() -> Weight { + fn set_domain_hook_address() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` // This one has one read and one write for sure and possible one // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(1)) - .saturating_add(RocksDbWeight::get().writes(1)) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) } - fn end_batch_message() -> Weight { + fn execute_message_recovery() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` diff --git a/pallets/liquidity-pools/Cargo.toml b/pallets/liquidity-pools/Cargo.toml index 9b2bab88a5..62bec053f1 100644 --- a/pallets/liquidity-pools/Cargo.toml +++ b/pallets/liquidity-pools/Cargo.toml @@ -22,6 +22,7 @@ scale-info = { workspace = true } serde = { workspace = true } serde-big-array = { workspace = true } sp-core = { workspace = true } +sp-io = { workspace = true } sp-runtime = { workspace = true } sp-std = { workspace = true } staging-xcm = { workspace = true } @@ -48,6 +49,7 @@ std = [ "frame-support/std", "frame-system/std", "sp-std/std", + "sp-io/std", "sp-runtime/std", "orml-traits/std", "staging-xcm/std", diff --git a/pallets/liquidity-pools/src/message.rs b/pallets/liquidity-pools/src/message.rs index 076c8f6250..f8f04014ec 100644 --- a/pallets/liquidity-pools/src/message.rs +++ b/pallets/liquidity-pools/src/message.rs @@ -5,7 +5,10 @@ //! also have a custom GMPF implementation, aiming for a fixed-size encoded //! representation for each message variant. -use cfg_traits::{liquidity_pools::LPEncoding, Seconds}; +use cfg_traits::{ + liquidity_pools::{LPEncoding, Proof}, + Seconds, +}; use cfg_types::domain_address::Domain; use frame_support::{pallet_prelude::RuntimeDebug, BoundedVec}; use parity_scale_codec::{Decode, Encode, MaxEncodedLen}; @@ -15,6 +18,7 @@ use serde::{ ser::{Error as _, SerializeTuple}, Deserialize, Serialize, Serializer, }; +use sp_io::hashing::keccak_256; use sp_runtime::{traits::ConstU32, DispatchError, DispatchResult}; use sp_std::{vec, vec::Vec}; @@ -557,6 +561,19 @@ impl LPEncoding for Message { fn empty() -> Message { Message::Batch(BatchMessages::default()) } + + fn get_proof(&self) -> Option { + match self { + Message::MessageProof { hash } => Some(*hash), + _ => None, + } + } + + fn to_proof_message(&self) -> Self { + let hash = keccak_256(&LPEncoding::serialize(self)); + + Message::MessageProof { hash } + } } /// A Liquidity Pool message for updating restrictions on foreign domains. diff --git a/runtime/altair/src/lib.rs b/runtime/altair/src/lib.rs index 3889aadaf9..834333cae0 100644 --- a/runtime/altair/src/lib.rs +++ b/runtime/altair/src/lib.rs @@ -27,7 +27,7 @@ use cfg_primitives::{ IBalance, InvestmentId, ItemId, LoanId, Nonce, OrderId, PalletIndex, PoolEpochId, PoolFeeId, PoolId, Signature, TrancheId, TrancheWeight, }, - LPGatewayQueueMessageNonce, + LPGatewayQueueMessageNonce, LPGatewaySessionId, }; use cfg_traits::{investments::OrderManager, Millis, PoolUpdateGuard, Seconds}; use cfg_types::{ @@ -117,7 +117,7 @@ use runtime_common::{ permissions::{IsUnfrozenTrancheInvestor, PoolAdminCheck}, remarks::Remark, rewards::SingleCurrencyMovement, - routing::{EvmAccountCodeChecker, RouterDispatcher, RouterId}, + routing::{EvmAccountCodeChecker, LPGatewayRouterProvider, RouterDispatcher, RouterId}, transfer_filter::{PreLpTransfer, PreNativeTransfer}, xcm::AccountIdToLocation, xcm_transactor, AllowanceDeposit, CurrencyED, @@ -1757,8 +1757,9 @@ impl pallet_liquidity_pools::Config for Runtime { } parameter_types! { - pub const MaxIncomingMessageSize: u32 = 1024; pub Sender: DomainAddress = gateway::get_gateway_account::(); + pub const MaxIncomingMessageSize: u32 = 1024; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -1766,18 +1767,21 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = LiquidityPools; type LocalEVMOrigin = pallet_liquidity_pools_gateway::EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = pallet_liquidity_pools::Message; type MessageQueue = LiquidityPoolsGatewayQueue; type MessageSender = RouterDispatcher; type RouterId = RouterId; + type RouterProvider = LPGatewayRouterProvider; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = LPGatewaySessionId; type WeightInfo = (); } impl pallet_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; type MessageNonce = LPGatewayQueueMessageNonce; type MessageProcessor = LiquidityPoolsGateway; type RuntimeEvent = RuntimeEvent; diff --git a/runtime/centrifuge/src/lib.rs b/runtime/centrifuge/src/lib.rs index a4e2102a18..3cae0346fd 100644 --- a/runtime/centrifuge/src/lib.rs +++ b/runtime/centrifuge/src/lib.rs @@ -27,7 +27,7 @@ use cfg_primitives::{ IBalance, InvestmentId, ItemId, LoanId, Nonce, OrderId, PalletIndex, PoolEpochId, PoolFeeId, PoolId, Signature, TrancheId, TrancheWeight, }, - LPGatewayQueueMessageNonce, + LPGatewayQueueMessageNonce, LPGatewaySessionId, }; use cfg_traits::{ investments::OrderManager, Millis, Permissions as PermissionsT, PoolUpdateGuard, PreConditions, @@ -117,7 +117,7 @@ use runtime_common::{ }, permissions::{IsUnfrozenTrancheInvestor, PoolAdminCheck}, rewards::SingleCurrencyMovement, - routing::{EvmAccountCodeChecker, RouterDispatcher, RouterId}, + routing::{EvmAccountCodeChecker, LPGatewayRouterProvider, RouterDispatcher, RouterId}, transfer_filter::{PreLpTransfer, PreNativeTransfer}, xcm::AccountIdToLocation, xcm_transactor, AllowanceDeposit, CurrencyED, @@ -1840,8 +1840,9 @@ impl pallet_liquidity_pools::Config for Runtime { } parameter_types! { - pub const MaxIncomingMessageSize: u32 = 1024; pub Sender: DomainAddress = gateway::get_gateway_account::(); + pub const MaxIncomingMessageSize: u32 = 1024; + pub const MaxRouterCount: u32 = 8; } parameter_types! { @@ -1865,18 +1866,21 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = LiquidityPools; type LocalEVMOrigin = pallet_liquidity_pools_gateway::EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = pallet_liquidity_pools::Message; type MessageQueue = LiquidityPoolsGatewayQueue; type MessageSender = RouterDispatcher; type RouterId = RouterId; + type RouterProvider = LPGatewayRouterProvider; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = LPGatewaySessionId; type WeightInfo = (); } impl pallet_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; type MessageNonce = LPGatewayQueueMessageNonce; type MessageProcessor = LiquidityPoolsGateway; type RuntimeEvent = RuntimeEvent; diff --git a/runtime/common/src/routing.rs b/runtime/common/src/routing.rs index 9b9e345d3e..90cc97b62f 100644 --- a/runtime/common/src/routing.rs +++ b/runtime/common/src/routing.rs @@ -1,5 +1,5 @@ use cfg_traits::{ - liquidity_pools::{MessageSender, RouterSupport}, + liquidity_pools::{MessageSender, RouterProvider}, PreConditions, }; use cfg_types::domain_address::{Domain, DomainAddress}; @@ -37,8 +37,13 @@ impl From for Domain { } } -impl RouterSupport for RouterId { - fn for_domain(domain: Domain) -> Vec { +/// Static router provider used in the LP gateway. +pub struct LPGatewayRouterProvider; + +impl RouterProvider for LPGatewayRouterProvider { + type RouterId = RouterId; + + fn routers_for_domain(domain: Domain) -> Vec { match domain { Domain::EVM(chain_id) => vec![RouterId::Axelar(AxelarId::Evm(chain_id))], Domain::Centrifuge => vec![], diff --git a/runtime/development/src/lib.rs b/runtime/development/src/lib.rs index 973921019e..5fc9aa1b0e 100644 --- a/runtime/development/src/lib.rs +++ b/runtime/development/src/lib.rs @@ -27,7 +27,7 @@ use cfg_primitives::{ IBalance, InvestmentId, ItemId, LoanId, Nonce, OrderId, PalletIndex, PoolEpochId, PoolFeeId, PoolId, Signature, TrancheId, TrancheWeight, }, - LPGatewayQueueMessageNonce, + LPGatewayQueueMessageNonce, LPGatewaySessionId, }; use cfg_traits::{ investments::OrderManager, Millis, Permissions as PermissionsT, PoolUpdateGuard, PreConditions, @@ -125,7 +125,7 @@ use runtime_common::{ permissions::{IsUnfrozenTrancheInvestor, PoolAdminCheck}, remarks::Remark, rewards::SingleCurrencyMovement, - routing::{EvmAccountCodeChecker, RouterDispatcher, RouterId}, + routing::{EvmAccountCodeChecker, LPGatewayRouterProvider, RouterDispatcher, RouterId}, transfer_filter::{PreLpTransfer, PreNativeTransfer}, xcm::AccountIdToLocation, xcm_transactor, AllowanceDeposit, CurrencyED, @@ -1862,8 +1862,9 @@ impl pallet_liquidity_pools::Config for Runtime { } parameter_types! { - pub const MaxIncomingMessageSize: u32 = 1024; pub Sender: DomainAddress = gateway::get_gateway_account::(); + pub const MaxIncomingMessageSize: u32 = 1024; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -1871,18 +1872,21 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = LiquidityPools; type LocalEVMOrigin = pallet_liquidity_pools_gateway::EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = pallet_liquidity_pools::Message; type MessageQueue = LiquidityPoolsGatewayQueue; type MessageSender = RouterDispatcher; type RouterId = RouterId; + type RouterProvider = LPGatewayRouterProvider; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = LPGatewaySessionId; type WeightInfo = (); } impl pallet_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; type MessageNonce = LPGatewayQueueMessageNonce; type MessageProcessor = LiquidityPoolsGateway; type RuntimeEvent = RuntimeEvent; diff --git a/runtime/integration-tests/src/cases/liquidity_pools.rs b/runtime/integration-tests/src/cases/liquidity_pools.rs index ddc227f779..a3a80d3da6 100644 --- a/runtime/integration-tests/src/cases/liquidity_pools.rs +++ b/runtime/integration-tests/src/cases/liquidity_pools.rs @@ -22,6 +22,7 @@ use frame_support::{ OriginTrait, PalletInfo, }, }; +use pallet_axelar_router::AxelarId; use pallet_foreign_investments::ForeignInvestmentInfo; use pallet_investments::CollectOutcome; use pallet_liquidity_pools::Message; @@ -30,7 +31,7 @@ use pallet_liquidity_pools_gateway_queue::MessageNonceStore; use pallet_pool_system::tranches::{TrancheInput, TrancheLoc, TrancheType}; use runtime_common::{ account_conversion::AccountConverter, foreign_investments::IdentityPoolCurrencyConverter, - xcm::general_key, + routing::RouterId, xcm::general_key, }; use sp_core::Get; use sp_runtime::{ @@ -77,6 +78,7 @@ pub const DEFAULT_DOMAIN_ADDRESS_MOONBEAM: DomainAddress = DomainAddress::EVM(MOONBEAM_EVM_CHAIN_ID, DEFAULT_EVM_ADDRESS_MOONBEAM); pub const DEFAULT_OTHER_DOMAIN_ADDRESS: DomainAddress = DomainAddress::EVM(MOONBEAM_EVM_CHAIN_ID, [0; 20]); +pub const DEFAULT_ROUTER_ID: RouterId = RouterId::Axelar(AxelarId::Evm(MOONBEAM_EVM_CHAIN_ID)); pub type LiquidityPoolMessage = Message; @@ -288,6 +290,11 @@ pub mod utils { DEFAULT_BALANCE_GLMR, 0, )); + + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + ::RuntimeOrigin::root(), + BoundedVec::try_from(vec![DEFAULT_ROUTER_ID]).unwrap(), + )); }); } @@ -1038,7 +1045,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -1142,7 +1149,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: Message::FulfilledDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -1235,7 +1242,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -1270,14 +1277,11 @@ mod foreign_investments { message: GatewayMessage::Outbound { sender: event_sender, - destination: event_domain, + router_id: event_router_id, message: Message::FulfilledDepositRequest { .. }, }, .. - } => { - event_sender == sender - && event_domain == DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain() - } + } => event_sender == sender && event_router_id == DEFAULT_ROUTER_ID, _ => false, } } else { @@ -1520,7 +1524,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledRedeemRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -1610,7 +1614,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledRedeemRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -1965,7 +1969,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -2117,7 +2121,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledCancelDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), @@ -2229,7 +2233,7 @@ mod foreign_investments { nonce, message: GatewayMessage::Outbound { sender: sender.clone(), - destination: DEFAULT_DOMAIN_ADDRESS_MOONBEAM.domain(), + router_id: DEFAULT_ROUTER_ID, message: LiquidityPoolMessage::FulfilledCancelDepositRequest { pool_id, tranche_id: default_tranche_id::(pool_id), diff --git a/runtime/integration-tests/src/cases/liquidity_pools_gateway_queue.rs b/runtime/integration-tests/src/cases/liquidity_pools_gateway_queue.rs index 7034f18ecb..95f6ce9442 100644 --- a/runtime/integration-tests/src/cases/liquidity_pools_gateway_queue.rs +++ b/runtime/integration-tests/src/cases/liquidity_pools_gateway_queue.rs @@ -1,15 +1,17 @@ use cfg_traits::liquidity_pools::MessageQueue; -use cfg_types::domain_address::{Domain, DomainAddress}; -use frame_support::assert_ok; +use cfg_types::domain_address::DomainAddress; +use frame_support::{assert_ok, traits::OriginTrait}; use pallet_liquidity_pools::Message; use pallet_liquidity_pools_gateway::message::GatewayMessage; -use sp_runtime::traits::One; +use sp_runtime::{traits::One, BoundedVec}; use crate::{ + cases::liquidity_pools::{DEFAULT_DOMAIN_ADDRESS_MOONBEAM, DEFAULT_ROUTER_ID}, config::Runtime, env::{Blocks, Env}, envs::fudge_env::{FudgeEnv, FudgeSupport}, }; + /// NOTE - we're using fudge here because in a non-fudge environment, the event /// can only be read before block finalization. The LP gateway queue is /// processing messages during the `on_idle` hook, just before the block is @@ -23,9 +25,15 @@ fn inbound() { let mut env = FudgeEnv::::default(); let expected_event = env.parachain_state_mut(|| { + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + ::RuntimeOrigin::root(), + BoundedVec::try_from(vec![DEFAULT_ROUTER_ID]).unwrap(), + )); + let nonce = ::MessageNonce::one(); let message = GatewayMessage::Inbound { - domain_address: DomainAddress::EVM(1, [2; 20]), + domain_address: DEFAULT_DOMAIN_ADDRESS_MOONBEAM, + router_id: DEFAULT_ROUTER_ID, message: Message::Invalid, }; @@ -53,10 +61,15 @@ fn outbound() { let mut env = FudgeEnv::::default(); let expected_event = env.parachain_state_mut(|| { + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + ::RuntimeOrigin::root(), + BoundedVec::try_from(vec![DEFAULT_ROUTER_ID]).unwrap(), + )); + let nonce = ::MessageNonce::one(); let message = GatewayMessage::Outbound { sender: DomainAddress::Centrifuge([1; 32]), - destination: Domain::EVM(1), + router_id: DEFAULT_ROUTER_ID, message: Message::Invalid, }; diff --git a/runtime/integration-tests/src/cases/lp/mod.rs b/runtime/integration-tests/src/cases/lp/mod.rs index 1c6896a9bf..5ff21520b0 100644 --- a/runtime/integration-tests/src/cases/lp/mod.rs +++ b/runtime/integration-tests/src/cases/lp/mod.rs @@ -24,9 +24,9 @@ use ethabi::{ use frame_support::{assert_ok, dispatch::RawOrigin, traits::OriginTrait}; use frame_system::pallet_prelude::OriginFor; use hex_literal::hex; -use pallet_axelar_router::{AxelarConfig, DomainConfig, EvmConfig, FeeValues}; +use pallet_axelar_router::{AxelarConfig, AxelarId, DomainConfig, EvmConfig, FeeValues}; use pallet_evm::FeeCalculator; -use runtime_common::account_conversion::AccountConverter; +use runtime_common::{account_conversion::AccountConverter, routing::RouterId}; pub use setup_lp::*; use sp_core::Get; use sp_runtime::traits::{BlakeTwo256, Hash}; @@ -88,6 +88,8 @@ pub const EVM_DOMAIN_CHAIN_ID: u64 = 1; pub const EVM_DOMAIN: Domain = Domain::EVM(EVM_DOMAIN_CHAIN_ID); +pub const EVM_ROUTER_ID: RouterId = RouterId::Axelar(AxelarId::Evm(EVM_DOMAIN_CHAIN_ID)); + /// Represents Solidity enum Domain.Centrifuge pub const DOMAIN_CENTRIFUGE: u8 = 0; diff --git a/runtime/integration-tests/src/cases/lp/setup_lp.rs b/runtime/integration-tests/src/cases/lp/setup_lp.rs index efc3ad0299..e5cab8f68a 100644 --- a/runtime/integration-tests/src/cases/lp/setup_lp.rs +++ b/runtime/integration-tests/src/cases/lp/setup_lp.rs @@ -10,6 +10,8 @@ // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. +use sp_core::bounded_vec::BoundedVec; + use super::*; use crate::cases::lp::utils::pool_c_tranche_1_id; @@ -91,6 +93,11 @@ pub fn setup as EnvEvmExtension>::E DomainAddress::evm(EVM_DOMAIN_CHAIN_ID, EVM_LP_INSTANCE) )); + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + RawOrigin::Root.into(), + BoundedVec::try_from(vec![EVM_ROUTER_ID]).unwrap(), + )); + assert_ok!( pallet_liquidity_pools_gateway::Pallet::::set_domain_hook_address( RawOrigin::Root.into(), diff --git a/runtime/integration-tests/src/cases/lp/utils.rs b/runtime/integration-tests/src/cases/lp/utils.rs index e8229d5c17..fe5e570088 100644 --- a/runtime/integration-tests/src/cases/lp/utils.rs +++ b/runtime/integration-tests/src/cases/lp/utils.rs @@ -13,7 +13,7 @@ use std::{cmp::min, fmt::Debug}; use cfg_primitives::{Balance, TrancheId}; -use cfg_types::domain_address::{Domain, DomainAddress}; +use cfg_types::domain_address::DomainAddress; use ethabi::ethereum_types::{H160, H256, U256}; use fp_evm::CallInfo; use frame_support::traits::{OriginTrait, PalletInfo}; @@ -34,7 +34,7 @@ use staging_xcm::{ }; use crate::{ - cases::lp::{EVM_DOMAIN_CHAIN_ID, POOL_A, POOL_B, POOL_C}, + cases::lp::{EVM_DOMAIN_CHAIN_ID, EVM_ROUTER_ID, POOL_A, POOL_B, POOL_C}, config::Runtime, utils::{accounts::Keyring, evm::receipt_ok, last_event, pool::get_tranche_ids}, }; @@ -149,14 +149,14 @@ pub fn process_gateway_message( GatewayMessage::Inbound { message, .. } => verifier(message), GatewayMessage::Outbound { sender, - destination, + router_id, message, } => { assert_eq!( sender, ::Sender::get() ); - assert_eq!(destination, Domain::EVM(EVM_DOMAIN_CHAIN_ID)); + assert_eq!(router_id, EVM_ROUTER_ID); verifier(message) } } diff --git a/runtime/integration-tests/src/cases/restricted_transfers.rs b/runtime/integration-tests/src/cases/restricted_transfers.rs index bb1cfe7466..5269cc1b3f 100644 --- a/runtime/integration-tests/src/cases/restricted_transfers.rs +++ b/runtime/integration-tests/src/cases/restricted_transfers.rs @@ -21,8 +21,9 @@ use cfg_types::{ }; use cumulus_primitives_core::WeightLimit; use frame_support::{assert_noop, assert_ok, dispatch::RawOrigin, traits::PalletInfo}; -use runtime_common::remarks::Remark; -use sp_runtime::traits::Zero; +use pallet_axelar_router::AxelarId; +use runtime_common::{remarks::Remark, routing::RouterId}; +use sp_runtime::{traits::Zero, BoundedVec}; use staging_xcm::{ v4::{Junction::*, Location, NetworkId}, VersionedLocation, @@ -361,6 +362,7 @@ mod eth_address { const TRANSFER: u32 = 10; const CHAIN_ID: u64 = 1; const CONTRACT_ACCOUNT: [u8; 20] = [1; 20]; + const ROUTER_ID: RouterId = RouterId::Axelar(AxelarId::Evm(CHAIN_ID)); #[test_runtimes(all)] fn restrict_lp_eth_transfer() { @@ -399,6 +401,11 @@ mod eth_address { env.parachain_state_mut(|| { let curr_contract = DomainAddress::EVM(CHAIN_ID, CONTRACT_ACCOUNT); + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + RawOrigin::Root.into(), + BoundedVec::try_from(vec![ROUTER_ID]).unwrap(), + )); + assert_ok!( pallet_transfer_allowlist::Pallet::::add_transfer_allowance( RawOrigin::Signed(Keyring::Alice.into()).into(), diff --git a/runtime/integration-tests/src/cases/routers.rs b/runtime/integration-tests/src/cases/routers.rs index e4c13f0d55..0eaf6065fc 100644 --- a/runtime/integration-tests/src/cases/routers.rs +++ b/runtime/integration-tests/src/cases/routers.rs @@ -7,12 +7,12 @@ use cfg_types::{ use ethabi::{Function, Param, ParamType, Token}; use frame_support::{assert_ok, dispatch::RawOrigin}; use orml_traits::MultiCurrency; -use pallet_axelar_router::{AxelarConfig, DomainConfig, EvmConfig, FeeValues}; +use pallet_axelar_router::{AxelarConfig, AxelarId, DomainConfig, EvmConfig, FeeValues}; use pallet_liquidity_pools::Message; use pallet_liquidity_pools_gateway::message::GatewayMessage; use runtime_common::{ account_conversion::AccountConverter, evm::precompile::LP_AXELAR_GATEWAY, - gateway::get_gateway_h160_account, + gateway::get_gateway_h160_account, routing::RouterId, }; use sp_core::{Get, H160, H256, U256}; use sp_runtime::traits::{BlakeTwo256, Hash}; @@ -30,12 +30,15 @@ use crate::{ }; mod axelar_evm { + use frame_support::BoundedVec; + use super::*; const CHAIN_NAME: &str = "Ethereum"; const INITIAL: Balance = 100; const CHAIN_ID: EVMChainId = 1; const TEST_DOMAIN: Domain = Domain::EVM(CHAIN_ID); + const TEST_ROUTER_ID: RouterId = RouterId::Axelar(AxelarId::Evm(CHAIN_ID)); const AXELAR_CONTRACT_CODE: &[u8] = &[0, 0, 0]; const AXELAR_CONTRACT_ADDRESS: H160 = H160::repeat_byte(1); const LP_CONTRACT_ADDRESS: H160 = H160::repeat_byte(2); @@ -130,9 +133,14 @@ mod axelar_evm { Box::new(base_config::()), )); + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + RawOrigin::Root.into(), + BoundedVec::try_from(vec![TEST_ROUTER_ID]).unwrap(), + )); + let gateway_message = GatewayMessage::Outbound { sender: T::Sender::get(), - destination: TEST_DOMAIN, + router_id: TEST_ROUTER_ID, message: Message::Invalid, }; @@ -162,6 +170,11 @@ mod axelar_evm { Box::new(base_config::()), )); + assert_ok!(pallet_liquidity_pools_gateway::Pallet::::set_routers( + RawOrigin::Root.into(), + BoundedVec::try_from(vec![TEST_ROUTER_ID]).unwrap(), + )); + let message = Message::TransferAssets { currency: pallet_liquidity_pools::Pallet::::try_get_general_index(Usd18.id()) .unwrap(), diff --git a/runtime/integration-tests/src/config.rs b/runtime/integration-tests/src/config.rs index 6873cbcbc8..8671a58c54 100644 --- a/runtime/integration-tests/src/config.rs +++ b/runtime/integration-tests/src/config.rs @@ -137,7 +137,7 @@ pub trait Runtime: TrancheId = TrancheId, BalanceRatio = Ratio, > + pallet_liquidity_pools_gateway::Config - + pallet_liquidity_pools_gateway_queue::Config> + + pallet_liquidity_pools_gateway_queue::Config> + pallet_xcm_transactor::Config + pallet_ethereum::Config + pallet_ethereum_transaction::Config diff --git a/runtime/integration-tests/submodules/liquidity-pools b/runtime/integration-tests/submodules/liquidity-pools index 6e8f1a29df..4301885b9a 160000 --- a/runtime/integration-tests/submodules/liquidity-pools +++ b/runtime/integration-tests/submodules/liquidity-pools @@ -1 +1 @@ -Subproject commit 6e8f1a29dff0d7cf5ff74285cfffadae8a8b303f +Subproject commit 4301885b9a3b8ec36f3bda4b789daa5b115c006a