diff --git a/Cargo.lock b/Cargo.lock index 9bb0ef1336..ded0246947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5476,6 +5476,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -8266,9 +8275,13 @@ dependencies = [ "frame-support", "frame-system", "hex", + "itertools 0.13.0", + "lazy_static", + "mock-builder", "orml-traits", "parity-scale-codec", "scale-info", + "sp-arithmetic", "sp-core", "sp-io", "sp-runtime", diff --git a/Cargo.toml b/Cargo.toml index 42d8008e8e..836976ea9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,7 @@ impl-trait-for-tuples = "0.2.2" num-traits = { version = "0.2.17", default-features = false } num_enum = { version = "0.5.3", default-features = false } chrono = { version = "0.4", default-features = false } +itertools = { version = "0.13.0", default-features = false } # Cumulus cumulus-pallet-aura-ext = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.7.2" } diff --git a/libs/mocks/src/liquidity_pools.rs b/libs/mocks/src/liquidity_pools.rs index acee16d4db..dcae66f60f 100644 --- a/libs/mocks/src/liquidity_pools.rs +++ b/libs/mocks/src/liquidity_pools.rs @@ -2,7 +2,7 @@ pub mod pallet { use cfg_traits::liquidity_pools::InboundMessageHandler; use frame_support::pallet_prelude::*; - use mock_builder::{execute_call, register_call}; + use mock_builder::{execute_call, register_call, CallHandler}; #[pallet::config] pub trait Config: frame_system::Config { @@ -17,8 +17,10 @@ pub mod pallet { type CallIds = StorageMap<_, _, String, mock_builder::CallId>; impl Pallet { - pub fn mock_handle(f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static) { - register_call!(move |(sender, msg)| f(sender, msg)); + pub fn mock_handle( + f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static, + ) -> CallHandler { + register_call!(move |(sender, msg)| f(sender, msg)) } } diff --git a/libs/mocks/src/liquidity_pools_gateway_routers.rs b/libs/mocks/src/liquidity_pools_gateway_routers.rs index 22bfdf1bb0..41408ffeba 100644 --- a/libs/mocks/src/liquidity_pools_gateway_routers.rs +++ b/libs/mocks/src/liquidity_pools_gateway_routers.rs @@ -26,9 +26,14 @@ pub mod pallet { ) { register_call!(move |(sender, message)| f(sender, message)); } + + pub fn mock_hash(f: impl Fn() -> T::Hash + 'static) { + register_call!(move |()| f()); + } } impl MockedRouter for Pallet { + type Hash = T::Hash; type Sender = T::AccountId; fn init() -> DispatchResult { @@ -38,6 +43,10 @@ pub mod pallet { fn send(sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo { execute_call!((sender, message)) } + + fn hash() -> Self::Hash { + execute_call!(()) + } } } @@ -68,11 +77,16 @@ impl RouterMock { ) { pallet::Pallet::::mock_send(f) } + + pub fn mock_hash(&self, f: impl Fn() -> as Router>::Hash + 'static) { + pallet::Pallet::::mock_hash(f) + } } /// Here we implement the actual Router trait for the `RouterMock` which in turn /// calls the `MockedRouter` trait implementation. impl Router for RouterMock { + type Hash = T::Hash; type Sender = T::AccountId; fn init(&self) -> DispatchResult { @@ -82,6 +96,10 @@ impl Router for RouterMock { fn send(&self, sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo { pallet::Pallet::::send(sender, message) } + + fn hash(&self) -> Self::Hash { + pallet::Pallet::::hash() + } } /// A mocked Router trait that emulates the actual Router trait but without @@ -94,9 +112,13 @@ trait MockedRouter { /// The sender type of the outbound message. type Sender; + type Hash; + /// Initialize the router. fn init() -> DispatchResult; /// Send the message to the router's destination. fn send(sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo; + + fn hash() -> Self::Hash; } diff --git a/libs/traits/src/liquidity_pools.rs b/libs/traits/src/liquidity_pools.rs index 9ee729e4c6..6ff5b4b588 100644 --- a/libs/traits/src/liquidity_pools.rs +++ b/libs/traits/src/liquidity_pools.rs @@ -18,6 +18,8 @@ use frame_support::{ use sp_runtime::DispatchError; use sp_std::vec::Vec; +pub type Proof = [u8; 32]; + /// An encoding & decoding trait for the purpose of meeting the /// LiquidityPools General Message Passing Format pub trait LPEncoding: Sized { @@ -34,6 +36,12 @@ pub trait LPEncoding: Sized { /// Creates an empty message. /// It's the identity message for composing messages with pack_with fn empty() -> Self; + + /// Retrieves the message proof, if any. + fn get_message_proof(&self) -> Option; + + /// Converts the message into a message proof type. + fn to_message_proof(&self) -> Self; } /// The trait required for sending outbound messages. @@ -41,11 +49,17 @@ pub trait Router { /// The sender type of the outbound message. type Sender; + /// The router hash type. + type Hash; + /// Initialize the router. fn init(&self) -> DispatchResult; /// Send the message to the router's destination. fn send(&self, sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo; + + /// Generate a hash for this router. + fn hash(&self) -> Self::Hash; } /// The trait required for queueing messages. diff --git a/pallets/liquidity-pools-gateway/Cargo.toml b/pallets/liquidity-pools-gateway/Cargo.toml index 5f01d96284..e2caa19652 100644 --- a/pallets/liquidity-pools-gateway/Cargo.toml +++ b/pallets/liquidity-pools-gateway/Cargo.toml @@ -22,6 +22,7 @@ scale-info = { workspace = true } sp-core = { workspace = true } sp-runtime = { workspace = true } sp-std = { workspace = true } +sp-arithmetic = { workspace = true } # Benchmarking frame-benchmarking = { workspace = true, optional = true } @@ -35,6 +36,9 @@ cfg-utils = { workspace = true } [dev-dependencies] cfg-mocks = { workspace = true, default-features = true } sp-io = { workspace = true, default-features = true } +itertools = { workspace = true, default-features = true } +lazy_static = { workspace = true, default-features = true } +mock-builder = { workspace = true, default-features = true } [features] default = ["std"] @@ -53,6 +57,7 @@ std = [ "cfg-utils/std", "hex/std", "cfg-primitives/std", + "sp-arithmetic/std", ] try-runtime = [ "cfg-traits/try-runtime", diff --git a/pallets/liquidity-pools-gateway/routers/src/lib.rs b/pallets/liquidity-pools-gateway/routers/src/lib.rs index a731009fbb..47e3601398 100644 --- a/pallets/liquidity-pools-gateway/routers/src/lib.rs +++ b/pallets/liquidity-pools-gateway/routers/src/lib.rs @@ -76,6 +76,7 @@ where OriginFor: From + Into>>, { + type Hash = T::Hash; type Sender = T::AccountId; fn init(&self) -> DispatchResult { @@ -89,6 +90,12 @@ where DomainRouter::AxelarEVM(r) => r.do_send(sender, message), } } + + fn hash(&self) -> Self::Hash { + match self { + DomainRouter::AxelarEVM(r) => r.hash(), + } + } } /// A generic router used for executing EVM calls. diff --git a/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs b/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs index 670bb8b4dd..d9d5f069a9 100644 --- a/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs +++ b/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs @@ -21,7 +21,8 @@ use scale_info::{ prelude::{format, string::String}, TypeInfo, }; -use sp_core::{bounded::BoundedVec, ConstU32, H160}; +use sp_core::{bounded::BoundedVec, ConstU32, Hasher, H160}; +use sp_runtime::traits::BlakeTwo256; use sp_std::{collections::btree_map::BTreeMap, vec, vec::Vec}; use crate::{ @@ -77,6 +78,10 @@ where self.router.do_send(sender, eth_msg) } + + pub fn hash(&self) -> T::Hash { + BlakeTwo256::hash(self.evm_chain.encode().as_slice()) + } } /// Encodes the provided message into the format required for submitting it diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 2dfc01b85e..c28907cbfb 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -31,7 +31,7 @@ use core::fmt::Debug; use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{ InboundMessageHandler, LPEncoding, MessageProcessor, MessageQueue, OutboundMessageHandler, - Router as DomainRouter, + Proof, Router as DomainRouter, }; use cfg_types::domain_address::{Domain, DomainAddress}; use frame_support::{dispatch::DispatchResult, pallet_prelude::*}; @@ -40,9 +40,10 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; +use sp_arithmetic::traits::{BaseArithmetic, One}; use sp_std::convert::TryInto; -use crate::weights::WeightInfo; +use crate::{message_processing::InboundEntry, weights::WeightInfo}; mod origin; pub use origin::*; @@ -54,11 +55,15 @@ pub mod weights; #[cfg(test)] mod mock; +mod message_processing; #[cfg(test)] mod tests; #[frame_support::pallet] pub mod pallet { + use frame_system::pallet_prelude::BlockNumberFor; + use sp_arithmetic::traits::{EnsureAdd, EnsureAddAssign}; + use super::*; const STORAGE_VERSION: StorageVersion = StorageVersion::new(1); @@ -70,6 +75,13 @@ pub mod pallet { #[pallet::origin] pub type Origin = GatewayOrigin; + #[pallet::hooks] + impl Hooks> for Pallet { + fn on_idle(_now: BlockNumberFor, max_weight: Weight) -> Weight { + Self::clear_invalid_session_ids(max_weight) + } + } + #[pallet::config] pub trait Config: frame_system::Config { /// The origin type. @@ -94,7 +106,7 @@ pub mod pallet { type Message: LPEncoding + Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec; /// The message router type that is stored for each domain. - type Router: DomainRouter + type Router: DomainRouter + Clone + Debug + MaxEncodedLen @@ -103,6 +115,9 @@ pub mod pallet { + EncodeLike + PartialEq; + /// The type used for identifying routers. + type RouterId: Clone + Debug + MaxEncodedLen + TypeInfo + FullCodec + EncodeLike + PartialEq; + /// The type that processes inbound messages. type InboundMessageHandler: InboundMessageHandler< Sender = DomainAddress, @@ -121,14 +136,34 @@ pub mod pallet { type Sender: Get; /// Type used for queueing messages. - type MessageQueue: MessageQueue>; + type MessageQueue: MessageQueue< + Message = GatewayMessage, + >; + + /// Maximum number of routers allowed for a domain. + #[pallet::constant] + type MaxRouterCount: Get; + + /// Type for identifying sessions of inbound routers. + type SessionId: Parameter + + Member + + BaseArithmetic + + Default + + Copy + + MaybeSerializeDeserialize + + TypeInfo + + MaxEncodedLen; } #[pallet::event] #[pallet::generate_deposit(pub (super) fn deposit_event)] pub enum Event { - /// The router for a given domain was set. - DomainRouterSet { domain: Domain, router: T::Router }, + /// The routers for a given domain were set. + RoutersSet { + domain: Domain, + router_ids: BoundedVec, + session_id: T::SessionId, + }, /// An instance was added to a domain. InstanceAdded { instance: DomainAddress }, @@ -141,14 +176,31 @@ pub mod pallet { domain: Domain, hook_address: [u8; 20], }, + + /// Message recovery was executed. + MessageRecoveryExecuted { + domain: Domain, + proof: Proof, + router_id: T::RouterId, + }, } - /// Storage for domain routers. + // TODO(cdamian): Add migration to clear this storage. + // /// Storage for domain routers. + // /// + // /// This can only be set by an admin. + // #[pallet::storage] + // #[pallet::getter(fn domain_routers)] + // pub type DomainRouters = StorageMap<_, Blake2_128Concat, Domain, + // T::Router>; + + /// Storage for routers specific for a domain. /// /// This can only be set by an admin. #[pallet::storage] - #[pallet::getter(fn domain_routers)] - pub type DomainRouters = StorageMap<_, Blake2_128Concat, Domain, T::Router>; + #[pallet::getter(fn routers)] + pub type Routers = + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage that contains a limited number of whitelisted instances of /// deployed liquidity pools for a particular domain. @@ -175,11 +227,38 @@ pub mod pallet { pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; + /// Storage for pending inbound messages. + #[pallet::storage] + #[pallet::getter(fn pending_inbound_entries)] + pub type PendingInboundEntries = StorageDoubleMap< + _, + Blake2_128Concat, + T::SessionId, + Blake2_128Concat, + (Proof, T::RouterId), + InboundEntry, + >; + + /// Storage for the inbound message session IDs. + #[pallet::storage] + #[pallet::getter(fn inbound_message_sessions)] + pub type InboundMessageSessions = + StorageMap<_, Blake2_128Concat, Domain, T::SessionId>; + + /// Storage for inbound message session IDs. + #[pallet::storage] + pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; + + /// Storage that keeps track of invalid session IDs. + /// + /// Any `PendingInboundEntries` mapped to the invalid IDs are removed from + /// storage during `on_idle`. + #[pallet::storage] + #[pallet::getter(fn invalid_message_sessions)] + pub type InvalidMessageSessions = StorageMap<_, Blake2_128Concat, T::SessionId, ()>; + #[pallet::error] pub enum Error { - /// Router initialization failed. - RouterInitFailed, - /// The origin of the message to be processed is invalid. InvalidMessageOrigin, @@ -195,8 +274,8 @@ pub mod pallet { /// Unknown instance. UnknownInstance, - /// Router not found. - RouterNotFound, + /// Routers not found. + RoutersNotFound, /// Emitted when you call `start_batch_messages()` but that was already /// called. You should finalize the message with `end_batch_messages()` @@ -205,27 +284,73 @@ pub mod pallet { /// Emitted when you can `end_batch_message()` but the packing process /// was not started by `start_batch_message()`. MessagePackingNotStarted, + + /// Invalid multi router. + InvalidMultiRouter, + + /// Inbound domain session not found. + InboundDomainSessionNotFound, + + /// The router that sent the inbound message is unknown. + UnknownInboundMessageRouter, + + /// The router that sent the message is not the first one. + MessageExpectedFromFirstRouter, + + /// The router that sent the proof should not be the first one. + ProofNotExpectedFromFirstRouter, + + /// A message was expected instead of a proof. + ExpectedMessageType, + + /// A message proof was expected instead of a message. + ExpectedMessageProofType, + + /// Pending inbound entry not found. + PendingInboundEntryNotFound, + + /// Message proof cannot be retrieved. + MessageProofRetrieval, + + /// Recovery message not found. + RecoveryMessageNotFound, } #[pallet::call] impl Pallet { - /// Set a domain's router, - #[pallet::weight(T::WeightInfo::set_domain_router())] + /// Sets the router IDs used for a specific domain, + #[pallet::weight(T::WeightInfo::set_domain_routers())] #[pallet::call_index(0)] - pub fn set_domain_router( + pub fn set_domain_routers( origin: OriginFor, domain: Domain, - router: T::Router, + router_ids: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); - router.init().map_err(|_| Error::::RouterInitFailed)?; + //TODO(cdamian): Outbound - Call router.init() for each router? + + >::insert(domain.clone(), router_ids.clone()); + + if let Some(old_session_id) = InboundMessageSessions::::get(domain.clone()) { + InvalidMessageSessions::::insert(old_session_id, ()); + } + + let session_id = SessionIdStore::::try_mutate(|n| { + n.ensure_add_assign(One::one())?; - >::insert(domain.clone(), router.clone()); + Ok::(*n) + })?; - Self::deposit_event(Event::DomainRouterSet { domain, router }); + InboundMessageSessions::::insert(domain.clone(), session_id); + + Self::deposit_event(Event::RoutersSet { + domain, + router_ids, + session_id, + }); Ok(()) } @@ -277,6 +402,7 @@ pub mod pallet { #[pallet::call_index(5)] pub fn receive_message( origin: OriginFor, + router_id: T::RouterId, msg: BoundedVec, ) -> DispatchResult { let GatewayOrigin::Domain(origin_address) = T::LocalEVMOrigin::ensure_origin(origin)?; @@ -290,10 +416,12 @@ pub mod pallet { Error::::UnknownInstance, ); - let gateway_message = GatewayMessage::::Inbound { - domain_address: origin_address, - message: T::Message::deserialize(&msg)?, - }; + let gateway_message = + GatewayMessage::::Inbound { + domain_address: origin_address, + message: T::Message::deserialize(&msg)?, + router_id, + }; T::MessageQueue::submit(gateway_message) } @@ -301,7 +429,7 @@ pub mod pallet { /// Set the address of the domain hook /// /// Can only be called by `AdminOrigin`. - #[pallet::weight(T::WeightInfo::set_domain_router())] + #[pallet::weight(T::WeightInfo::set_domain_hook_address())] #[pallet::call_index(8)] pub fn set_domain_hook_address( origin: OriginFor, @@ -352,61 +480,58 @@ pub mod pallet { None => Err(Error::::MessagePackingNotStarted.into()), } } - } - - impl Pallet { - /// Give the message to the `InboundMessageHandler` to be processed. - fn process_inbound_message( - domain_address: DomainAddress, - message: T::Message, - ) -> (DispatchResult, Weight) { - let mut count = 0; - - for submessage in message.submessages() { - count += 1; - - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), submessage) - { - // We only consume the processed weight if error during the batch - return (Err(e), LP_DEFENSIVE_WEIGHT.saturating_mul(count)); - } - } - - (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) - } - /// Retrieves the router stored for the provided domain, sends the - /// message using the router, and calculates and returns the required - /// weight for these operations in the `DispatchResultWithPostInfo`. - fn process_outbound_message( - sender: T::AccountId, + /// Manually increase the proof count for a particular message and + /// executes it if the required count is reached. + /// + /// Can only be called by `AdminOrigin`. + #[pallet::weight(T::WeightInfo::execute_message_recovery())] + #[pallet::call_index(11)] + pub fn execute_message_recovery( + origin: OriginFor, domain: Domain, - message: T::Message, - ) -> (DispatchResult, Weight) { - let read_weight = T::DbWeight::get().reads(1); + proof: Proof, + router_id: T::RouterId, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; - let Some(router) = DomainRouters::::get(domain) else { - return (Err(Error::::RouterNotFound.into()), read_weight); - }; + let session_id = InboundMessageSessions::::get(&domain) + .ok_or(Error::::InboundDomainSessionNotFound)?; - let (result, router_weight) = match router.send(sender, message.serialize()) { - Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), - Err(e) => (Err(e.error), e.post_info.actual_weight), - }; + let routers = Routers::::get(&domain).ok_or(Error::::RoutersNotFound)?; - (result, router_weight.unwrap_or(read_weight)) - } + ensure!( + routers.iter().any(|x| x == &router_id), + Error::::UnknownInboundMessageRouter + ); - fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - // We are using the sender specified in the pallet config so that we can - // ensure that the account is funded - let gateway_message = GatewayMessage::::Outbound { - sender: T::Sender::get(), - destination, - message, - }; + PendingInboundEntries::::try_mutate( + session_id, + (proof, router_id.clone()), + |storage_entry| match storage_entry { + Some(entry) => match entry { + InboundEntry::Proof { current_count } => { + current_count.ensure_add_assign(1).map_err(|e| e.into()) + } + InboundEntry::Message { .. } => { + Err(Error::::ExpectedMessageProofType.into()) + } + }, + None => { + *storage_entry = Some(InboundEntry::::Proof { current_count: 1 }); + + Ok::<(), DispatchError>(()) + } + }, + )?; + + Self::deposit_event(Event::::MessageRecoveryExecuted { + domain, + proof, + router_id, + }); - T::MessageQueue::submit(gateway_message) + Ok(()) } } @@ -439,23 +564,25 @@ pub mod pallet { } impl MessageProcessor for Pallet { - type Message = GatewayMessage; + type Message = GatewayMessage; fn process(msg: Self::Message) -> (DispatchResult, Weight) { match msg { GatewayMessage::Inbound { domain_address, message, - } => Self::process_inbound_message(domain_address, message), + router_id, + } => Self::process_inbound_message(domain_address, message, router_id), GatewayMessage::Outbound { sender, - destination, message, - } => Self::process_outbound_message(sender, destination, message), + router_id, + } => Self::process_outbound_message(sender, message, router_id), } } - /// Process a message. + /// Returns the max processing weight for a message, based on its + /// direction. fn max_processing_weight(msg: &Self::Message) -> Weight { match msg { GatewayMessage::Inbound { message, .. } => { diff --git a/pallets/liquidity-pools-gateway/src/message.rs b/pallets/liquidity-pools-gateway/src/message.rs index cf0bbb1a17..0d6fc4ff38 100644 --- a/pallets/liquidity-pools-gateway/src/message.rs +++ b/pallets/liquidity-pools-gateway/src/message.rs @@ -1,25 +1,29 @@ -use cfg_types::domain_address::{Domain, DomainAddress}; +use cfg_types::domain_address::DomainAddress; use frame_support::pallet_prelude::{Decode, Encode, MaxEncodedLen, TypeInfo}; /// Message type used by the LP gateway. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] -pub enum GatewayMessage { +pub enum GatewayMessage { Inbound { domain_address: DomainAddress, message: Message, + router_id: RouterId, }, Outbound { sender: AccountId, - destination: Domain, message: Message, + router_id: RouterId, }, } -impl Default for GatewayMessage { +impl Default + for GatewayMessage +{ fn default() -> Self { GatewayMessage::Inbound { domain_address: Default::default(), message: Default::default(), + router_id: Default::default(), } } } diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs new file mode 100644 index 0000000000..80828991a1 --- /dev/null +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -0,0 +1,482 @@ +use cfg_primitives::LP_DEFENSIVE_WEIGHT; +use cfg_traits::liquidity_pools::{InboundMessageHandler, LPEncoding, MessageQueue, Proof}; +use cfg_types::domain_address::{Domain, DomainAddress}; +use frame_support::{ + dispatch::DispatchResult, + ensure, + pallet_prelude::{Decode, Encode, Get, TypeInfo}, + weights::Weight, + BoundedVec, +}; +use parity_scale_codec::MaxEncodedLen; +use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub}; +use sp_runtime::DispatchError; + +use crate::{ + message::GatewayMessage, Config, Error, InboundMessageSessions, InvalidMessageSessions, Pallet, + PendingInboundEntries, Routers, +}; + +/// The limit used when clearing the `PendingInboundEntries` for invalid +/// session IDs. +const INVALID_ID_REMOVAL_LIMIT: u32 = 100; + +/// Type used when storing inbound message information. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub enum InboundEntry { + Message { + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + }, + Proof { + current_count: u32, + }, +} + +/// Type used when processing inbound messages. +#[derive(Clone)] +pub struct InboundProcessingInfo { + domain_address: DomainAddress, + routers: BoundedVec, + current_session_id: T::SessionId, + expected_proof_count_per_message: u32, +} + +impl Pallet { + /// Calculates and returns the proof count required for processing one + /// inbound message. + fn get_expected_proof_count(domain: &Domain) -> Result { + let routers = Routers::::get(domain).ok_or(Error::::RoutersNotFound)?; + + let expected_proof_count = routers.len().ensure_sub(1)?; + + Ok(expected_proof_count as u32) + } + + /// Gets the message proof for a message. + fn get_message_proof(message: T::Message) -> Proof { + match message.get_message_proof() { + None => message + .to_message_proof() + .get_message_proof() + .expect("message proof ensured by 'to_message_proof'"), + Some(proof) => proof, + } + } + + /// Creates an inbound entry based on whether the inbound message is a + /// proof or not. + fn create_inbound_entry( + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + ) -> InboundEntry { + match message.get_message_proof() { + None => InboundEntry::Message { + domain_address, + message, + expected_proof_count, + }, + Some(_) => InboundEntry::Proof { current_count: 1 }, + } + } + + /// Validation ensures that: + /// + /// - the router that sent the inbound message is a valid router for the + /// specific domain. + /// - messages are only sent by the first inbound router. + /// - proofs are not sent by the first inbound router. + fn validate_inbound_entry( + inbound_processing_info: &InboundProcessingInfo, + router_id: &T::RouterId, + inbound_entry: &InboundEntry, + ) -> DispatchResult { + let routers = inbound_processing_info.routers.clone(); + + ensure!( + routers.iter().any(|x| x == router_id), + Error::::UnknownInboundMessageRouter + ); + + match inbound_entry { + InboundEntry::Message { .. } => { + ensure!( + routers.get(0) == Some(&router_id), + Error::::MessageExpectedFromFirstRouter + ); + + Ok(()) + } + InboundEntry::Proof { .. } => { + ensure!( + routers.get(0) != Some(&router_id), + Error::::ProofNotExpectedFromFirstRouter + ); + + Ok(()) + } + } + } + + /// Upserts an inbound entry for a particular message, increasing the + /// relevant counts accordingly. + fn upsert_pending_entry( + session_id: T::SessionId, + message_proof: Proof, + router_id: T::RouterId, + inbound_entry: InboundEntry, + weight: &mut Weight, + ) -> DispatchResult { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + PendingInboundEntries::::try_mutate( + session_id, + (message_proof, router_id), + |storage_entry| match storage_entry { + None => { + *storage_entry = Some(inbound_entry); + + Ok::<(), DispatchError>(()) + } + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count: old, + .. + } => match inbound_entry { + InboundEntry::Message { + expected_proof_count: new, + .. + } => old.ensure_add_assign(new).map_err(|e| e.into()), + InboundEntry::Proof { .. } => Err(Error::::ExpectedMessageType.into()), + }, + InboundEntry::Proof { current_count: old } => match inbound_entry { + InboundEntry::Proof { current_count: new } => { + old.ensure_add_assign(new).map_err(|e| e.into()) + } + InboundEntry::Message { .. } => { + Err(Error::::ExpectedMessageProofType.into()) + } + }, + }, + }, + ) + } + + /// Creates, validates and upserts the inbound entry. + fn validate_and_upsert_pending_entries( + inbound_processing_info: &InboundProcessingInfo, + message: T::Message, + message_proof: Proof, + router_id: T::RouterId, + weight: &mut Weight, + ) -> DispatchResult { + let inbound_entry = Self::create_inbound_entry( + inbound_processing_info.domain_address.clone(), + message, + inbound_processing_info.expected_proof_count_per_message, + ); + + Self::validate_inbound_entry(&inbound_processing_info, &router_id, &inbound_entry)?; + + Self::upsert_pending_entry( + inbound_processing_info.current_session_id, + message_proof, + router_id, + inbound_entry, + weight, + )?; + + Ok(()) + } + + /// Checks if the number of proofs required for executing one message + /// were received, and returns the message if so. + fn get_executable_message( + inbound_processing_info: &InboundProcessingInfo, + message_proof: Proof, + weight: &mut Weight, + ) -> Option { + let mut message = None; + let mut votes = 0; + + for router in &inbound_processing_info.routers { + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + match PendingInboundEntries::::get( + inbound_processing_info.current_session_id, + (message_proof, router), + ) { + // We expected one InboundEntry for each router, if that's not the case, + // we can return. + None => return None, + Some(inbound_entry) => match inbound_entry { + InboundEntry::Message { + message: stored_message, + .. + } => message = Some(stored_message), + InboundEntry::Proof { current_count } => { + if current_count > 0 { + votes += 1; + } + } + }, + }; + } + + if votes == inbound_processing_info.expected_proof_count_per_message { + return message; + } + + None + } + + /// Decreases the counts for inbound entries and removes them if the + /// counts reach 0. + fn decrease_pending_entries_counts( + inbound_processing_info: &InboundProcessingInfo, + message_proof: Proof, + weight: &mut Weight, + ) -> DispatchResult { + for router in &inbound_processing_info.routers { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + match PendingInboundEntries::::try_mutate( + inbound_processing_info.current_session_id, + (message_proof, router), + |storage_entry| match storage_entry { + None => Err(Error::::PendingInboundEntryNotFound.into()), + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count, + .. + } => { + let updated_count = (*expected_proof_count).ensure_sub( + inbound_processing_info.expected_proof_count_per_message, + )?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *expected_proof_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + InboundEntry::Proof { current_count } => { + let updated_count = (*current_count).ensure_sub(1)?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *current_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + }, + }, + ) { + Ok(()) => {} + Err(e) => return Err(e), + } + } + + Ok(()) + } + + /// Retrieves the information required for processing an inbound + /// message. + fn get_inbound_processing_info( + domain_address: DomainAddress, + weight: &mut Weight, + ) -> Result, DispatchError> { + let routers = + Routers::::get(domain_address.domain()).ok_or(Error::::RoutersNotFound)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let current_session_id = InboundMessageSessions::::get(domain_address.domain()) + .ok_or(Error::::InboundDomainSessionNotFound)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + Ok(InboundProcessingInfo { + domain_address, + routers, + current_session_id, + expected_proof_count_per_message: expected_proof_count, + }) + } + + /// Iterates over a batch of messages and checks if the requirements for + /// processing each message are met. + pub(crate) fn process_inbound_message( + domain_address: DomainAddress, + message: T::Message, + router_id: T::RouterId, + ) -> (DispatchResult, Weight) { + let mut weight = Default::default(); + + let inbound_processing_info = + match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) { + Ok(i) => i, + Err(e) => return (Err(e), weight), + }; + + weight.saturating_accrue( + Weight::from_parts(0, T::Message::max_encoded_len() as u64) + .saturating_add(LP_DEFENSIVE_WEIGHT), + ); + + let mut count = 0; + + for submessage in message.submessages() { + count += 1; + + let message_proof = Self::get_message_proof(message.clone()); + + if let Err(e) = Self::validate_and_upsert_pending_entries( + &inbound_processing_info, + submessage.clone(), + message_proof, + router_id.clone(), + &mut weight, + ) { + return (Err(e), weight); + } + + match Self::get_executable_message(&inbound_processing_info, message_proof, &mut weight) + { + Some(m) => { + if let Err(e) = Self::decrease_pending_entries_counts( + &inbound_processing_info, + message_proof, + &mut weight, + ) { + return (Err(e), weight.saturating_mul(count)); + } + + if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) { + // We only consume the processed weight if error during the batch + return (Err(e), weight.saturating_mul(count)); + } + } + None => continue, + } + } + + (Ok(()), weight.saturating_mul(count)) + } + + /// Retrieves the stored router, sends the message, and calculates and + /// returns the router operation result and the weight used. + pub(crate) fn process_outbound_message( + sender: T::AccountId, + message: T::Message, + router_id: T::RouterId, + ) -> (DispatchResult, Weight) { + let read_weight = T::DbWeight::get().reads(1); + + // TODO(cdamian): Update when the router refactor is done. + + // let Some(router) = Routers::::get(router_id) else { + // return (Err(Error::::RouterNotFound.into()), read_weight); + // }; + // + // let (result, router_weight) = match router.send(sender, message.serialize()) + // { Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), + // Err(e) => (Err(e.error), e.post_info.actual_weight), + // }; + // + // (result, router_weight.unwrap_or(read_weight)) + + (Ok(()), read_weight) + } + + /// Retrieves the IDs of the routers set for a domain and queues the + /// message and proofs accordingly. + pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { + let router_ids = + Routers::::get(destination.clone()).ok_or(Error::::RoutersNotFound)?; + + let message_proof = message.to_message_proof(); + let mut message_opt = Some(message); + + for router_id in router_ids { + // Ensure that we only send the actual message once, using one router. + // The remaining routers will send the message proof. + let router_msg = match message_opt.take() { + Some(m) => m, + None => message_proof.clone(), + }; + + // We are using the sender specified in the pallet config so that we can + // ensure that the account is funded + let gateway_message = + GatewayMessage::::Outbound { + sender: T::Sender::get(), + message: router_msg, + router_id, + }; + + T::MessageQueue::submit(gateway_message)?; + } + + Ok(()) + } + + /// Clears `PendingInboundEntries` mapped to invalid session IDs as long as + /// there is enough weight available for this operation. + /// + /// The invalid session IDs are removed from storage if all entries mapped + /// to them were cleared. + pub(crate) fn clear_invalid_session_ids(max_weight: Weight) -> Weight { + let invalid_session_ids = InvalidMessageSessions::::iter_keys().collect::>(); + + let mut weight = T::DbWeight::get().reads(1); + + for invalid_session_id in invalid_session_ids { + let mut cursor: Option> = None; + + loop { + let res = PendingInboundEntries::::clear_prefix( + invalid_session_id, + INVALID_ID_REMOVAL_LIMIT, + cursor.as_ref().map(|x| x.as_ref()), + ); + + weight.saturating_accrue( + T::DbWeight::get().reads_writes(res.loops.into(), res.unique.into()), + ); + + if weight.all_gte(max_weight) { + return weight; + } + + cursor = match res.maybe_cursor { + None => { + InvalidMessageSessions::::remove(invalid_session_id); + + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + if weight.all_gte(max_weight) { + return weight; + } + + break; + } + Some(c) => Some(c), + }; + } + } + + weight + } +} diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 348ccf8d91..b97d1ba99a 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -1,8 +1,10 @@ +use std::fmt::{Debug, Formatter}; + use cfg_mocks::{ pallet_mock_liquidity_pools, pallet_mock_liquidity_pools_gateway_queue, pallet_mock_routers, RouterMock, }; -use cfg_traits::liquidity_pools::LPEncoding; +use cfg_traits::liquidity_pools::{LPEncoding, Proof}; use cfg_types::domain_address::DomainAddress; use frame_support::{derive_impl, weights::constants::RocksDbWeight}; use frame_system::EnsureRoot; @@ -18,11 +20,24 @@ pub const LP_ADMIN_ACCOUNT: AccountId32 = AccountId32::new([u8::MAX; 32]); pub const MAX_PACKED_MESSAGES_ERR: &str = "packed limit error"; pub const MAX_PACKED_MESSAGES: usize = 10; -#[derive(Default, Debug, Eq, PartialEq, Clone, Encode, Decode, TypeInfo)] +pub const MESSAGE_PROOF: [u8; 32] = [1; 32]; + +#[derive(Default, Eq, PartialEq, Clone, Encode, Decode, TypeInfo, Hash)] pub enum Message { #[default] Simple, Pack(Vec), + Proof([u8; 32]), +} + +impl Debug for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Simple => write!(f, "Simple"), + Message::Pack(p) => write!(f, "Pack - {:?}", p), + Message::Proof(_) => write!(f, "Proof"), + } + } } /// Avoiding automatic infinity loop with the MaxEncodedLen derive @@ -35,8 +50,8 @@ impl MaxEncodedLen for Message { impl LPEncoding for Message { fn serialize(&self) -> Vec { match self { - Self::Simple => vec![0x42], Self::Pack(list) => list.iter().map(|_| 0x42).collect(), + _ => vec![0x42], } } @@ -50,10 +65,6 @@ impl LPEncoding for Message { fn pack_with(&mut self, other: Self) -> DispatchResult { match self { - Self::Simple => { - *self = Self::Pack(vec![Self::Simple, other]); - Ok(()) - } Self::Pack(list) if list.len() == MAX_PACKED_MESSAGES => { Err(MAX_PACKED_MESSAGES_ERR.into()) } @@ -61,19 +72,37 @@ impl LPEncoding for Message { list.push(other); Ok(()) } + _ => { + *self = Self::Pack(vec![self.clone(), other]); + Ok(()) + } } } fn submessages(&self) -> Vec { match self { - Self::Simple => vec![Self::Simple], Self::Pack(list) => list.clone(), + _ => vec![self.clone()], } } fn empty() -> Self { Self::Pack(vec![]) } + + fn get_message_proof(&self) -> Option { + match self { + Message::Proof(p) => Some(p.clone()), + _ => None, + } + } + + fn to_message_proof(&self) -> Self { + match self { + Message::Proof(_) => self.clone(), + _ => Message::Proof(MESSAGE_PROOF), + } + } } frame_support::construct_runtime!( @@ -102,13 +131,14 @@ impl pallet_mock_liquidity_pools::Config for Runtime { impl pallet_mock_routers::Config for Runtime {} impl pallet_mock_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; } frame_support::parameter_types! { pub Sender: AccountId32 = AccountId32::from(H256::from_low_u64_be(1).to_fixed_bytes()); pub const MaxIncomingMessageSize: u32 = 1024; pub const LpAdminAccount: AccountId32 = LP_ADMIN_ACCOUNT; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -116,12 +146,16 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = MockLiquidityPools; type LocalEVMOrigin = EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; type Router = RouterMock; + //TODO(cdamian): Change to some other type for tests? + type RouterId = H256; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = u32; type WeightInfo = (); } diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 13afa3a3bc..7e01d1c45b 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,4 +1,5 @@ -use cfg_mocks::*; +use std::collections::HashMap; + use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler}; use cfg_types::domain_address::*; @@ -6,8 +7,15 @@ use frame_support::{ assert_err, assert_noop, assert_ok, dispatch::PostDispatchInfo, pallet_prelude::Pays, weights::Weight, }; -use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160}; -use sp_runtime::{DispatchError, DispatchError::BadOrigin, DispatchErrorWithPostInfo}; +use itertools::Itertools; +use lazy_static::lazy_static; +use parity_scale_codec::MaxEncodedLen; +use sp_arithmetic::ArithmeticError::Overflow; +use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160, H256}; +use sp_runtime::{ + DispatchError, + DispatchError::{Arithmetic, BadOrigin}, +}; use sp_std::sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -18,7 +26,15 @@ use super::{ origin::*, pallet::*, }; -use crate::GatewayMessage; +use crate::{message_processing::InboundEntry, GatewayMessage}; + +pub const TEST_DOMAIN_ADDRESS: DomainAddress = DomainAddress::EVM(0, [1; 20]); + +lazy_static! { + static ref ROUTER_HASH_1: H256 = H256::from_low_u64_be(1); + static ref ROUTER_HASH_2: H256 = H256::from_low_u64_be(2); + static ref ROUTER_HASH_3: H256 = H256::from_low_u64_be(3); +} mod utils { use super::*; @@ -41,43 +57,67 @@ mod utils { use utils::*; -mod set_domain_router { +mod set_domain_routers { use super::*; #[test] fn success() { new_test_ext().execute_with(|| { let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let mut session_id = 1; + + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + let mut router_ids = + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router.clone(), + router_ids.clone(), )); - let storage_entry = DomainRouters::::get(domain.clone()); - assert_eq!(storage_entry.unwrap(), router); + assert_eq!(Routers::::get(domain.clone()).unwrap(), router_ids); + assert_eq!( + InboundMessageSessions::::get(domain.clone()), + Some(session_id) + ); + assert_eq!(InvalidMessageSessions::::get(session_id - 1), None); - event_exists(Event::::DomainRouterSet { domain, router }); - }); - } - #[test] - fn router_init_error() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Err(DispatchError::Other("error"))); + event_exists(Event::::RoutersSet { + domain: domain.clone(), + router_ids, + session_id, + }); - assert_noop!( - LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router, - ), - Error::::RouterInitFailed, + router_ids = BoundedVec::try_from(vec![router_id_3, router_id_2, router_id_1]).unwrap(); + + session_id += 1; + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( + RuntimeOrigin::root(), + domain.clone(), + router_ids.clone(), + )); + + assert_eq!(Routers::::get(domain.clone()).unwrap(), router_ids); + assert_eq!( + InboundMessageSessions::::get(domain.clone()), + Some(session_id) + ); + assert_eq!( + InvalidMessageSessions::::get(session_id - 1), + Some(()) ); + + event_exists(Event::::RoutersSet { + domain, + router_ids, + session_id, + }); }); } @@ -85,19 +125,19 @@ mod set_domain_router { fn bad_origin() { new_test_ext().execute_with(|| { let domain = Domain::EVM(0); - let router = RouterMock::::default(); assert_noop!( - LiquidityPoolsGateway::set_domain_router( + LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::signed(get_test_account_id()), domain.clone(), - router, + BoundedVec::try_from(vec![]).unwrap(), ), BadOrigin ); - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); + assert!(Routers::::get(domain.clone()).is_none()); + assert!(InboundMessageSessions::::get(domain).is_none()); + assert!(InvalidMessageSessions::::get(0).is_none()); }); } @@ -105,19 +145,37 @@ mod set_domain_router { fn unsupported_domain() { new_test_ext().execute_with(|| { let domain = Domain::Centrifuge; - let router = RouterMock::::default(); assert_noop!( - LiquidityPoolsGateway::set_domain_router( + LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router + BoundedVec::try_from(vec![]).unwrap(), ), Error::::DomainNotSupported ); - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); + assert!(Routers::::get(domain.clone()).is_none()); + assert!(InboundMessageSessions::::get(domain).is_none()); + assert!(InvalidMessageSessions::::get(0).is_none()); + }); + } + + #[test] + fn session_id_overflow() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + + SessionIdStore::::set(u32::MAX); + + assert_noop!( + LiquidityPoolsGateway::set_domain_routers( + RuntimeOrigin::root(), + domain, + BoundedVec::try_from(vec![]).unwrap(), + ), + Arithmetic(Overflow) + ); }); } } @@ -282,6 +340,8 @@ mod receive_message_domain { let domain_address = DomainAddress::EVM(0, address.into()); let message = Message::Simple; + let router_id = H256::from_low_u64_be(1); + assert_ok!(LiquidityPoolsGateway::add_instance( RuntimeOrigin::root(), domain_address.clone(), @@ -289,9 +349,10 @@ mod receive_message_domain { let encoded_msg = message.serialize(); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_id, }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -301,6 +362,7 @@ mod receive_message_domain { assert_ok!(LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() )); }); @@ -311,9 +373,12 @@ mod receive_message_domain { new_test_ext().execute_with(|| { let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); + assert_noop!( LiquidityPoolsGateway::receive_message( RuntimeOrigin::signed(AccountId32::new([0u8; 32])), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), BadOrigin, @@ -326,10 +391,12 @@ mod receive_message_domain { new_test_ext().execute_with(|| { let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), Error::::InvalidMessageOrigin, @@ -343,10 +410,12 @@ mod receive_message_domain { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), Error::::UnknownInstance, @@ -361,6 +430,8 @@ mod receive_message_domain { let domain_address = DomainAddress::EVM(0, address.into()); let message = Message::Simple; + let router_id = H256::from_low_u64_be(1); + assert_ok!(LiquidityPoolsGateway::add_instance( RuntimeOrigin::root(), domain_address.clone(), @@ -370,9 +441,10 @@ mod receive_message_domain { let err = sp_runtime::DispatchError::from("liquidity_pools error"); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_id, }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -383,6 +455,7 @@ mod receive_message_domain { assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), err, @@ -400,24 +473,52 @@ mod outbound_message_handler_impl { let domain = Domain::EVM(0); let sender = get_test_account_id(); let msg = Message::Simple; - - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); - - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let message_proof = msg.to_message_proof().get_message_proof().unwrap(); + + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + //TODO(cdamian): Router init + // let router_hash_1 = H256::from_low_u64_be(1); + // let router_hash_2 = H256::from_low_u64_be(2); + // let router_hash_3 = H256::from_low_u64_be(3); + // + // let router_mock_1 = RouterMock::::default(); + // let router_mock_2 = RouterMock::::default(); + // let router_mock_3 = RouterMock::::default(); + // + // router_mock_1.mock_init(move || Ok(())); + // router_mock_1.mock_hash(move || router_hash_1); + // router_mock_2.mock_init(move || Ok(())); + // router_mock_2.mock_hash(move || router_hash_2); + // router_mock_3.mock_init(move || Ok(())); + // router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router.clone(), + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(), )); - let gateway_message = GatewayMessage::::Outbound { - sender: ::Sender::get(), - destination: domain.clone(), - message: msg.clone(), - }; - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { - assert_eq!(mock_msg, gateway_message); + match mock_msg { + GatewayMessage::Inbound { .. } => { + assert!(false, "expected outbound message") + } + GatewayMessage::Outbound { + sender, message, .. + } => { + assert_eq!(sender, ::Sender::get()); + + match message { + Message::Proof(p) => { + assert_eq!(p, message_proof); + } + _ => {} + } + } + } Ok(()) }); @@ -447,30 +548,48 @@ mod outbound_message_handler_impl { let sender = get_test_account_id(); let msg = Message::Simple; - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); - - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + //TODO(cdamian): Router init? + // let router_hash_1 = H256::from_low_u64_be(1); + // let router_hash_2 = H256::from_low_u64_be(2); + // let router_hash_3 = H256::from_low_u64_be(3); + // + // let router_mock_1 = RouterMock::::default(); + // let router_mock_2 = RouterMock::::default(); + // let router_mock_3 = RouterMock::::default(); + // + // router_mock_1.mock_init(move || Ok(())); + // router_mock_1.mock_hash(move || router_hash_1); + // router_mock_2.mock_init(move || Ok(())); + // router_mock_2.mock_hash(move || router_hash_2); + // router_mock_3.mock_init(move || Ok(())); + // router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router.clone(), + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(), )); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender: ::Sender::get(), - destination: domain.clone(), message: msg.clone(), + router_id: router_id_1, }; let err = DispatchError::Unavailable; - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { assert_eq!(mock_msg, gateway_message); Err(err) }); assert_noop!(LiquidityPoolsGateway::handle(sender, domain, msg), err); + assert_eq!(handler.times(), 1); }); } } @@ -530,36 +649,1896 @@ mod message_processor_impl { mod inbound { use super::*; - #[test] - fn success() { - new_test_ext().execute_with(|| { - let domain_address = DomainAddress::EVM(1, [1; 20]); - let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - }; + #[macro_use] + mod util { + use super::*; + + pub fn run_inbound_message_test_suite(suite: InboundMessageTestSuite) { + let test_routers = suite.routers; + + for test in suite.tests { + println!("Executing test for - {:?}", test.router_messages); + + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::try_from(test_routers.clone()).unwrap(), + ); + InboundMessageSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let handler = MockLiquidityPools::mock_handle(move |_, _| Ok(())); + + for router_message in test.router_messages { + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: router_message.1, + router_id: router_message.0, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + } + + let expected_message_submitted_times = + test.expected_test_result.message_submitted_times; + let message_submitted_times = handler.times(); + + assert_eq!( + message_submitted_times, + expected_message_submitted_times, + "Expected message to be submitted {expected_message_submitted_times} times, was {message_submitted_times}" + ); + + for expected_storage_entry in + test.expected_test_result.expected_storage_entries + { + let expected_storage_entry_router_hash = expected_storage_entry.0; + let expected_inbound_entry = expected_storage_entry.1; + + let storage_entry = PendingInboundEntries::::get( + session_id, + (MESSAGE_PROOF, expected_storage_entry_router_hash), + ); + assert_eq!(storage_entry, expected_inbound_entry, "Expected inbound entry {expected_inbound_entry:?}, found {storage_entry:?}"); + } + }); + } + } + + /// Used for generating all `RouterMessage` combinations like: + /// + /// vec![ + /// (*ROUTER_HASH_1, Message::Simple), + /// (*ROUTER_HASH_1, Message::Simple), + /// ] + /// vec![ + /// (*ROUTER_HASH_1, Message::Simple), + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// ] + /// vec![ + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// (*ROUTER_HASH_1, Message::Simple), + /// ] + /// vec![ + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// ] + pub fn generate_test_combinations( + t: T, + count: usize, + ) -> Vec::Item>> + where + T: IntoIterator + Clone, + T::IntoIter: Clone, + T::Item: Clone, + { + std::iter::repeat(t.clone().into_iter()) + .take(count) + .multi_cartesian_product() + .collect::>() + } + + /// Type used for mapping a message to a router hash. + pub type RouterMessage = (H256, Message); + + /// Type used for aggregating tests for inbound messages. + pub struct InboundMessageTestSuite { + pub routers: Vec, + pub tests: Vec, + } + + /// Type used for defining a test which contains a set of + /// `RouterMessage` combinations and the expected test result. + pub struct InboundMessageTest { + pub router_messages: Vec, + pub expected_test_result: ExpectedTestResult, + } + + /// Type used for defining the number of expected inbound message + /// submission and the exected storage state. + #[derive(Clone, Debug)] + pub struct ExpectedTestResult { + pub message_submitted_times: u32, + pub expected_storage_entries: Vec<(H256, Option>)>, + } + + /// Generates the combinations of `RouterMessage` used when testing, + /// maps the `ExpectedTestResult` for each and creates the + /// `InboundMessageTestSuite`. + pub fn generate_test_suite( + routers: Vec, + test_data: Vec, + expected_results: HashMap, ExpectedTestResult>, + message_count: usize, + ) -> InboundMessageTestSuite { + let tests = generate_test_combinations(test_data, message_count); + + let tests = tests + .into_iter() + .map(|router_messages| { + let expected_test_result = expected_results + .get(&router_messages) + .expect( + format!("test for {router_messages:?} should be covered").as_str(), + ) + .clone(); + + InboundMessageTest { + router_messages, + expected_test_result, + } + }) + .collect::>(); - MockLiquidityPools::mock_handle(move |mock_domain_address, mock_mesage| { - assert_eq!(mock_domain_address, domain_address); - assert_eq!(mock_mesage, message); + InboundMessageTestSuite { routers, tests } + } + } - Ok(()) + use util::*; + + mod one_router { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let message_proof = message.to_message_proof().get_message_proof().unwrap(); + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_hash, + }; + + Routers::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), session_id); + + let handler = MockLiquidityPools::mock_handle( + move |mock_domain_address, mock_message| { + assert_eq!(mock_domain_address, domain_address); + assert_eq!(mock_message, message); + + Ok(()) + }, + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + assert_eq!(handler.times(), 1); + + assert!(PendingInboundEntries::::get( + session_id, + (message_proof, router_hash) + ) + .is_none()); }); + } + + #[test] + fn multi_router_not_found() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_hash, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::RoutersNotFound); + }); + } + + #[test] + fn inbound_domain_session_not_found() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_hash, + }; + + Routers::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::InboundDomainSessionNotFound); + }); + } + + #[test] + fn unknown_inbound_message_router() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + // The router stored has a different hash, this should trigger the expected + // error. + router_id: *ROUTER_HASH_2, + }; + + Routers::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), session_id); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::UnknownInboundMessageRouter); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let message_proof = message.to_message_proof().get_message_proof().unwrap(); + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_id: router_hash, + }; + + Routers::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), session_id); + PendingInboundEntries::::insert( + session_id, + (message_proof, router_hash), + InboundEntry::::Proof { current_count: 0 }, + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ExpectedMessageProofType); + }); + } + } - let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); - }); + mod two_routers { + use super::*; + + mod success { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 3, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod four_messages { + use super::*; + + const MESSAGE_COUNT: usize = 4; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 4, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + } + + mod failure { + use super::*; + + #[test] + fn message_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) + .unwrap(), + ); + InboundMessageSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + router_id: *ROUTER_HASH_2, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::MessageExpectedFromFirstRouter); + }); + } + + #[test] + fn proof_not_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) + .unwrap(), + ); + InboundMessageSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Proof(MESSAGE_PROOF), + router_id: *ROUTER_HASH_1, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ProofNotExpectedFromFirstRouter); + }); + } + } + } + + mod three_routers { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2, *ROUTER_HASH_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 6, + }), + ), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2, *ROUTER_HASH_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } } #[test] fn inbound_message_handler_error() { new_test_ext().execute_with(|| { let domain_address = DomainAddress::EVM(1, [1; 20]); + + let router_id = H256::from_low_u64_be(1); + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); + let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_id, }; let err = DispatchError::Unavailable; @@ -571,9 +2550,8 @@ mod message_processor_impl { Err(err) }); - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + let (res, _) = LiquidityPoolsGateway::process(gateway_message); assert_noop!(res, err); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT); }); } } @@ -596,24 +2574,30 @@ mod message_processor_impl { pays_fee: Pays::Yes, }; - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); - - Ok(router_post_info) - }); - - DomainRouters::::insert(domain.clone(), router_mock); + let router_id = H256::from_low_u64_be(1); + + //TODO(cdamian): Drop mock? + // let router_hash = H256::from_low_u64_be(1); + // + // let router_mock = RouterMock::::default(); + // router_mock.mock_send(move |mock_sender, mock_message| { + // assert_eq!(mock_sender, expected_sender); + // assert_eq!(mock_message, expected_message.serialize()); + // + // Ok(router_post_info) + // }); + // router_mock.mock_hash(move || router_hash); + // + // DomainRouters::::insert(domain.clone(), router_mock); let min_expected_weight = ::DbWeight::get() .reads(1) + router_post_info.actual_weight.unwrap() + Weight::from_parts(0, message.serialize().len() as u64); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender, - destination: domain, message: message.clone(), + router_id, }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); @@ -622,26 +2606,27 @@ mod message_processor_impl { }); } - #[test] - fn router_not_found() { - new_test_ext().execute_with(|| { - let sender = get_test_account_id(); - let domain = Domain::EVM(1); - let message = Message::Simple; - - let expected_weight = ::DbWeight::get().reads(1); - - let gateway_message = GatewayMessage::::Outbound { - sender, - destination: domain, - message, - }; - - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, Error::::RouterNotFound); - assert_eq!(weight, expected_weight); - }); - } + //TODO(cdamian): Fix when bi-directional routers are in. + // #[test] + // fn router_not_found() { + // new_test_ext().execute_with(|| { + // let sender = get_test_account_id(); + // let message = Message::Simple; + // + // let expected_weight = ::DbWeight::get().reads(1); + // + // let gateway_message = GatewayMessage::Outbound { + // sender, + // message, + // router_id: H256::from_low_u64_be(1), + // }; + // + // let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + // assert_noop!(res, Error::::RouterNotFound); + // assert_eq!(weight, expected_weight); + // }); + // } #[test] fn router_error() { @@ -658,33 +2643,35 @@ mod message_processor_impl { pays_fee: Pays::Yes, }; - let router_err = DispatchError::Unavailable; - - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); - - Err(DispatchErrorWithPostInfo { - post_info: router_post_info, - error: router_err, - }) - }); - - DomainRouters::::insert(domain.clone(), router_mock); + // let router_err = DispatchError::Unavailable; + // + // let router_mock = RouterMock::::default(); + // router_mock.mock_send(move |mock_sender, mock_message| { + // assert_eq!(mock_sender, expected_sender); + // assert_eq!(mock_message, expected_message.serialize()); + // + // Err(DispatchErrorWithPostInfo { + // post_info: router_post_info, + // error: router_err, + // }) + // }); + // + // DomainRouters::::insert(domain.clone(), router_mock); let min_expected_weight = ::DbWeight::get() .reads(1) + router_post_info.actual_weight.unwrap() + Weight::from_parts(0, message.serialize().len() as u64); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender, - destination: domain, message: message.clone(), + router_id: H256::from_low_u64_be(1), }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, router_err); + //TODO(cdamian): Error out + assert_ok!(res); + // assert_noop!(res, router_err) assert!(weight.all_lte(min_expected_weight)); }); } @@ -725,6 +2712,10 @@ mod batches { // Ok Batched assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert(DOMAIN, BoundedVec::try_from(vec![router_id_1]).unwrap()); + // Not batched, it belong to OTHER assert_ok!(LiquidityPoolsGateway::handle( OTHER, @@ -732,6 +2723,11 @@ mod batches { Message::Simple )); + Routers::::insert( + Domain::EVM(2), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + // Not batched, it belong to EVM 2 assert_ok!(LiquidityPoolsGateway::handle( USER, @@ -774,6 +2770,10 @@ mod batches { DispatchError::Other(MAX_PACKED_MESSAGES_ERR) ); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert(DOMAIN, BoundedVec::try_from(vec![router_id_1]).unwrap()); + assert_ok!(LiquidityPoolsGateway::end_batch_message( RuntimeOrigin::signed(USER), DOMAIN @@ -812,15 +2812,54 @@ mod batches { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); - MockLiquidityPools::mock_handle(|_, _| Ok(())); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); + + let handler = MockLiquidityPools::mock_handle(|_, _| Ok(())); + + let submessage_count = 5; let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, - message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + message: Message::deserialize(&(1..=submessage_count).collect::>()).unwrap(), + router_id: *ROUTER_HASH_1, }); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 5); + let expected_weight = Weight::default() + // get_inbound_processing_info + .saturating_add(::DbWeight::get().reads(3)) + // process_inbound_message + .saturating_add(Weight::from_parts(0, Message::max_encoded_len() as u64)) + .saturating_add(LP_DEFENSIVE_WEIGHT) + // upsert_pending_entry + .saturating_add( + ::DbWeight::get() + .writes(1) + .saturating_mul(submessage_count.into()), + ) + // get_executable_message + .saturating_add( + ::DbWeight::get() + .reads(1) + .saturating_mul(submessage_count.into()), + ) + // decrease_pending_entries_counts + .saturating_add( + ::DbWeight::get() + .writes(1) + .saturating_mul(submessage_count.into()), + ) + // process_inbound_message + .saturating_mul(submessage_count.into()); + assert_ok!(result); + assert_eq!(weight, expected_weight); + assert_eq!(handler.times(), submessage_count as u32); }); } @@ -830,22 +2869,226 @@ mod batches { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); + let counter = Arc::new(AtomicU32::new(0)); - MockLiquidityPools::mock_handle(move |_, _| { + + let handler = MockLiquidityPools::mock_handle(move |_, _| { match counter.fetch_add(1, Ordering::Relaxed) { 2 => Err(DispatchError::Unavailable), _ => Ok(()), } }); - let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { + let (result, _) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + router_id: *ROUTER_HASH_1, }); - // 2 correct messages and 1 failed message processed. - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 3); assert_err!(result, DispatchError::Unavailable); + // 2 correct messages and 1 failed message processed. + assert_eq!(handler.times(), 3); + }); + } +} + +mod execute_message_recovery { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + )); + + event_exists(Event::::MessageRecoveryExecuted { + domain: domain.clone(), + proof: MESSAGE_PROOF, + router_id: router_id.clone(), + }); + + let inbound_entry = + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, router_id)) + .expect("inbound entry is stored"); + + assert_eq!( + inbound_entry, + InboundEntry::::Proof { current_count: 1 } + ); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + )); + + let inbound_entry = + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, router_id)) + .expect("inbound entry is stored"); + + assert_eq!( + inbound_entry, + InboundEntry::::Proof { current_count: 2 } + ); + + event_exists(Event::::MessageRecoveryExecuted { + domain: domain.clone(), + proof: MESSAGE_PROOF, + router_id: router_id.clone(), + }); + }); + } + + #[test] + fn inbound_domain_session_not_found() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Error::::InboundDomainSessionNotFound + ); + }); + } + + #[test] + fn routers_not_found() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Error::::RoutersNotFound + ); + }); + } + + #[test] + fn unknown_inbound_message_router() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id_2 + ), + Error::::UnknownInboundMessageRouter + ); + }); + } + + #[test] + fn proof_count_overflow() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, router_id), + InboundEntry::::Proof { + current_count: u32::MAX, + }, + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Arithmetic(Overflow) + ); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), session_id); + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, router_id), + InboundEntry::::Message { + domain_address: domain_address.clone(), + message: Message::Simple, + expected_proof_count: 2, + }, + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain_address.domain(), + MESSAGE_PROOF, + router_id + ), + Error::::ExpectedMessageProofType + ); }); } } diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index b1ac9ed578..aaf598ec91 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -13,16 +13,16 @@ use frame_support::weights::{constants::RocksDbWeight, Weight}; pub trait WeightInfo { - fn set_domain_router() -> Weight; + fn set_domain_routers() -> Weight; fn add_instance() -> Weight; fn remove_instance() -> Weight; fn add_relayer() -> Weight; fn remove_relayer() -> Weight; fn receive_message() -> Weight; - fn process_outbound_message() -> Weight; - fn process_failed_outbound_message() -> Weight; fn start_batch_message() -> Weight; fn end_batch_message() -> Weight; + fn set_domain_hook_address() -> Weight; + fn execute_message_recovery() -> Weight; } // NOTE: We use temporary weights here. `execute_epoch` is by far our heaviest @@ -31,7 +31,7 @@ pub trait WeightInfo { const N: u64 = 4; impl WeightInfo for () { - fn set_domain_router() -> Weight { + fn set_domain_routers() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -103,18 +103,18 @@ impl WeightInfo for () { .saturating_add(Weight::from_parts(0, 17774).saturating_mul(N)) } - fn process_outbound_message() -> Weight { + fn start_batch_message() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` // This one has one read and one write for sure and possible one // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().reads(1)) .saturating_add(RocksDbWeight::get().writes(1)) } - fn process_failed_outbound_message() -> Weight { + fn end_batch_message() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -122,21 +122,21 @@ impl WeightInfo for () { // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(1)) + .saturating_add(RocksDbWeight::get().writes(2)) } - fn start_batch_message() -> Weight { + fn set_domain_hook_address() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` // This one has one read and one write for sure and possible one // read for `AdminOrigin` Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(1)) - .saturating_add(RocksDbWeight::get().writes(1)) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) } - fn end_batch_message() -> Weight { + fn execute_message_recovery() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` diff --git a/pallets/liquidity-pools/src/message.rs b/pallets/liquidity-pools/src/message.rs index d981abfb16..d4ae5d5f4e 100644 --- a/pallets/liquidity-pools/src/message.rs +++ b/pallets/liquidity-pools/src/message.rs @@ -5,7 +5,10 @@ //! also have a custom GMPF implementation, aiming for a fixed-size encoded //! representation for each message variant. -use cfg_traits::{liquidity_pools::LPEncoding, Seconds}; +use cfg_traits::{ + liquidity_pools::{LPEncoding, Proof}, + Seconds, +}; use cfg_types::domain_address::Domain; use frame_support::{pallet_prelude::RuntimeDebug, BoundedVec}; use parity_scale_codec::{Decode, Encode, MaxEncodedLen}; @@ -15,7 +18,7 @@ use serde::{ ser::{Error as _, SerializeTuple}, Deserialize, Serialize, Serializer, }; -use sp_core::U256; +use sp_core::{keccak_256, U256}; use sp_runtime::{traits::ConstU32, DispatchError, DispatchResult}; use sp_std::{vec, vec::Vec}; @@ -558,6 +561,19 @@ impl LPEncoding for Message { fn empty() -> Message { Message::Batch(BatchMessages::default()) } + + fn get_message_proof(&self) -> Option { + match self { + Message::MessageProof { hash } => Some(hash.clone()), + _ => None, + } + } + + fn to_message_proof(&self) -> Self { + let hash = keccak_256(&LPEncoding::serialize(self)); + + Message::MessageProof { hash } + } } /// A Liquidity Pool message for updating restrictions on foreign domains.