From 15aaedccb0ca958c26d97a5271fb2aa3baa78fca Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Thu, 8 Aug 2024 14:37:39 +0300 Subject: [PATCH 01/11] wip --- Cargo.toml | 1 + libs/mocks/src/liquidity_pools.rs | 8 +- .../src/liquidity_pools_gateway_routers.rs | 22 + libs/traits/src/liquidity_pools.rs | 11 + pallets/liquidity-pools-gateway/Cargo.toml | 3 + .../routers/src/lib.rs | 9 + .../routers/src/routers/axelar_evm.rs | 7 +- pallets/liquidity-pools-gateway/src/lib.rs | 297 ++++++- .../liquidity-pools-gateway/src/message.rs | 6 +- pallets/liquidity-pools-gateway/src/mock.rs | 49 +- pallets/liquidity-pools-gateway/src/tests.rs | 729 ++++++++++++++++-- .../liquidity-pools-gateway/src/weights.rs | 36 + pallets/liquidity-pools/src/message.rs | 20 +- 13 files changed, 1118 insertions(+), 80 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 42d8008e8e..836976ea9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -86,6 +86,7 @@ impl-trait-for-tuples = "0.2.2" num-traits = { version = "0.2.17", default-features = false } num_enum = { version = "0.5.3", default-features = false } chrono = { version = "0.4", default-features = false } +itertools = { version = "0.13.0", default-features = false } # Cumulus cumulus-pallet-aura-ext = { git = "https://github.com/paritytech/polkadot-sdk", default-features = false, branch = "release-polkadot-v1.7.2" } diff --git a/libs/mocks/src/liquidity_pools.rs b/libs/mocks/src/liquidity_pools.rs index acee16d4db..dcae66f60f 100644 --- a/libs/mocks/src/liquidity_pools.rs +++ b/libs/mocks/src/liquidity_pools.rs @@ -2,7 +2,7 @@ pub mod pallet { use cfg_traits::liquidity_pools::InboundMessageHandler; use frame_support::pallet_prelude::*; - use mock_builder::{execute_call, register_call}; + use mock_builder::{execute_call, register_call, CallHandler}; #[pallet::config] pub trait Config: frame_system::Config { @@ -17,8 +17,10 @@ pub mod pallet { type CallIds = StorageMap<_, _, String, mock_builder::CallId>; impl Pallet { - pub fn mock_handle(f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static) { - register_call!(move |(sender, msg)| f(sender, msg)); + pub fn mock_handle( + f: impl Fn(T::DomainAddress, T::Message) -> DispatchResult + 'static, + ) -> CallHandler { + register_call!(move |(sender, msg)| f(sender, msg)) } } diff --git a/libs/mocks/src/liquidity_pools_gateway_routers.rs b/libs/mocks/src/liquidity_pools_gateway_routers.rs index 22bfdf1bb0..41408ffeba 100644 --- a/libs/mocks/src/liquidity_pools_gateway_routers.rs +++ b/libs/mocks/src/liquidity_pools_gateway_routers.rs @@ -26,9 +26,14 @@ pub mod pallet { ) { register_call!(move |(sender, message)| f(sender, message)); } + + pub fn mock_hash(f: impl Fn() -> T::Hash + 'static) { + register_call!(move |()| f()); + } } impl MockedRouter for Pallet { + type Hash = T::Hash; type Sender = T::AccountId; fn init() -> DispatchResult { @@ -38,6 +43,10 @@ pub mod pallet { fn send(sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo { execute_call!((sender, message)) } + + fn hash() -> Self::Hash { + execute_call!(()) + } } } @@ -68,11 +77,16 @@ impl RouterMock { ) { pallet::Pallet::::mock_send(f) } + + pub fn mock_hash(&self, f: impl Fn() -> as Router>::Hash + 'static) { + pallet::Pallet::::mock_hash(f) + } } /// Here we implement the actual Router trait for the `RouterMock` which in turn /// calls the `MockedRouter` trait implementation. impl Router for RouterMock { + type Hash = T::Hash; type Sender = T::AccountId; fn init(&self) -> DispatchResult { @@ -82,6 +96,10 @@ impl Router for RouterMock { fn send(&self, sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo { pallet::Pallet::::send(sender, message) } + + fn hash(&self) -> Self::Hash { + pallet::Pallet::::hash() + } } /// A mocked Router trait that emulates the actual Router trait but without @@ -94,9 +112,13 @@ trait MockedRouter { /// The sender type of the outbound message. type Sender; + type Hash; + /// Initialize the router. fn init() -> DispatchResult; /// Send the message to the router's destination. fn send(sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo; + + fn hash() -> Self::Hash; } diff --git a/libs/traits/src/liquidity_pools.rs b/libs/traits/src/liquidity_pools.rs index 9ee729e4c6..2dfd5d891b 100644 --- a/libs/traits/src/liquidity_pools.rs +++ b/libs/traits/src/liquidity_pools.rs @@ -18,6 +18,8 @@ use frame_support::{ use sp_runtime::DispatchError; use sp_std::vec::Vec; +pub type Proof = [u8; 32]; + /// An encoding & decoding trait for the purpose of meeting the /// LiquidityPools General Message Passing Format pub trait LPEncoding: Sized { @@ -34,6 +36,9 @@ pub trait LPEncoding: Sized { /// Creates an empty message. /// It's the identity message for composing messages with pack_with fn empty() -> Self; + + fn get_message_proof(&self) -> Option; + fn to_message_proof(&self) -> Self; } /// The trait required for sending outbound messages. @@ -41,11 +46,17 @@ pub trait Router { /// The sender type of the outbound message. type Sender; + /// The router hash type. + type Hash; + /// Initialize the router. fn init(&self) -> DispatchResult; /// Send the message to the router's destination. fn send(&self, sender: Self::Sender, message: Vec) -> DispatchResultWithPostInfo; + + /// Generate a hash for this router. + fn hash(&self) -> Self::Hash; } /// The trait required for queueing messages. diff --git a/pallets/liquidity-pools-gateway/Cargo.toml b/pallets/liquidity-pools-gateway/Cargo.toml index 5f01d96284..f5b0356569 100644 --- a/pallets/liquidity-pools-gateway/Cargo.toml +++ b/pallets/liquidity-pools-gateway/Cargo.toml @@ -35,6 +35,9 @@ cfg-utils = { workspace = true } [dev-dependencies] cfg-mocks = { workspace = true, default-features = true } sp-io = { workspace = true, default-features = true } +itertools = { workspace = true, default-features = true } +lazy_static = { workspace = true, default-features = true } +mock-builder = { workspace = true, default-features = true } [features] default = ["std"] diff --git a/pallets/liquidity-pools-gateway/routers/src/lib.rs b/pallets/liquidity-pools-gateway/routers/src/lib.rs index a731009fbb..df47238053 100644 --- a/pallets/liquidity-pools-gateway/routers/src/lib.rs +++ b/pallets/liquidity-pools-gateway/routers/src/lib.rs @@ -76,6 +76,7 @@ where OriginFor: From + Into>>, { + type Hash = T::Hash; type Sender = T::AccountId; fn init(&self) -> DispatchResult { @@ -89,6 +90,14 @@ where DomainRouter::AxelarEVM(r) => r.do_send(sender, message), } } + + fn hash(&self) -> Self::Hash { + match self { + DomainRouter::EthereumXCM(r) => r.hash(), + DomainRouter::AxelarEVM(r) => r.hash(), + DomainRouter::AxelarXCM(r) => r.hash(), + } + } } /// A generic router used for executing EVM calls. diff --git a/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs b/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs index 670bb8b4dd..d9d5f069a9 100644 --- a/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs +++ b/pallets/liquidity-pools-gateway/routers/src/routers/axelar_evm.rs @@ -21,7 +21,8 @@ use scale_info::{ prelude::{format, string::String}, TypeInfo, }; -use sp_core::{bounded::BoundedVec, ConstU32, H160}; +use sp_core::{bounded::BoundedVec, ConstU32, Hasher, H160}; +use sp_runtime::traits::BlakeTwo256; use sp_std::{collections::btree_map::BTreeMap, vec, vec::Vec}; use crate::{ @@ -77,6 +78,10 @@ where self.router.do_send(sender, eth_msg) } + + pub fn hash(&self) -> T::Hash { + BlakeTwo256::hash(self.evm_chain.encode().as_slice()) + } } /// Encodes the provided message into the format required for submitting it diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 2dfc01b85e..366169b1d6 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -40,7 +40,8 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; -use sp_std::convert::TryInto; +use sp_runtime::traits::EnsureAddAssign; +use sp_std::{cmp::Ordering, convert::TryInto, vec::Vec}; use crate::weights::WeightInfo; @@ -94,7 +95,7 @@ pub mod pallet { type Message: LPEncoding + Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec; /// The message router type that is stored for each domain. - type Router: DomainRouter + type Router: DomainRouter + Clone + Debug + MaxEncodedLen @@ -121,7 +122,13 @@ pub mod pallet { type Sender: Get; /// Type used for queueing messages. - type MessageQueue: MessageQueue>; + type MessageQueue: MessageQueue< + Message = GatewayMessage, + >; + + /// Number of routers for a domain. + #[pallet::constant] + type MultiRouterCount: Get; } #[pallet::event] @@ -141,6 +148,12 @@ pub mod pallet { domain: Domain, hook_address: [u8; 20], }, + + /// The routers for a given domain were set. + DomainMultiRouterSet { + domain: Domain, + routers: BoundedVec, + }, } /// Storage for domain routers. @@ -175,6 +188,34 @@ pub mod pallet { pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; + /// Storage for routers. + /// + /// This can only be set by an admin. + #[pallet::storage] + #[pallet::getter(fn routers)] + pub type Routers = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; + + /// Storage for domain multi-routers. + /// + /// This can only be set by an admin. + #[pallet::storage] + #[pallet::getter(fn domain_multi_routers)] + pub type DomainMultiRouters = + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + + /// Storage that keeps track of incoming message proofs. + #[pallet::storage] + #[pallet::getter(fn inbound_message_proof_count)] + pub type InboundMessageProofCount = + StorageMap<_, Blake2_128Concat, Proof, u32, ValueQuery>; + + /// Storage that keeps track of incoming messages and the expected proof + /// count. + #[pallet::storage] + #[pallet::getter(fn inbound_messages)] + pub type InboundMessages = + StorageMap<_, Blake2_128Concat, Proof, (DomainAddress, T::Message, u32)>; + #[pallet::error] pub enum Error { /// Router initialization failed. @@ -205,6 +246,18 @@ pub mod pallet { /// Emitted when you can `end_batch_message()` but the packing process /// was not started by `start_batch_message()`. MessagePackingNotStarted, + + /// Invalid multi router. + InvalidMultiRouter, + + /// Multi-router not found. + MultiRouterNotFound, + + /// Message proof cannot be retrieved. + MessageProofRetrieval, + + /// Recovery message not found. + RecoveryMessageNotFound, } #[pallet::call] @@ -290,7 +343,7 @@ pub mod pallet { Error::::UnknownInstance, ); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::::Inbound { domain_address: origin_address, message: T::Message::deserialize(&msg)?, }; @@ -301,7 +354,7 @@ pub mod pallet { /// Set the address of the domain hook /// /// Can only be called by `AdminOrigin`. - #[pallet::weight(T::WeightInfo::set_domain_router())] + #[pallet::weight(T::WeightInfo::set_domain_hook_address())] #[pallet::call_index(8)] pub fn set_domain_hook_address( origin: OriginFor, @@ -352,9 +405,136 @@ pub mod pallet { None => Err(Error::::MessagePackingNotStarted.into()), } } + + /// Set routers for a particular domain. + #[pallet::weight(T::WeightInfo::set_domain_multi_router())] + #[pallet::call_index(11)] + pub fn set_domain_multi_router( + origin: OriginFor, + domain: Domain, + routers: BoundedVec, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); + ensure!( + routers.len() == T::MultiRouterCount::get() as usize, + Error::::InvalidMultiRouter + ); + + let mut router_hashes = Vec::new(); + + for router in &routers { + router.init().map_err(|_| Error::::RouterInitFailed)?; + + let router_hash = router.hash(); + + router_hashes.push(router_hash); + + Routers::::insert(router_hash, router); + } + + >::insert( + domain.clone(), + BoundedVec::try_from(router_hashes).map_err(|_| Error::::InvalidMultiRouter)?, + ); + + Self::clear_storages_for_inbound_messages(); + + Self::deposit_event(Event::DomainMultiRouterSet { domain, routers }); + + Ok(()) + } + + /// Manually increase the proof count for a particular message and + /// executes it if the required count is reached. + /// + /// Can only be called by `AdminOrigin`. + #[pallet::weight(T::WeightInfo::execute_message_recovery())] + #[pallet::call_index(12)] + pub fn execute_message_recovery( + origin: OriginFor, + message_proof: Proof, + proof_count: u32, + ) -> DispatchResult { + //TODO(cdamian): Implement this. + unimplemented!() + } } impl Pallet { + fn clear_storages_for_inbound_messages() { + let _ = InboundMessages::::clear(u32::MAX, None); + let _ = InboundMessageProofCount::::clear(u32::MAX, None); + } + + //TODO(cdamian): Use safe math + fn get_expected_message_proof_count() -> u32 { + T::MultiRouterCount::get() - 1 + } + + /// Inserts a message and its expected proof count, or increases the + /// message proof count for a particular message. + fn get_proof_and_current_count( + domain_address: DomainAddress, + message: T::Message, + weight: &mut Weight, + ) -> Result<(Proof, u32), DispatchError> { + match message.get_message_proof() { + None => { + let message_proof = message + .to_message_proof() + .get_message_proof() + .expect("message proof ensured by 'to_message_proof'"); + + match InboundMessages::::try_mutate(message_proof, |storage_entry| { + match storage_entry { + None => { + *storage_entry = Some(( + domain_address, + message, + Self::get_expected_message_proof_count(), + )); + } + Some((_, _, expected_proof_count)) => { + // We already have a message, in this case we should expect another + // set of message proofs. + expected_proof_count + .ensure_add_assign(Self::get_expected_message_proof_count())?; + } + }; + + Ok(()) + }) { + Ok(_) => {} + Err(e) => return Err(e), + }; + + *weight = weight.saturating_add(T::DbWeight::get().reads_writes(1, 1)); + + Ok(( + message_proof, + InboundMessageProofCount::::get(message_proof), + )) + } + Some(message_proof) => { + let message_proof_count = + match InboundMessageProofCount::::try_mutate(message_proof, |count| { + count.ensure_add_assign(1)?; + + Ok(*count) + }) { + Ok(r) => r, + Err(e) => return Err(e), + }; + + *weight = weight.saturating_add(T::DbWeight::get().writes(1)); + + Ok((message_proof, message_proof_count)) + } + } + } + /// Give the message to the `InboundMessageHandler` to be processed. fn process_inbound_message( domain_address: DomainAddress, @@ -365,6 +545,69 @@ pub mod pallet { for submessage in message.submessages() { count += 1; + let (message_proof, mut current_message_proof_count) = + match Self::get_proof_and_current_count( + domain_address.clone(), + message.clone(), + &mut weight, + ) { + Ok(r) => r, + Err(e) => return (Err(e), weight), + }; + + let (_, message, mut total_expected_proof_count) = + match InboundMessages::::get(message_proof) { + None => return (Ok(()), weight), + Some(r) => r, + }; + + weight = weight.saturating_add(T::DbWeight::get().reads(1)); + + let expected_message_proof_count = Self::get_expected_message_proof_count(); + + match current_message_proof_count.cmp(&expected_message_proof_count) { + Ordering::Less => return (Ok(()), weight), + Ordering::Equal => { + InboundMessageProofCount::::remove(message_proof); + total_expected_proof_count -= expected_message_proof_count; + + if total_expected_proof_count == 0 { + InboundMessages::::remove(message_proof); + } else { + InboundMessages::::insert( + message_proof, + ( + domain_address.clone(), + message.clone(), + total_expected_proof_count, + ), + ); + } + } + Ordering::Greater => { + current_message_proof_count -= expected_message_proof_count; + InboundMessageProofCount::::insert( + message_proof, + current_message_proof_count, + ); + + total_expected_proof_count -= expected_message_proof_count; + + if total_expected_proof_count == 0 { + InboundMessages::::remove(message_proof); + } else { + InboundMessages::::insert( + message_proof, + ( + domain_address.clone(), + message.clone(), + total_expected_proof_count, + ), + ); + } + } + } + if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), submessage) { // We only consume the processed weight if error during the batch @@ -380,12 +623,12 @@ pub mod pallet { /// weight for these operations in the `DispatchResultWithPostInfo`. fn process_outbound_message( sender: T::AccountId, - domain: Domain, message: T::Message, + router_hash: T::Hash, ) -> (DispatchResult, Weight) { let read_weight = T::DbWeight::get().reads(1); - let Some(router) = DomainRouters::::get(domain) else { + let Some(router) = Routers::::get(router_hash) else { return (Err(Error::::RouterNotFound.into()), read_weight); }; @@ -398,15 +641,33 @@ pub mod pallet { } fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - // We are using the sender specified in the pallet config so that we can - // ensure that the account is funded - let gateway_message = GatewayMessage::::Outbound { - sender: T::Sender::get(), - destination, - message, - }; + let router_hashes = DomainMultiRouters::::get(destination.clone()) + .ok_or(Error::::MultiRouterNotFound)?; + + let message_proof = message.to_message_proof(); + let mut message_opt = Some(message); + + for router_hash in router_hashes { + // Ensure that we only send the actual message once, using one router. + // The remaining routers will send the message proof. + let router_msg = match message_opt.take() { + Some(m) => m, + None => message_proof.clone(), + }; + + // We are using the sender specified in the pallet config so that we can + // ensure that the account is funded + let gateway_message = + GatewayMessage::::Outbound { + sender: T::Sender::get(), + message: router_msg, + router_hash, + }; + + T::MessageQueue::submit(gateway_message)?; + } - T::MessageQueue::submit(gateway_message) + Ok(()) } } @@ -439,7 +700,7 @@ pub mod pallet { } impl MessageProcessor for Pallet { - type Message = GatewayMessage; + type Message = GatewayMessage; fn process(msg: Self::Message) -> (DispatchResult, Weight) { match msg { @@ -449,9 +710,9 @@ pub mod pallet { } => Self::process_inbound_message(domain_address, message), GatewayMessage::Outbound { sender, - destination, message, - } => Self::process_outbound_message(sender, destination, message), + router_hash, + } => Self::process_outbound_message(sender, message, router_hash), } } diff --git a/pallets/liquidity-pools-gateway/src/message.rs b/pallets/liquidity-pools-gateway/src/message.rs index cf0bbb1a17..a2f568c23e 100644 --- a/pallets/liquidity-pools-gateway/src/message.rs +++ b/pallets/liquidity-pools-gateway/src/message.rs @@ -3,19 +3,19 @@ use frame_support::pallet_prelude::{Decode, Encode, MaxEncodedLen, TypeInfo}; /// Message type used by the LP gateway. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] -pub enum GatewayMessage { +pub enum GatewayMessage { Inbound { domain_address: DomainAddress, message: Message, }, Outbound { sender: AccountId, - destination: Domain, message: Message, + router_hash: Hash, }, } -impl Default for GatewayMessage { +impl Default for GatewayMessage { fn default() -> Self { GatewayMessage::Inbound { domain_address: Default::default(), diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 348ccf8d91..662de1ff60 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -1,8 +1,10 @@ +use std::fmt::{Debug, Formatter}; + use cfg_mocks::{ pallet_mock_liquidity_pools, pallet_mock_liquidity_pools_gateway_queue, pallet_mock_routers, RouterMock, }; -use cfg_traits::liquidity_pools::LPEncoding; +use cfg_traits::liquidity_pools::{LPEncoding, Proof}; use cfg_types::domain_address::DomainAddress; use frame_support::{derive_impl, weights::constants::RocksDbWeight}; use frame_system::EnsureRoot; @@ -18,11 +20,24 @@ pub const LP_ADMIN_ACCOUNT: AccountId32 = AccountId32::new([u8::MAX; 32]); pub const MAX_PACKED_MESSAGES_ERR: &str = "packed limit error"; pub const MAX_PACKED_MESSAGES: usize = 10; -#[derive(Default, Debug, Eq, PartialEq, Clone, Encode, Decode, TypeInfo)] +pub const MESSAGE_PROOF: [u8; 32] = [1; 32]; + +#[derive(Default, Eq, PartialEq, Clone, Encode, Decode, TypeInfo, Hash)] pub enum Message { #[default] Simple, Pack(Vec), + Proof([u8; 32]), +} + +impl Debug for Message { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Message::Simple => write!(f, "Simple"), + Message::Pack(p) => write!(f, "Pack - {:?}", p), + Message::Proof(_) => write!(f, "Proof"), + } + } } /// Avoiding automatic infinity loop with the MaxEncodedLen derive @@ -35,8 +50,8 @@ impl MaxEncodedLen for Message { impl LPEncoding for Message { fn serialize(&self) -> Vec { match self { - Self::Simple => vec![0x42], Self::Pack(list) => list.iter().map(|_| 0x42).collect(), + _ => vec![0x42], } } @@ -50,10 +65,6 @@ impl LPEncoding for Message { fn pack_with(&mut self, other: Self) -> DispatchResult { match self { - Self::Simple => { - *self = Self::Pack(vec![Self::Simple, other]); - Ok(()) - } Self::Pack(list) if list.len() == MAX_PACKED_MESSAGES => { Err(MAX_PACKED_MESSAGES_ERR.into()) } @@ -61,19 +72,37 @@ impl LPEncoding for Message { list.push(other); Ok(()) } + _ => { + *self = Self::Pack(vec![self.clone(), other]); + Ok(()) + } } } fn submessages(&self) -> Vec { match self { - Self::Simple => vec![Self::Simple], Self::Pack(list) => list.clone(), + _ => vec![self.clone()], } } fn empty() -> Self { Self::Pack(vec![]) } + + fn get_message_proof(&self) -> Option { + match self { + Message::Proof(p) => Some(p.clone()), + _ => None, + } + } + + fn to_message_proof(&self) -> Self { + match self { + Message::Proof(_) => self.clone(), + _ => Message::Proof(MESSAGE_PROOF), + } + } } frame_support::construct_runtime!( @@ -102,13 +131,14 @@ impl pallet_mock_liquidity_pools::Config for Runtime { impl pallet_mock_routers::Config for Runtime {} impl pallet_mock_liquidity_pools_gateway_queue::Config for Runtime { - type Message = GatewayMessage; + type Message = GatewayMessage; } frame_support::parameter_types! { pub Sender: AccountId32 = AccountId32::from(H256::from_low_u64_be(1).to_fixed_bytes()); pub const MaxIncomingMessageSize: u32 = 1024; pub const LpAdminAccount: AccountId32 = LP_ADMIN_ACCOUNT; + pub const MultiRouterCount: u32 = 3; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -118,6 +148,7 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type MaxIncomingMessageSize = MaxIncomingMessageSize; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; + type MultiRouterCount = MultiRouterCount; type Router = RouterMock; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 13afa3a3bc..283d49f81f 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,12 +1,16 @@ +use std::collections::HashMap; + use cfg_mocks::*; -use cfg_primitives::LP_DEFENSIVE_WEIGHT; -use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler}; +use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler, Proof}; use cfg_types::domain_address::*; use frame_support::{ assert_err, assert_noop, assert_ok, dispatch::PostDispatchInfo, pallet_prelude::Pays, weights::Weight, }; -use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160}; +use itertools::Itertools; +use lazy_static::lazy_static; +use parity_scale_codec::MaxEncodedLen; +use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160, H256}; use sp_runtime::{DispatchError, DispatchError::BadOrigin, DispatchErrorWithPostInfo}; use sp_std::sync::{ atomic::{AtomicU32, Ordering}, @@ -289,7 +293,7 @@ mod receive_message_domain { let encoded_msg = message.serialize(); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), }; @@ -370,7 +374,7 @@ mod receive_message_domain { let err = sp_runtime::DispatchError::from("liquidity_pools error"); - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), }; @@ -400,24 +404,47 @@ mod outbound_message_handler_impl { let domain = Domain::EVM(0); let sender = get_test_account_id(); let msg = Message::Simple; + let message_proof = msg.to_message_proof().get_message_proof().unwrap(); - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); + let router_hash_1 = H256::from_low_u64_be(1); + let router_hash_2 = H256::from_low_u64_be(2); + let router_hash_3 = H256::from_low_u64_be(3); - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let router_mock_1 = RouterMock::::default(); + let router_mock_2 = RouterMock::::default(); + let router_mock_3 = RouterMock::::default(); + + router_mock_1.mock_init(move || Ok(())); + router_mock_1.mock_hash(move || router_hash_1); + router_mock_2.mock_init(move || Ok(())); + router_mock_2.mock_hash(move || router_hash_2); + router_mock_3.mock_init(move || Ok(())); + router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_multi_router( RuntimeOrigin::root(), domain.clone(), - router.clone(), + BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), )); - let gateway_message = GatewayMessage::::Outbound { - sender: ::Sender::get(), - destination: domain.clone(), - message: msg.clone(), - }; - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { - assert_eq!(mock_msg, gateway_message); + match mock_msg { + GatewayMessage::Inbound { .. } => { + assert!(false, "expected outbound message") + } + GatewayMessage::Outbound { + sender, message, .. + } => { + assert_eq!(sender, ::Sender::get()); + + match message { + Message::Proof(p) => { + assert_eq!(p, message_proof); + } + _ => {} + } + } + } Ok(()) }); @@ -447,19 +474,31 @@ mod outbound_message_handler_impl { let sender = get_test_account_id(); let msg = Message::Simple; - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); + let router_hash_1 = H256::from_low_u64_be(1); + let router_hash_2 = H256::from_low_u64_be(2); + let router_hash_3 = H256::from_low_u64_be(3); - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let router_mock_1 = RouterMock::::default(); + let router_mock_2 = RouterMock::::default(); + let router_mock_3 = RouterMock::::default(); + + router_mock_1.mock_init(move || Ok(())); + router_mock_1.mock_hash(move || router_hash_1); + router_mock_2.mock_init(move || Ok(())); + router_mock_2.mock_hash(move || router_hash_2); + router_mock_3.mock_init(move || Ok(())); + router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_multi_router( RuntimeOrigin::root(), domain.clone(), - router.clone(), + BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), )); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender: ::Sender::get(), - destination: domain.clone(), message: msg.clone(), + router_hash: router_hash_3, }; let err = DispatchError::Unavailable; @@ -530,34 +569,634 @@ mod message_processor_impl { mod inbound { use super::*; + #[macro_use] + mod util { + use super::*; + + macro_rules! run_tests { + ($tests:expr) => { + // $tests = Vec<(Vec, &ExpectedTestResult)> + for test in $tests { + new_test_ext().execute_with(|| { + println!("Executing test for - {:?}", test.0); + + let handler = MockLiquidityPools::mock_handle(move |_, _| Ok(())); + + // test.0 = Vec + for test_message in test.0 { + let domain_address = DomainAddress::EVM(1, [1; 20]); + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: test_message.clone(), + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + } + + assert_eq!(handler.times(), test.1.mock_called_times); + + assert_eq!( + InboundMessages::::get(MESSAGE_PROOF), + // test.1 = &ExpectedTestResult + test.1.inbound_message, + ); + assert_eq!( + InboundMessageProofCount::::get(MESSAGE_PROOF), + // test.1 = &ExpectedTestResult + test.1.proof_count, + ); + }); + } + }; + } + + lazy_static! { + static ref TEST_MESSAGES: Vec = + vec![Message::Simple, Message::Proof(MESSAGE_PROOF),]; + } + + /// Generate all `Message` combinations for a specific + /// number of messages, like: + /// + /// vec![ + /// Message::Simple, + /// Message::Simple, + /// ] + /// vec![ + /// Message::Simple, + /// Message::Proof(MESSAGE_PROOF), + /// ] + /// vec![ + /// Message::Proof(MESSAGE_PROOF), + /// Message::Simple, + /// ] + /// vec![ + /// Message::Proof(MESSAGE_PROOF), + /// Message::Proof(MESSAGE_PROOF), + /// ] + pub fn generate_test_combinations(count: usize) -> Vec> { + std::iter::repeat(TEST_MESSAGES.clone().into_iter()) + .take(count) + .multi_cartesian_product() + .collect::>() + } + + pub struct ExpectedTestResult { + pub inbound_message: Option<(DomainAddress, Message, u32)>, + pub proof_count: u32, + pub mock_called_times: u32, + } + } + + use util::*; + + mod combined_messages { + use super::*; + + mod two_messages { + use super::*; + + lazy_static! { + static ref TEST_MAP: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![Message::Simple, Message::Simple], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 4 + )), + proof_count: 0, + mock_called_times: 0, + } + ), + ( + vec![Message::Proof(MESSAGE_PROOF), Message::Proof(MESSAGE_PROOF)], + ExpectedTestResult { + inbound_message: None, + proof_count: 2, + mock_called_times: 0, + } + ), + ( + vec![Message::Simple, Message::Proof(MESSAGE_PROOF)], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![Message::Proof(MESSAGE_PROOF), Message::Simple], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ]); + } + + #[test] + fn two_messages() { + let tests = generate_test_combinations(2) + .iter() + .map(|x| { + ( + x.clone(), + TEST_MAP + .get(x) + .expect(format!("test for {x:?} should be covered").as_str()), + ) + }) + .collect::>(); + + run_tests!(tests); + } + } + + mod three_messages { + use super::*; + + lazy_static! { + static ref TEST_MAP: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![Message::Simple, Message::Simple, Message::Simple,], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 6 + )), + proof_count: 0, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 3, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 4 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 4 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 4 + )), + proof_count: 1, + mock_called_times: 0, + } + ) + ]); + } + + #[test] + fn three_messages() { + let tests = generate_test_combinations(3) + .iter() + .map(|x| { + ( + x.clone(), + TEST_MAP + .get(x) + .expect(format!("test for {x:?} should be covered").as_str()), + ) + }) + .collect::>(); + + run_tests!(tests); + } + } + + mod four_messages { + use super::*; + + lazy_static! { + static ref TEST_MAP: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + Message::Simple, + Message::Simple, + Message::Simple, + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 8 + )), + proof_count: 0, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 4, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 1, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 1, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 1, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: None, + proof_count: 1, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Simple, + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 2 + )), + proof_count: 0, + mock_called_times: 1, + } + ), + ( + vec![ + Message::Simple, + Message::Simple, + Message::Simple, + Message::Proof(MESSAGE_PROOF), + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 6 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 6 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Simple, + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 6 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ( + vec![ + Message::Proof(MESSAGE_PROOF), + Message::Simple, + Message::Simple, + Message::Simple, + ], + ExpectedTestResult { + inbound_message: Some(( + DomainAddress::EVM(1, [1; 20]), + Message::Simple, + 6 + )), + proof_count: 1, + mock_called_times: 0, + } + ), + ]); + } + + #[test] + fn four_messages() { + let tests = generate_test_combinations(4) + .iter() + .filter(|x| TEST_MAP.get(x.clone()).is_some()) + .map(|x| { + ( + x.clone(), + TEST_MAP + .get(x) + .expect(format!("test for {x:?} should be covered").as_str()), + ) + }) + .collect::>(); + + run_tests!(tests); + } + } + } + #[test] - fn success() { + fn two_non_proof_and_four_proofs() { + let tests = generate_test_combinations(6) + .into_iter() + .filter(|x| { + let r = x.iter().counts_by(|c| c.clone()); + let non_proof_count = r.get(&Message::Simple); + let proof_count = r.get(&Message::Proof(MESSAGE_PROOF)); + + match (non_proof_count, proof_count) { + (Some(non_proof_count), Some(proof_count)) => { + *non_proof_count == 2 && *proof_count == 4 + } + _ => false, + } + }) + .map(|x| { + ( + x, + ExpectedTestResult { + inbound_message: None, + proof_count: 0, + mock_called_times: 2, + }, + ) + }) + .collect::>(); + + run_tests!(tests); + } + + #[test] + fn inbound_message_handler_error() { new_test_ext().execute_with(|| { let domain_address = DomainAddress::EVM(1, [1; 20]); - let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { + + let message = Message::Proof(MESSAGE_PROOF); + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), }; - MockLiquidityPools::mock_handle(move |mock_domain_address, mock_mesage| { - assert_eq!(mock_domain_address, domain_address); - assert_eq!(mock_mesage, message); + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); - Ok(()) - }); + let message = Message::Proof(MESSAGE_PROOF); + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); assert_ok!(res); - }); - } - #[test] - fn inbound_message_handler_error() { - new_test_ext().execute_with(|| { - let domain_address = DomainAddress::EVM(1, [1; 20]); let message = Message::Simple; - let gateway_message = GatewayMessage::::Inbound { + let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), }; @@ -596,6 +1235,8 @@ mod message_processor_impl { pays_fee: Pays::Yes, }; + let router_hash = H256::from_low_u64_be(1); + let router_mock = RouterMock::::default(); router_mock.mock_send(move |mock_sender, mock_message| { assert_eq!(mock_sender, expected_sender); @@ -603,6 +1244,7 @@ mod message_processor_impl { Ok(router_post_info) }); + router_mock.mock_hash(move || router_hash); DomainRouters::::insert(domain.clone(), router_mock); @@ -610,10 +1252,10 @@ mod message_processor_impl { .reads(1) + router_post_info.actual_weight.unwrap() + Weight::from_parts(0, message.serialize().len() as u64); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender, - destination: domain, message: message.clone(), + router_hash, }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); @@ -626,15 +1268,14 @@ mod message_processor_impl { fn router_not_found() { new_test_ext().execute_with(|| { let sender = get_test_account_id(); - let domain = Domain::EVM(1); let message = Message::Simple; let expected_weight = ::DbWeight::get().reads(1); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender, - destination: domain, message, + router_hash: H256::from_low_u64_be(1), }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); @@ -677,10 +1318,10 @@ mod message_processor_impl { .reads(1) + router_post_info.actual_weight.unwrap() + Weight::from_parts(0, message.serialize().len() as u64); - let gateway_message = GatewayMessage::::Outbound { + let gateway_message = GatewayMessage::Outbound { sender, - destination: domain, message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index b1ac9ed578..ce5650e801 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -23,6 +23,9 @@ pub trait WeightInfo { fn process_failed_outbound_message() -> Weight; fn start_batch_message() -> Weight; fn end_batch_message() -> Weight; + fn set_domain_hook_address() -> Weight; + fn set_domain_multi_router() -> Weight; + fn execute_message_recovery() -> Weight; } // NOTE: We use temporary weights here. `execute_epoch` is by far our heaviest @@ -146,4 +149,37 @@ impl WeightInfo for () { .saturating_add(RocksDbWeight::get().reads(2)) .saturating_add(RocksDbWeight::get().writes(2)) } + + fn set_domain_hook_address() -> Weight { + // TODO: BENCHMARK CORRECTLY + // + // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` + // This one has one read and one write for sure and possible one + // read for `AdminOrigin` + Weight::from_parts(30_117_000, 5991) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) + } + + fn set_domain_multi_router() -> Weight { + // TODO: BENCHMARK CORRECTLY + // + // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` + // This one has one read and one write for sure and possible one + // read for `AdminOrigin` + Weight::from_parts(30_117_000, 5991) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) + } + + fn execute_message_recovery() -> Weight { + // TODO: BENCHMARK CORRECTLY + // + // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` + // This one has one read and one write for sure and possible one + // read for `AdminOrigin` + Weight::from_parts(30_117_000, 5991) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) + } } diff --git a/pallets/liquidity-pools/src/message.rs b/pallets/liquidity-pools/src/message.rs index d981abfb16..d4ae5d5f4e 100644 --- a/pallets/liquidity-pools/src/message.rs +++ b/pallets/liquidity-pools/src/message.rs @@ -5,7 +5,10 @@ //! also have a custom GMPF implementation, aiming for a fixed-size encoded //! representation for each message variant. -use cfg_traits::{liquidity_pools::LPEncoding, Seconds}; +use cfg_traits::{ + liquidity_pools::{LPEncoding, Proof}, + Seconds, +}; use cfg_types::domain_address::Domain; use frame_support::{pallet_prelude::RuntimeDebug, BoundedVec}; use parity_scale_codec::{Decode, Encode, MaxEncodedLen}; @@ -15,7 +18,7 @@ use serde::{ ser::{Error as _, SerializeTuple}, Deserialize, Serialize, Serializer, }; -use sp_core::U256; +use sp_core::{keccak_256, U256}; use sp_runtime::{traits::ConstU32, DispatchError, DispatchResult}; use sp_std::{vec, vec::Vec}; @@ -558,6 +561,19 @@ impl LPEncoding for Message { fn empty() -> Message { Message::Batch(BatchMessages::default()) } + + fn get_message_proof(&self) -> Option { + match self { + Message::MessageProof { hash } => Some(hash.clone()), + _ => None, + } + } + + fn to_message_proof(&self) -> Self { + let hash = keccak_256(&LPEncoding::serialize(self)); + + Message::MessageProof { hash } + } } /// A Liquidity Pool message for updating restrictions on foreign domains. From e6057161bef2f6352b7949f595313bb9c559a97a Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Thu, 8 Aug 2024 15:22:44 +0300 Subject: [PATCH 02/11] lp-gateway: Add router hash for inbound messages, extrinsic to set inbound routers, session ID storage --- Cargo.lock | 1 + pallets/liquidity-pools-gateway/Cargo.toml | 2 + pallets/liquidity-pools-gateway/src/lib.rs | 70 ++++++++++++++++--- .../liquidity-pools-gateway/src/message.rs | 6 +- pallets/liquidity-pools-gateway/src/mock.rs | 1 + pallets/liquidity-pools-gateway/src/tests.rs | 11 ++- .../liquidity-pools-gateway/src/weights.rs | 16 ++++- 7 files changed, 93 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9bb0ef1336..b5dfd2b257 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8269,6 +8269,7 @@ dependencies = [ "orml-traits", "parity-scale-codec", "scale-info", + "sp-arithmetic", "sp-core", "sp-io", "sp-runtime", diff --git a/pallets/liquidity-pools-gateway/Cargo.toml b/pallets/liquidity-pools-gateway/Cargo.toml index f5b0356569..e2caa19652 100644 --- a/pallets/liquidity-pools-gateway/Cargo.toml +++ b/pallets/liquidity-pools-gateway/Cargo.toml @@ -22,6 +22,7 @@ scale-info = { workspace = true } sp-core = { workspace = true } sp-runtime = { workspace = true } sp-std = { workspace = true } +sp-arithmetic = { workspace = true } # Benchmarking frame-benchmarking = { workspace = true, optional = true } @@ -56,6 +57,7 @@ std = [ "cfg-utils/std", "hex/std", "cfg-primitives/std", + "sp-arithmetic/std", ] try-runtime = [ "cfg-traits/try-runtime", diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 366169b1d6..64c0dff9bf 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -40,6 +40,7 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; +use sp_arithmetic::traits::BaseArithmetic; use sp_runtime::traits::EnsureAddAssign; use sp_std::{cmp::Ordering, convert::TryInto, vec::Vec}; @@ -129,6 +130,16 @@ pub mod pallet { /// Number of routers for a domain. #[pallet::constant] type MultiRouterCount: Get; + + /// Type for identifying sessions of inbound routers. + type SessionId: Parameter + + Member + + BaseArithmetic + + Default + + Copy + + MaybeSerializeDeserialize + + TypeInfo + + MaxEncodedLen; } #[pallet::event] @@ -149,11 +160,16 @@ pub mod pallet { hook_address: [u8; 20], }, - /// The routers for a given domain were set. - DomainMultiRouterSet { + /// The outbound routers for a given domain were set. + OutboundRoutersSet { domain: Domain, routers: BoundedVec, }, + + /// Inbound routers were set. + InboundRoutersSet { + router_hashes: BoundedVec, + }, } /// Storage for domain routers. @@ -216,6 +232,14 @@ pub mod pallet { pub type InboundMessages = StorageMap<_, Blake2_128Concat, Proof, (DomainAddress, T::Message, u32)>; + #[pallet::storage] + #[pallet::getter(fn inbound_routers)] + pub type InboundRouters = + StorageValue<_, BoundedVec, ValueQuery>; + + #[pallet::storage] + pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; + #[pallet::error] pub enum Error { /// Router initialization failed. @@ -346,6 +370,8 @@ pub mod pallet { let gateway_message = GatewayMessage::::Inbound { domain_address: origin_address, message: T::Message::deserialize(&msg)?, + //TODO(cdamian): Use an actual router hash. + router_hash: T::Hash::default(), }; T::MessageQueue::submit(gateway_message) @@ -406,10 +432,10 @@ pub mod pallet { } } - /// Set routers for a particular domain. - #[pallet::weight(T::WeightInfo::set_domain_multi_router())] + /// Set outbound routers for a particular domain. + #[pallet::weight(T::WeightInfo::set_outbound_routers())] #[pallet::call_index(11)] - pub fn set_domain_multi_router( + pub fn set_outbound_routers( origin: OriginFor, domain: Domain, routers: BoundedVec, @@ -439,9 +465,33 @@ pub mod pallet { BoundedVec::try_from(router_hashes).map_err(|_| Error::::InvalidMultiRouter)?, ); - Self::clear_storages_for_inbound_messages(); + Self::deposit_event(Event::OutboundRoutersSet { domain, routers }); - Self::deposit_event(Event::DomainMultiRouterSet { domain, routers }); + Ok(()) + } + + /// Set inbound routers. + #[pallet::weight(T::WeightInfo::set_inbound_routers())] + #[pallet::call_index(12)] + pub fn set_inbound_routers( + origin: OriginFor, + router_hashes: BoundedVec, + ) -> DispatchResult { + T::AdminOrigin::ensure_origin(origin)?; + + ensure!( + router_hashes.len() == T::MultiRouterCount::get() as usize, + Error::::InvalidMultiRouter + ); + + SessionIdStore::::try_mutate(|n| { + n.ensure_add_assign(One::one())?; + Ok::<(), DispatchError>(()) + })?; + + InboundRouters::::set(router_hashes.clone()); + + Self::deposit_event(Event::InboundRoutersSet { router_hashes }); Ok(()) } @@ -451,7 +501,7 @@ pub mod pallet { /// /// Can only be called by `AdminOrigin`. #[pallet::weight(T::WeightInfo::execute_message_recovery())] - #[pallet::call_index(12)] + #[pallet::call_index(13)] pub fn execute_message_recovery( origin: OriginFor, message_proof: Proof, @@ -539,6 +589,7 @@ pub mod pallet { fn process_inbound_message( domain_address: DomainAddress, message: T::Message, + router_hash: T::Hash, ) -> (DispatchResult, Weight) { let mut count = 0; @@ -707,7 +758,8 @@ pub mod pallet { GatewayMessage::Inbound { domain_address, message, - } => Self::process_inbound_message(domain_address, message), + router_hash, + } => Self::process_inbound_message(domain_address, message, router_hash), GatewayMessage::Outbound { sender, message, diff --git a/pallets/liquidity-pools-gateway/src/message.rs b/pallets/liquidity-pools-gateway/src/message.rs index a2f568c23e..42226a46eb 100644 --- a/pallets/liquidity-pools-gateway/src/message.rs +++ b/pallets/liquidity-pools-gateway/src/message.rs @@ -7,6 +7,7 @@ pub enum GatewayMessage { Inbound { domain_address: DomainAddress, message: Message, + router_hash: Hash, }, Outbound { sender: AccountId, @@ -15,11 +16,14 @@ pub enum GatewayMessage { }, } -impl Default for GatewayMessage { +impl Default + for GatewayMessage +{ fn default() -> Self { GatewayMessage::Inbound { domain_address: Default::default(), message: Default::default(), + router_hash: Default::default(), } } } diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 662de1ff60..6406b673f8 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -153,6 +153,7 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; + type SessionId = u64; type WeightInfo = (); } diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 283d49f81f..3c9f056071 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -296,6 +296,7 @@ mod receive_message_domain { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -377,6 +378,7 @@ mod receive_message_domain { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -421,7 +423,7 @@ mod outbound_message_handler_impl { router_mock_3.mock_init(move || Ok(())); router_mock_3.mock_hash(move || router_hash_3); - assert_ok!(LiquidityPoolsGateway::set_domain_multi_router( + assert_ok!(LiquidityPoolsGateway::set_outbound_routers( RuntimeOrigin::root(), domain.clone(), BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), @@ -489,7 +491,7 @@ mod outbound_message_handler_impl { router_mock_3.mock_init(move || Ok(())); router_mock_3.mock_hash(move || router_hash_3); - assert_ok!(LiquidityPoolsGateway::set_domain_multi_router( + assert_ok!(LiquidityPoolsGateway::set_outbound_routers( RuntimeOrigin::root(), domain.clone(), BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), @@ -588,6 +590,8 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: test_message.clone(), + //TODO(cdamian): Use test router hash. + router_hash: H256::from_low_u64_be(1), }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -1181,6 +1185,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -1190,6 +1195,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -1199,6 +1205,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), + router_hash: H256::from_low_u64_be(1), }; let err = DispatchError::Unavailable; diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index ce5650e801..c568dfc4bf 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -24,8 +24,9 @@ pub trait WeightInfo { fn start_batch_message() -> Weight; fn end_batch_message() -> Weight; fn set_domain_hook_address() -> Weight; - fn set_domain_multi_router() -> Weight; + fn set_outbound_routers() -> Weight; fn execute_message_recovery() -> Weight; + fn set_inbound_routers() -> Weight; } // NOTE: We use temporary weights here. `execute_epoch` is by far our heaviest @@ -161,7 +162,7 @@ impl WeightInfo for () { .saturating_add(RocksDbWeight::get().writes(2)) } - fn set_domain_multi_router() -> Weight { + fn set_outbound_routers() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -182,4 +183,15 @@ impl WeightInfo for () { .saturating_add(RocksDbWeight::get().reads(2)) .saturating_add(RocksDbWeight::get().writes(2)) } + + fn set_inbound_routers() -> Weight { + // TODO: BENCHMARK CORRECTLY + // + // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` + // This one has one read and one write for sure and possible one + // read for `AdminOrigin` + Weight::from_parts(30_117_000, 5991) + .saturating_add(RocksDbWeight::get().reads(2)) + .saturating_add(RocksDbWeight::get().writes(2)) + } } From a9830aa138287a47f8434b3ac157be47c6f98c18 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Fri, 9 Aug 2024 13:32:32 +0300 Subject: [PATCH 03/11] lp-gateway: Use router hashes for inbound, use session ID, update inbound message processing logic --- pallets/liquidity-pools-gateway/src/lib.rs | 482 +++++++++++++++------ 1 file changed, 341 insertions(+), 141 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 64c0dff9bf..169c7ab601 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -59,6 +59,20 @@ mod mock; #[cfg(test)] mod tests; +/// Type that stores the information required when processing inbound messages. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub enum InboundEntry { + Message { + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + }, + Proof { + current_count: u32, + }, +} + #[frame_support::pallet] pub mod pallet { use super::*; @@ -168,6 +182,7 @@ pub mod pallet { /// Inbound routers were set. InboundRoutersSet { + domain: Domain, router_hashes: BoundedVec, }, } @@ -204,39 +219,48 @@ pub mod pallet { pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; - /// Storage for routers. + /// Storage for outbound routers. /// /// This can only be set by an admin. #[pallet::storage] #[pallet::getter(fn routers)] - pub type Routers = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; + pub type OutboundRouters = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; - /// Storage for domain multi-routers. + /// Storage for outbound routers specific for a domain. /// /// This can only be set by an admin. #[pallet::storage] - #[pallet::getter(fn domain_multi_routers)] - pub type DomainMultiRouters = + #[pallet::getter(fn outbound_domain_routers)] + pub type OutboundDomainRouters = StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; - /// Storage that keeps track of incoming message proofs. - #[pallet::storage] - #[pallet::getter(fn inbound_message_proof_count)] - pub type InboundMessageProofCount = - StorageMap<_, Blake2_128Concat, Proof, u32, ValueQuery>; - - /// Storage that keeps track of incoming messages and the expected proof - /// count. + /// Storage for pending inbound messages. #[pallet::storage] - #[pallet::getter(fn inbound_messages)] - pub type InboundMessages = - StorageMap<_, Blake2_128Concat, Proof, (DomainAddress, T::Message, u32)>; - + #[pallet::getter(fn pending_inbound_entries)] + pub type PendingInboundEntries = StorageDoubleMap< + _, + Blake2_128Concat, + T::SessionId, + Blake2_128Concat, + (Proof, T::Hash), + InboundEntry, + >; + + /// Storage for inbound routers specific for a domain. + /// + /// This can only be set by an admin. #[pallet::storage] #[pallet::getter(fn inbound_routers)] pub type InboundRouters = - StorageValue<_, BoundedVec, ValueQuery>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + + /// Storage for the session ID of an inbound domain. + #[pallet::storage] + #[pallet::getter(fn inbound_domain_sessions)] + pub type InboundDomainSessions = + StorageMap<_, Blake2_128Concat, Domain, T::SessionId>; + /// Storage for inbound router session IDs. #[pallet::storage] pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; @@ -457,10 +481,10 @@ pub mod pallet { router_hashes.push(router_hash); - Routers::::insert(router_hash, router); + OutboundRouters::::insert(router_hash, router); } - >::insert( + >::insert( domain.clone(), BoundedVec::try_from(router_hashes).map_err(|_| Error::::InvalidMultiRouter)?, ); @@ -475,6 +499,7 @@ pub mod pallet { #[pallet::call_index(12)] pub fn set_inbound_routers( origin: OriginFor, + domain: Domain, router_hashes: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; @@ -484,14 +509,18 @@ pub mod pallet { Error::::InvalidMultiRouter ); - SessionIdStore::::try_mutate(|n| { + let session_id = SessionIdStore::::try_mutate(|n| { n.ensure_add_assign(One::one())?; - Ok::<(), DispatchError>(()) + Ok::(*n) })?; - InboundRouters::::set(router_hashes.clone()); + InboundRouters::::insert(domain.clone(), router_hashes.clone()); + InboundDomainSessions::::insert(domain.clone(), session_id); - Self::deposit_event(Event::InboundRoutersSet { router_hashes }); + Self::deposit_event(Event::InboundRoutersSet { + domain, + router_hashes, + }); Ok(()) } @@ -513,76 +542,261 @@ pub mod pallet { } impl Pallet { - fn clear_storages_for_inbound_messages() { - let _ = InboundMessages::::clear(u32::MAX, None); - let _ = InboundMessageProofCount::::clear(u32::MAX, None); - } - //TODO(cdamian): Use safe math fn get_expected_message_proof_count() -> u32 { T::MultiRouterCount::get() - 1 } - /// Inserts a message and its expected proof count, or increases the - /// message proof count for a particular message. - fn get_proof_and_current_count( + fn get_message_proof(message: T::Message) -> Proof { + match message.get_message_proof() { + None => message + .to_message_proof() + .get_message_proof() + .expect("message proof ensured by 'to_message_proof'"), + Some(proof) => proof, + } + } + + fn create_inbound_entry( domain_address: DomainAddress, message: T::Message, - weight: &mut Weight, - ) -> Result<(Proof, u32), DispatchError> { + ) -> InboundEntry { match message.get_message_proof() { - None => { - let message_proof = message - .to_message_proof() - .get_message_proof() - .expect("message proof ensured by 'to_message_proof'"); - - match InboundMessages::::try_mutate(message_proof, |storage_entry| { - match storage_entry { - None => { - *storage_entry = Some(( - domain_address, - message, - Self::get_expected_message_proof_count(), - )); + None => InboundEntry::Message { + domain_address, + message, + expected_proof_count: Self::get_expected_message_proof_count(), + }, + Some(_) => InboundEntry::Proof { current_count: 1 }, + } + } + + /// Validation ensures that: + /// + /// - the router that sent the inbound message is a valid router for the + /// specific domain. + /// - messages are only sent by the first inbound router. + /// - proofs are not sent by the first inbound router. + fn validate_inbound_entry( + domain: Domain, + router_hash: T::Hash, + inbound_entry: &InboundEntry, + ) -> DispatchResult { + let inbound_routers = + //TODO(cdamian): Add new error + InboundRouters::::get(domain).ok_or(Error::::InvalidMultiRouter)?; + + ensure!( + inbound_routers.iter().any(|x| x == &router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + match inbound_entry { + InboundEntry::Message { .. } => { + ensure!( + inbound_routers.get(0) == Some(&router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + Ok(()) + } + InboundEntry::Proof { .. } => { + ensure!( + inbound_routers.get(0) != Some(&router_hash), + //TODO(cdamian): Add error + Error::::InvalidMultiRouter + ); + + Ok(()) + } + } + } + + fn update_storage_entry(old: &mut InboundEntry, new: InboundEntry) -> DispatchResult { + match old { + InboundEntry::Message { + expected_proof_count, + .. + } => match new { + InboundEntry::Message { .. } => { + expected_proof_count + .ensure_add_assign(Self::get_expected_message_proof_count())?; + + Ok(()) + } + //TODO(cdamian): Update error + InboundEntry::Proof { .. } => Err(Error::::InvalidMultiRouter.into()), + }, + InboundEntry::Proof { current_count } => match new { + InboundEntry::Proof { .. } => { + current_count.ensure_add_assign(1)?; + + Ok(()) + } + //TODO(cdamian): Update error + InboundEntry::Message { .. } => Err(Error::::InvalidMultiRouter.into()), + }, + } + } + + fn update_pending_entry( + session_id: T::SessionId, + message_proof: Proof, + router_hash: T::Hash, + inbound_entry: InboundEntry, + weight: &mut Weight, + ) -> DispatchResult { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + PendingInboundEntries::::try_mutate( + session_id, + (message_proof, router_hash), + |storage_entry| match storage_entry { + None => { + *storage_entry = Some(inbound_entry); + + Ok::<(), DispatchError>(()) + } + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count: old, + .. + } => match inbound_entry { + InboundEntry::Message { + expected_proof_count: new, + .. + } => old.ensure_add_assign(new).map_err(|e| e.into()), + InboundEntry::Proof { .. } => { + // TODO(cdamian): Add new error. + Err(Error::::InvalidMultiRouter.into()) } - Some((_, _, expected_proof_count)) => { - // We already have a message, in this case we should expect another - // set of message proofs. - expected_proof_count - .ensure_add_assign(Self::get_expected_message_proof_count())?; + }, + InboundEntry::Proof { current_count: old } => match inbound_entry { + InboundEntry::Proof { current_count: new } => { + old.ensure_add_assign(new).map_err(|e| e.into()) } - }; + InboundEntry::Message { .. } => { + // TODO(cdamian): Add new error. + Err(Error::::InvalidMultiRouter.into()) + } + }, + }, + }, + ) + } - Ok(()) - }) { - Ok(_) => {} - Err(e) => return Err(e), - }; + fn validate_and_update_pending_entries( + session_id: T::SessionId, + message_proof: Proof, + router_hash: T::Hash, + domain_address: DomainAddress, + message: T::Message, + weight: &mut Weight, + ) -> DispatchResult { + let session_id = InboundDomainSessions::::get(domain_address.domain()) + .ok_or(Error::::InvalidMultiRouter)?; - *weight = weight.saturating_add(T::DbWeight::get().reads_writes(1, 1)); + let message_proof = Self::get_message_proof(message.clone()); - Ok(( - message_proof, - InboundMessageProofCount::::get(message_proof), - )) - } - Some(message_proof) => { - let message_proof_count = - match InboundMessageProofCount::::try_mutate(message_proof, |count| { - count.ensure_add_assign(1)?; + let inbound_entry = Self::create_inbound_entry(domain_address.clone(), message); + + Self::validate_inbound_entry(domain_address.domain(), router_hash, &inbound_entry)?; + + Self::update_pending_entry( + session_id, + message_proof, + router_hash, + inbound_entry, + weight, + )?; + + Ok(()) + } + + fn get_executable_message( + inbound_routers: BoundedVec, + session_id: T::SessionId, + message_proof: Proof, + ) -> Option { + let mut message = None; + let mut proof_count = 0; + + for inbound_router in inbound_routers { + match PendingInboundEntries::::get(session_id, (message_proof, inbound_router)) { + // We expected one InboundEntry for each router, if that's not the case, + // we can return. + None => return None, + Some(inbound_entry) => match inbound_entry { + InboundEntry::Message { + message: stored_message, + .. + } => message = Some(stored_message), + InboundEntry::Proof { current_count } => { + if current_count > 0 { + proof_count += 1; + } + } + }, + }; + } + + if proof_count == Self::get_expected_message_proof_count() { + return message; + } - Ok(*count) - }) { - Ok(r) => r, - Err(e) => return Err(e), - }; + None + } - *weight = weight.saturating_add(T::DbWeight::get().writes(1)); + fn decrease_pending_entries_counts( + inbound_routers: BoundedVec, + session_id: T::SessionId, + message_proof: Proof, + ) -> DispatchResult { + for inbound_router in inbound_routers { + match PendingInboundEntries::::try_mutate( + session_id, + (message_proof, inbound_router), + |storage_entry| match storage_entry { + // TODO(cdamian): Add new error + None => Err(Error::::InvalidMultiRouter.into()), + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count, + .. + } => { + let updated_count = (*expected_proof_count) + .ensure_sub(Self::get_expected_message_proof_count())?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *expected_proof_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + InboundEntry::Proof { current_count } => { + let updated_count = (*current_count).ensure_sub(1)?; - Ok((message_proof, message_proof_count)) + if updated_count == 0 { + *storage_entry = None; + } else { + *current_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + }, + }, + ) { + Ok(()) => {} + Err(e) => return Err(e), } } + + Ok(()) } /// Give the message to the `InboundMessageHandler` to be processed. @@ -591,78 +805,64 @@ pub mod pallet { message: T::Message, router_hash: T::Hash, ) -> (DispatchResult, Weight) { + let mut weight = T::DbWeight::get().reads(1); + + let Some(inbound_routers) = InboundRouters::::get(domain_address.domain()) else { + //TODO(cdamian): Add new error + return (Err(Error::::InvalidMultiRouter.into()), weight); + }; + + if inbound_routers.len() == 0 {} + + let Some(session_id) = InboundDomainSessions::::get(domain_address.domain()) else { + //TODO(cdamian): Add error + return (Err(Error::::InvalidMultiRouter.into()), weight); + }; + + let message_proof = Self::get_message_proof(message.clone()); + + weight.saturating_accrue( + Weight::from_parts(0, T::Message::max_encoded_len() as u64) + .saturating_add(LP_DEFENSIVE_WEIGHT), + ); + let mut count = 0; for submessage in message.submessages() { count += 1; - let (message_proof, mut current_message_proof_count) = - match Self::get_proof_and_current_count( - domain_address.clone(), - message.clone(), - &mut weight, - ) { - Ok(r) => r, - Err(e) => return (Err(e), weight), - }; - - let (_, message, mut total_expected_proof_count) = - match InboundMessages::::get(message_proof) { - None => return (Ok(()), weight), - Some(r) => r, - }; + if let Err(e) = Self::validate_and_update_pending_entries( + session_id, + message_proof, + router_hash, + domain_address.clone(), + submessage.clone(), + &mut weight, + ) { + return (Err(e), weight); + } - weight = weight.saturating_add(T::DbWeight::get().reads(1)); - - let expected_message_proof_count = Self::get_expected_message_proof_count(); - - match current_message_proof_count.cmp(&expected_message_proof_count) { - Ordering::Less => return (Ok(()), weight), - Ordering::Equal => { - InboundMessageProofCount::::remove(message_proof); - total_expected_proof_count -= expected_message_proof_count; - - if total_expected_proof_count == 0 { - InboundMessages::::remove(message_proof); - } else { - InboundMessages::::insert( - message_proof, - ( - domain_address.clone(), - message.clone(), - total_expected_proof_count, - ), - ); - } - } - Ordering::Greater => { - current_message_proof_count -= expected_message_proof_count; - InboundMessageProofCount::::insert( + match Self::get_executable_message( + inbound_routers.clone(), + session_id, + message_proof, + ) { + Some(m) => { + if let Err(e) = Self::decrease_pending_entries_counts( + inbound_routers.clone(), + session_id, message_proof, - current_message_proof_count, - ); - - total_expected_proof_count -= expected_message_proof_count; - - if total_expected_proof_count == 0 { - InboundMessages::::remove(message_proof); - } else { - InboundMessages::::insert( - message_proof, - ( - domain_address.clone(), - message.clone(), - total_expected_proof_count, - ), - ); + ) { + return (Err(e), weight.saturating_mul(count)); } - } - } - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), submessage) - { - // We only consume the processed weight if error during the batch - return (Err(e), LP_DEFENSIVE_WEIGHT.saturating_mul(count)); + if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) + { + // We only consume the processed weight if error during the batch + return (Err(e), weight.saturating_mul(count)); + } + } + None => continue, } } @@ -679,7 +879,7 @@ pub mod pallet { ) -> (DispatchResult, Weight) { let read_weight = T::DbWeight::get().reads(1); - let Some(router) = Routers::::get(router_hash) else { + let Some(router) = OutboundRouters::::get(router_hash) else { return (Err(Error::::RouterNotFound.into()), read_weight); }; @@ -692,7 +892,7 @@ pub mod pallet { } fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - let router_hashes = DomainMultiRouters::::get(destination.clone()) + let router_hashes = OutboundDomainRouters::::get(destination.clone()) .ok_or(Error::::MultiRouterNotFound)?; let message_proof = message.to_message_proof(); From 9d8e562990d8cc57de11c3895710ee3c8b23c3d3 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Fri, 9 Aug 2024 15:16:58 +0300 Subject: [PATCH 04/11] lp-gateway: Add and use InboundProcessingInfo --- pallets/liquidity-pools-gateway/src/lib.rs | 170 ++++++++++++------- pallets/liquidity-pools-gateway/src/mock.rs | 3 +- pallets/liquidity-pools-gateway/src/tests.rs | 1 + 3 files changed, 108 insertions(+), 66 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 169c7ab601..7190b36424 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -73,6 +73,14 @@ pub enum InboundEntry { }, } +#[derive(Clone)] +pub struct InboundProcessingInfo { + domain_address: DomainAddress, + inbound_routers: BoundedVec, + current_session_id: T::SessionId, + expected_proof_count_per_message: u32, +} + #[frame_support::pallet] pub mod pallet { use super::*; @@ -141,9 +149,9 @@ pub mod pallet { Message = GatewayMessage, >; - /// Number of routers for a domain. + /// Maximum number of routers allowed for a domain. #[pallet::constant] - type MultiRouterCount: Get; + type MaxRouterCount: Get; /// Type for identifying sessions of inbound routers. type SessionId: Parameter @@ -177,13 +185,13 @@ pub mod pallet { /// The outbound routers for a given domain were set. OutboundRoutersSet { domain: Domain, - routers: BoundedVec, + routers: BoundedVec, }, /// Inbound routers were set. InboundRoutersSet { domain: Domain, - router_hashes: BoundedVec, + router_hashes: BoundedVec, }, } @@ -232,7 +240,7 @@ pub mod pallet { #[pallet::storage] #[pallet::getter(fn outbound_domain_routers)] pub type OutboundDomainRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage for pending inbound messages. #[pallet::storage] @@ -252,7 +260,7 @@ pub mod pallet { #[pallet::storage] #[pallet::getter(fn inbound_routers)] pub type InboundRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage for the session ID of an inbound domain. #[pallet::storage] @@ -462,13 +470,13 @@ pub mod pallet { pub fn set_outbound_routers( origin: OriginFor, domain: Domain, - routers: BoundedVec, + routers: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); ensure!( - routers.len() == T::MultiRouterCount::get() as usize, + routers.len() == T::MaxRouterCount::get() as usize, Error::::InvalidMultiRouter ); @@ -500,12 +508,12 @@ pub mod pallet { pub fn set_inbound_routers( origin: OriginFor, domain: Domain, - router_hashes: BoundedVec, + router_hashes: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; ensure!( - router_hashes.len() == T::MultiRouterCount::get() as usize, + router_hashes.len() == T::MaxRouterCount::get() as usize, Error::::InvalidMultiRouter ); @@ -543,8 +551,13 @@ pub mod pallet { impl Pallet { //TODO(cdamian): Use safe math - fn get_expected_message_proof_count() -> u32 { - T::MultiRouterCount::get() - 1 + fn get_expected_proof_count(domain: &Domain) -> Result { + let routers = + InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; + + let expected_proof_count = routers.len().ensure_sub(1)?; + + Ok(expected_proof_count as u32) } fn get_message_proof(message: T::Message) -> Proof { @@ -560,12 +573,13 @@ pub mod pallet { fn create_inbound_entry( domain_address: DomainAddress, message: T::Message, + expected_proof_count: u32, ) -> InboundEntry { match message.get_message_proof() { None => InboundEntry::Message { domain_address, message, - expected_proof_count: Self::get_expected_message_proof_count(), + expected_proof_count, }, Some(_) => InboundEntry::Proof { current_count: 1 }, } @@ -614,15 +628,21 @@ pub mod pallet { } } - fn update_storage_entry(old: &mut InboundEntry, new: InboundEntry) -> DispatchResult { + fn update_storage_entry( + domain: Domain, + old: &mut InboundEntry, + new: InboundEntry, + ) -> DispatchResult { match old { InboundEntry::Message { - expected_proof_count, + expected_proof_count: stored_expected_proof_count, .. } => match new { InboundEntry::Message { .. } => { - expected_proof_count - .ensure_add_assign(Self::get_expected_message_proof_count())?; + let expected_message_proof_count = Self::get_expected_proof_count(&domain)?; + + stored_expected_proof_count + .ensure_add_assign(expected_message_proof_count)?; Ok(()) } @@ -688,24 +708,26 @@ pub mod pallet { } fn validate_and_update_pending_entries( - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, + message: T::Message, message_proof: Proof, router_hash: T::Hash, - domain_address: DomainAddress, - message: T::Message, weight: &mut Weight, ) -> DispatchResult { - let session_id = InboundDomainSessions::::get(domain_address.domain()) - .ok_or(Error::::InvalidMultiRouter)?; - - let message_proof = Self::get_message_proof(message.clone()); - - let inbound_entry = Self::create_inbound_entry(domain_address.clone(), message); + let inbound_entry = Self::create_inbound_entry( + inbound_processing_info.domain_address.clone(), + message, + inbound_processing_info.expected_proof_count_per_message, + ); - Self::validate_inbound_entry(domain_address.domain(), router_hash, &inbound_entry)?; + Self::validate_inbound_entry( + inbound_processing_info.domain_address.domain(), + router_hash, + &inbound_entry, + )?; Self::update_pending_entry( - session_id, + inbound_processing_info.current_session_id, message_proof, router_hash, inbound_entry, @@ -716,15 +738,17 @@ pub mod pallet { } fn get_executable_message( - inbound_routers: BoundedVec, - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, ) -> Option { let mut message = None; - let mut proof_count = 0; + let mut votes = 0; - for inbound_router in inbound_routers { - match PendingInboundEntries::::get(session_id, (message_proof, inbound_router)) { + for inbound_router in &inbound_processing_info.inbound_routers { + match PendingInboundEntries::::get( + inbound_processing_info.current_session_id, + (message_proof, inbound_router), + ) { // We expected one InboundEntry for each router, if that's not the case, // we can return. None => return None, @@ -735,14 +759,14 @@ pub mod pallet { } => message = Some(stored_message), InboundEntry::Proof { current_count } => { if current_count > 0 { - proof_count += 1; + votes += 1; } } }, }; } - if proof_count == Self::get_expected_message_proof_count() { + if votes == inbound_processing_info.expected_proof_count_per_message { return message; } @@ -750,13 +774,12 @@ pub mod pallet { } fn decrease_pending_entries_counts( - inbound_routers: BoundedVec, - session_id: T::SessionId, + inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, ) -> DispatchResult { - for inbound_router in inbound_routers { + for inbound_router in &inbound_processing_info.inbound_routers { match PendingInboundEntries::::try_mutate( - session_id, + inbound_processing_info.current_session_id, (message_proof, inbound_router), |storage_entry| match storage_entry { // TODO(cdamian): Add new error @@ -766,8 +789,9 @@ pub mod pallet { expected_proof_count, .. } => { - let updated_count = (*expected_proof_count) - .ensure_sub(Self::get_expected_message_proof_count())?; + let updated_count = (*expected_proof_count).ensure_sub( + inbound_processing_info.expected_proof_count_per_message, + )?; if updated_count == 0 { *storage_entry = None; @@ -799,27 +823,47 @@ pub mod pallet { Ok(()) } + fn get_inbound_processing_info( + domain_address: DomainAddress, + weight: &mut Weight, + ) -> Result, DispatchError> { + let inbound_routers = + //TODO(cdamian): Add new error + InboundRouters::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let current_session_id = + //TODO(cdamian): Add new error + InboundDomainSessions::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + Ok(InboundProcessingInfo { + domain_address, + inbound_routers, + current_session_id, + expected_proof_count_per_message: expected_proof_count, + }) + } + /// Give the message to the `InboundMessageHandler` to be processed. fn process_inbound_message( domain_address: DomainAddress, message: T::Message, router_hash: T::Hash, ) -> (DispatchResult, Weight) { - let mut weight = T::DbWeight::get().reads(1); + let mut weight = Default::default(); - let Some(inbound_routers) = InboundRouters::::get(domain_address.domain()) else { - //TODO(cdamian): Add new error - return (Err(Error::::InvalidMultiRouter.into()), weight); - }; - - if inbound_routers.len() == 0 {} - - let Some(session_id) = InboundDomainSessions::::get(domain_address.domain()) else { - //TODO(cdamian): Add error - return (Err(Error::::InvalidMultiRouter.into()), weight); - }; - - let message_proof = Self::get_message_proof(message.clone()); + let inbound_processing_info = + match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) { + Ok(i) => i, + Err(e) => return (Err(e), weight), + }; weight.saturating_accrue( Weight::from_parts(0, T::Message::max_encoded_len() as u64) @@ -831,26 +875,22 @@ pub mod pallet { for submessage in message.submessages() { count += 1; + let message_proof = Self::get_message_proof(message.clone()); + if let Err(e) = Self::validate_and_update_pending_entries( - session_id, + &inbound_processing_info, + submessage.clone(), message_proof, router_hash, - domain_address.clone(), - submessage.clone(), &mut weight, ) { return (Err(e), weight); } - match Self::get_executable_message( - inbound_routers.clone(), - session_id, - message_proof, - ) { + match Self::get_executable_message(&inbound_processing_info, message_proof) { Some(m) => { if let Err(e) = Self::decrease_pending_entries_counts( - inbound_routers.clone(), - session_id, + &inbound_processing_info, message_proof, ) { return (Err(e), weight.saturating_mul(count)); diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 6406b673f8..428d21427d 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -138,7 +138,7 @@ frame_support::parameter_types! { pub Sender: AccountId32 = AccountId32::from(H256::from_low_u64_be(1).to_fixed_bytes()); pub const MaxIncomingMessageSize: u32 = 1024; pub const LpAdminAccount: AccountId32 = LP_ADMIN_ACCOUNT; - pub const MultiRouterCount: u32 = 3; + pub const MaxRouterCount: u32 = 8; } impl pallet_liquidity_pools_gateway::Config for Runtime { @@ -146,6 +146,7 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type InboundMessageHandler = MockLiquidityPools; type LocalEVMOrigin = EnsureLocal; type MaxIncomingMessageSize = MaxIncomingMessageSize; + type MaxRouterCount = MaxRouterCount; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; type MultiRouterCount = MultiRouterCount; diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 3c9f056071..a66cc89115 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use cfg_mocks::*; +use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler, Proof}; use cfg_types::domain_address::*; use frame_support::{ From 51c2af5fa83dba8d791dfc0ca1eb43108eac168c Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Fri, 9 Aug 2024 19:52:50 +0300 Subject: [PATCH 05/11] lp-gateway: Unit tests WIP --- Cargo.lock | 12 + pallets/liquidity-pools-gateway/src/lib.rs | 126 +- pallets/liquidity-pools-gateway/src/mock.rs | 1 - pallets/liquidity-pools-gateway/src/tests.rs | 1171 ++++++++++-------- 4 files changed, 720 insertions(+), 590 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b5dfd2b257..ded0246947 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5476,6 +5476,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -8266,6 +8275,9 @@ dependencies = [ "frame-support", "frame-system", "hex", + "itertools 0.13.0", + "lazy_static", + "mock-builder", "orml-traits", "parity-scale-codec", "scale-info", diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 7190b36424..424ea1d580 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -31,7 +31,7 @@ use core::fmt::Debug; use cfg_primitives::LP_DEFENSIVE_WEIGHT; use cfg_traits::liquidity_pools::{ InboundMessageHandler, LPEncoding, MessageProcessor, MessageQueue, OutboundMessageHandler, - Router as DomainRouter, + Proof, Router as DomainRouter, }; use cfg_types::domain_address::{Domain, DomainAddress}; use frame_support::{dispatch::DispatchResult, pallet_prelude::*}; @@ -40,7 +40,7 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; -use sp_arithmetic::traits::BaseArithmetic; +use sp_arithmetic::traits::{BaseArithmetic, EnsureSub, One}; use sp_runtime::traits::EnsureAddAssign; use sp_std::{cmp::Ordering, convert::TryInto, vec::Vec}; @@ -63,14 +63,14 @@ mod tests; #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] #[scale_info(skip_type_params(T))] pub enum InboundEntry { - Message { - domain_address: DomainAddress, - message: T::Message, - expected_proof_count: u32, - }, - Proof { - current_count: u32, - }, + Message { + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + }, + Proof { + current_count: u32, + }, } #[derive(Clone)] @@ -306,6 +306,27 @@ pub mod pallet { /// Invalid multi router. InvalidMultiRouter, + /// Inbound domain session not found. + InboundDomainSessionNotFound, + + /// The router that sent the inbound message is unknown. + UnknownInboundMessageRouter, + + /// The router that sent the message is not the first one. + MessageExpectedFromFirstRouter, + + /// The router that sent the proof should not be the first one. + ProofNotExpectedFromFirstRouter, + + /// A message was expected instead of a proof. + ExpectedMessageType, + + /// A message proof was expected instead of a message. + ExpectedMessageProofType, + + /// Pending inbound entry not found. + PendingInboundEntryNotFound, + /// Multi-router not found. MultiRouterNotFound, @@ -475,10 +496,6 @@ pub mod pallet { T::AdminOrigin::ensure_origin(origin)?; ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); - ensure!( - routers.len() == T::MaxRouterCount::get() as usize, - Error::::InvalidMultiRouter - ); let mut router_hashes = Vec::new(); @@ -512,11 +529,6 @@ pub mod pallet { ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; - ensure!( - router_hashes.len() == T::MaxRouterCount::get() as usize, - Error::::InvalidMultiRouter - ); - let session_id = SessionIdStore::::try_mutate(|n| { n.ensure_add_assign(One::one())?; Ok::(*n) @@ -550,7 +562,6 @@ pub mod pallet { } impl Pallet { - //TODO(cdamian): Use safe math fn get_expected_proof_count(domain: &Domain) -> Result { let routers = InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; @@ -592,26 +603,23 @@ pub mod pallet { /// - messages are only sent by the first inbound router. /// - proofs are not sent by the first inbound router. fn validate_inbound_entry( - domain: Domain, + inbound_processing_info: &InboundProcessingInfo, router_hash: T::Hash, inbound_entry: &InboundEntry, ) -> DispatchResult { - let inbound_routers = - //TODO(cdamian): Add new error - InboundRouters::::get(domain).ok_or(Error::::InvalidMultiRouter)?; + let inbound_routers = inbound_processing_info.inbound_routers.clone(); ensure!( inbound_routers.iter().any(|x| x == &router_hash), - //TODO(cdamian): Add error - Error::::InvalidMultiRouter + Error::::UnknownInboundMessageRouter ); + //TODO(cdamian): Test with 2 routers match inbound_entry { InboundEntry::Message { .. } => { ensure!( inbound_routers.get(0) == Some(&router_hash), - //TODO(cdamian): Add error - Error::::InvalidMultiRouter + Error::::MessageExpectedFromFirstRouter ); Ok(()) @@ -619,8 +627,7 @@ pub mod pallet { InboundEntry::Proof { .. } => { ensure!( inbound_routers.get(0) != Some(&router_hash), - //TODO(cdamian): Add error - Error::::InvalidMultiRouter + Error::::ProofNotExpectedFromFirstRouter ); Ok(()) @@ -628,39 +635,6 @@ pub mod pallet { } } - fn update_storage_entry( - domain: Domain, - old: &mut InboundEntry, - new: InboundEntry, - ) -> DispatchResult { - match old { - InboundEntry::Message { - expected_proof_count: stored_expected_proof_count, - .. - } => match new { - InboundEntry::Message { .. } => { - let expected_message_proof_count = Self::get_expected_proof_count(&domain)?; - - stored_expected_proof_count - .ensure_add_assign(expected_message_proof_count)?; - - Ok(()) - } - //TODO(cdamian): Update error - InboundEntry::Proof { .. } => Err(Error::::InvalidMultiRouter.into()), - }, - InboundEntry::Proof { current_count } => match new { - InboundEntry::Proof { .. } => { - current_count.ensure_add_assign(1)?; - - Ok(()) - } - //TODO(cdamian): Update error - InboundEntry::Message { .. } => Err(Error::::InvalidMultiRouter.into()), - }, - } - } - fn update_pending_entry( session_id: T::SessionId, message_proof: Proof, @@ -689,8 +663,8 @@ pub mod pallet { .. } => old.ensure_add_assign(new).map_err(|e| e.into()), InboundEntry::Proof { .. } => { - // TODO(cdamian): Add new error. - Err(Error::::InvalidMultiRouter.into()) + //TODO(cdamian): Test with 2 routers + Err(Error::::ExpectedMessageType.into()) } }, InboundEntry::Proof { current_count: old } => match inbound_entry { @@ -698,8 +672,7 @@ pub mod pallet { old.ensure_add_assign(new).map_err(|e| e.into()) } InboundEntry::Message { .. } => { - // TODO(cdamian): Add new error. - Err(Error::::InvalidMultiRouter.into()) + Err(Error::::ExpectedMessageProofType.into()) } }, }, @@ -720,11 +693,7 @@ pub mod pallet { inbound_processing_info.expected_proof_count_per_message, ); - Self::validate_inbound_entry( - inbound_processing_info.domain_address.domain(), - router_hash, - &inbound_entry, - )?; + Self::validate_inbound_entry(&inbound_processing_info, router_hash, &inbound_entry)?; Self::update_pending_entry( inbound_processing_info.current_session_id, @@ -782,8 +751,7 @@ pub mod pallet { inbound_processing_info.current_session_id, (message_proof, inbound_router), |storage_entry| match storage_entry { - // TODO(cdamian): Add new error - None => Err(Error::::InvalidMultiRouter.into()), + None => Err(Error::::PendingInboundEntryNotFound.into()), Some(stored_inbound_entry) => match stored_inbound_entry { InboundEntry::Message { expected_proof_count, @@ -827,15 +795,13 @@ pub mod pallet { domain_address: DomainAddress, weight: &mut Weight, ) -> Result, DispatchError> { - let inbound_routers = - //TODO(cdamian): Add new error - InboundRouters::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + let inbound_routers = InboundRouters::::get(domain_address.domain()) + .ok_or(Error::::MultiRouterNotFound)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); - let current_session_id = - //TODO(cdamian): Add new error - InboundDomainSessions::::get(domain_address.domain()).ok_or(Error::::InvalidMultiRouter)?; + let current_session_id = InboundDomainSessions::::get(domain_address.domain()) + .ok_or(Error::::InboundDomainSessionNotFound)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 428d21427d..332b56df33 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -149,7 +149,6 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type MaxRouterCount = MaxRouterCount; type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; - type MultiRouterCount = MultiRouterCount; type Router = RouterMock; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index a66cc89115..d38d5a532f 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -23,7 +23,13 @@ use super::{ origin::*, pallet::*, }; -use crate::GatewayMessage; +use crate::{GatewayMessage, InboundEntry}; + +lazy_static! { + static ref ROUTER_HASH_1: H256 = H256::from_low_u64_be(1); + static ref ROUTER_HASH_2: H256 = H256::from_low_u64_be(2); + static ref ROUTER_HASH_3: H256 = H256::from_low_u64_be(3); +} mod utils { use super::*; @@ -652,530 +658,675 @@ mod message_processor_impl { pub proof_count: u32, pub mock_called_times: u32, } + + pub fn gen_new(t: T, count: usize) -> Vec::Item>> + where + T: IntoIterator + Clone, + T::IntoIter: Clone, + T::Item: Clone, + { + std::iter::repeat(t.clone().into_iter()) + .take(count) + .multi_cartesian_product() + .collect::>() + } } use util::*; - mod combined_messages { + mod one_router { use super::*; - mod two_messages { - use super::*; - - lazy_static! { - static ref TEST_MAP: HashMap, ExpectedTestResult> = - HashMap::from([ - ( - vec![Message::Simple, Message::Simple], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 4 - )), - proof_count: 0, - mock_called_times: 0, - } - ), - ( - vec![Message::Proof(MESSAGE_PROOF), Message::Proof(MESSAGE_PROOF)], - ExpectedTestResult { - inbound_message: None, - proof_count: 2, - mock_called_times: 0, - } - ), - ( - vec![Message::Simple, Message::Proof(MESSAGE_PROOF)], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![Message::Proof(MESSAGE_PROOF), Message::Simple], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ]); - } + #[test] + fn success() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_hash, + }; + + InboundRouters::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundDomainSessions::::insert(domain_address.domain(), session_id); + + let handler = MockLiquidityPools::mock_handle( + move |mock_domain_address, mock_message| { + assert_eq!(mock_domain_address, domain_address); + assert_eq!(mock_message, message); + + Ok(()) + }, + ); - #[test] - fn two_messages() { - let tests = generate_test_combinations(2) - .iter() - .map(|x| { - ( - x.clone(), - TEST_MAP - .get(x) - .expect(format!("test for {x:?} should be covered").as_str()), - ) - }) - .collect::>(); - - run_tests!(tests); - } + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + assert_eq!(handler.times(), 1); + }); } - mod three_messages { - use super::*; - - lazy_static! { - static ref TEST_MAP: HashMap, ExpectedTestResult> = - HashMap::from([ - ( - vec![Message::Simple, Message::Simple, Message::Simple,], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 6 - )), - proof_count: 0, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 3, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 4 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 4 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 4 - )), - proof_count: 1, - mock_called_times: 0, - } - ) - ]); - } + #[test] + fn multi_router_not_found() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_hash, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::MultiRouterNotFound); + }); + } - #[test] - fn three_messages() { - let tests = generate_test_combinations(3) - .iter() - .map(|x| { - ( - x.clone(), - TEST_MAP - .get(x) - .expect(format!("test for {x:?} should be covered").as_str()), - ) - }) - .collect::>(); - - run_tests!(tests); - } + #[test] + fn inbound_domain_session_not_found() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_hash, + }; + + InboundRouters::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::InboundDomainSessionNotFound); + }); } - mod four_messages { - use super::*; - - lazy_static! { - static ref TEST_MAP: HashMap, ExpectedTestResult> = - HashMap::from([ - ( - vec![ - Message::Simple, - Message::Simple, - Message::Simple, - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 8 - )), - proof_count: 0, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 4, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 1, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 1, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 1, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: None, - proof_count: 1, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Simple, - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 2 - )), - proof_count: 0, - mock_called_times: 1, - } - ), - ( - vec![ - Message::Simple, - Message::Simple, - Message::Simple, - Message::Proof(MESSAGE_PROOF), - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 6 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 6 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Simple, - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 6 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ( - vec![ - Message::Proof(MESSAGE_PROOF), - Message::Simple, - Message::Simple, - Message::Simple, - ], - ExpectedTestResult { - inbound_message: Some(( - DomainAddress::EVM(1, [1; 20]), - Message::Simple, - 6 - )), - proof_count: 1, - mock_called_times: 0, - } - ), - ]); - } + #[test] + fn unknown_inbound_message_router() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + // The router stored has a different hash, this should trigger the expected + // error. + router_hash: *ROUTER_HASH_2, + }; + + InboundRouters::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundDomainSessions::::insert(domain_address.domain(), session_id); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::UnknownInboundMessageRouter); + }); + } - #[test] - fn four_messages() { - let tests = generate_test_combinations(4) - .iter() - .filter(|x| TEST_MAP.get(x.clone()).is_some()) - .map(|x| { - ( - x.clone(), - TEST_MAP - .get(x) - .expect(format!("test for {x:?} should be covered").as_str()), - ) - }) - .collect::>(); - - run_tests!(tests); - } + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let message = Message::Simple; + let message_proof = message.to_message_proof().get_message_proof().unwrap(); + let session_id = 1; + let domain_address = DomainAddress::EVM(1, [1; 20]); + let router_hash = *ROUTER_HASH_1; + let gateway_message = GatewayMessage::Inbound { + domain_address: domain_address.clone(), + message: message.clone(), + router_hash, + }; + + InboundRouters::::insert( + domain_address.domain(), + BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), + ); + InboundDomainSessions::::insert(domain_address.domain(), session_id); + PendingInboundEntries::::insert( + session_id, + (message_proof, router_hash), + InboundEntry::::Proof { current_count: 0 }, + ); + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ExpectedMessageProofType); + }); } } - #[test] - fn two_non_proof_and_four_proofs() { - let tests = generate_test_combinations(6) - .into_iter() - .filter(|x| { - let r = x.iter().counts_by(|c| c.clone()); - let non_proof_count = r.get(&Message::Simple); - let proof_count = r.get(&Message::Proof(MESSAGE_PROOF)); - - match (non_proof_count, proof_count) { - (Some(non_proof_count), Some(proof_count)) => { - *non_proof_count == 2 && *proof_count == 4 - } - _ => false, - } - }) - .map(|x| { - ( - x, - ExpectedTestResult { - inbound_message: None, - proof_count: 0, - mock_called_times: 2, - }, - ) - }) - .collect::>(); - - run_tests!(tests); - } + // mod combined_messages { + // use super::*; + // + // mod two_messages { + // use super::*; + // + // lazy_static! { + // static ref TEST_MAP: HashMap, ExpectedTestResult> = + // HashMap::from([ + // ( + // vec![Message::Simple, Message::Simple], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 4 + // )), + // proof_count: 0, + // mock_called_times: 0, + // } + // ), + // ( + // vec![Message::Proof(MESSAGE_PROOF), Message::Proof(MESSAGE_PROOF)], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 2, + // mock_called_times: 0, + // } + // ), + // ( + // vec![Message::Simple, Message::Proof(MESSAGE_PROOF)], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![Message::Proof(MESSAGE_PROOF), Message::Simple], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ]); + // } + // + // #[test] + // fn two_messages() { + // let tests = generate_test_combinations(2) + // .iter() + // .map(|x| { + // ( + // x.clone(), + // TEST_MAP + // .get(x) + // .expect(format!("test for {x:?} should be covered").as_str()), + // ) + // }) + // .collect::>(); + // + // run_tests!(tests); + // } + // } + // + // mod three_messages { + // use super::*; + // + // lazy_static! { + // static ref TEST_MAP: HashMap, ExpectedTestResult> = + // HashMap::from([ + // ( + // vec![Message::Simple, Message::Simple, Message::Simple,], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 6 + // )), + // proof_count: 0, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 3, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 4 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 4 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 4 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ) + // ]); + // } + // + // #[test] + // fn three_messages() { + // let tests = generate_test_combinations(3) + // .iter() + // .map(|x| { + // ( + // x.clone(), + // TEST_MAP + // .get(x) + // .expect(format!("test for {x:?} should be covered").as_str()), + // ) + // }) + // .collect::>(); + // + // run_tests!(tests); + // } + // } + // + // mod four_messages { + // use super::*; + // + // lazy_static! { + // static ref TEST_MAP: HashMap, ExpectedTestResult> = + // HashMap::from([ + // ( + // vec![ + // Message::Simple, + // Message::Simple, + // Message::Simple, + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 8 + // )), + // proof_count: 0, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 4, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 1, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 1, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 1, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 1, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 2 + // )), + // proof_count: 0, + // mock_called_times: 1, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Simple, + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 6 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 6 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Simple, + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 6 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ( + // vec![ + // Message::Proof(MESSAGE_PROOF), + // Message::Simple, + // Message::Simple, + // Message::Simple, + // ], + // ExpectedTestResult { + // inbound_message: Some(( + // DomainAddress::EVM(1, [1; 20]), + // Message::Simple, + // 6 + // )), + // proof_count: 1, + // mock_called_times: 0, + // } + // ), + // ]); + // } + // + // #[test] + // fn four_messages() { + // let tests = generate_test_combinations(4) + // .iter() + // .filter(|x| TEST_MAP.get(x.clone()).is_some()) + // .map(|x| { + // ( + // x.clone(), + // TEST_MAP + // .get(x) + // .expect(format!("test for {x:?} should be covered").as_str()), + // ) + // }) + // .collect::>(); + // + // run_tests!(tests); + // } + // } + // } + // + // #[test] + // fn two_non_proof_and_four_proofs() { + // let tests = generate_test_combinations(6) + // .into_iter() + // .filter(|x| { + // let r = x.iter().counts_by(|c| c.clone()); + // let non_proof_count = r.get(&Message::Simple); + // let proof_count = r.get(&Message::Proof(MESSAGE_PROOF)); + // + // match (non_proof_count, proof_count) { + // (Some(non_proof_count), Some(proof_count)) => { + // *non_proof_count == 2 && *proof_count == 4 + // } + // _ => false, + // } + // }) + // .map(|x| { + // ( + // x, + // ExpectedTestResult { + // inbound_message: None, + // proof_count: 0, + // mock_called_times: 2, + // }, + // ) + // }) + // .collect::>(); + // + // run_tests!(tests); + // } #[test] fn inbound_message_handler_error() { @@ -1466,6 +1617,7 @@ mod batches { let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + router_hash: *ROUTER_HASH_1, }); assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 5); @@ -1490,6 +1642,7 @@ mod batches { let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, message: Message::deserialize(&(1..=5).collect::>()).unwrap(), + router_hash: *ROUTER_HASH_1, }); // 2 correct messages and 1 failed message processed. From 9be4c58eafbed564bdd6c330ef36c541ae569c00 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Sat, 10 Aug 2024 12:08:55 +0300 Subject: [PATCH 06/11] lp-gateway: Unit tests WIP 2 --- pallets/liquidity-pools-gateway/src/lib.rs | 1 - pallets/liquidity-pools-gateway/src/tests.rs | 2283 +++++++++++++----- 2 files changed, 1699 insertions(+), 585 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 424ea1d580..f62b872b7e 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -614,7 +614,6 @@ pub mod pallet { Error::::UnknownInboundMessageRouter ); - //TODO(cdamian): Test with 2 routers match inbound_entry { InboundEntry::Message { .. } => { ensure!( diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index d38d5a532f..c2352ad917 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -25,6 +25,8 @@ use super::{ }; use crate::{GatewayMessage, InboundEntry}; +pub const TEST_DOMAIN_ADDRESS: DomainAddress = DomainAddress::EVM(0, [1; 20]); + lazy_static! { static ref ROUTER_HASH_1: H256 = H256::from_low_u64_be(1); static ref ROUTER_HASH_2: H256 = H256::from_low_u64_be(2); @@ -582,84 +584,85 @@ mod message_processor_impl { mod util { use super::*; - macro_rules! run_tests { - ($tests:expr) => { - // $tests = Vec<(Vec, &ExpectedTestResult)> - for test in $tests { - new_test_ext().execute_with(|| { - println!("Executing test for - {:?}", test.0); - - let handler = MockLiquidityPools::mock_handle(move |_, _| Ok(())); - - // test.0 = Vec - for test_message in test.0 { - let domain_address = DomainAddress::EVM(1, [1; 20]); - let gateway_message = GatewayMessage::Inbound { - domain_address: domain_address.clone(), - message: test_message.clone(), - //TODO(cdamian): Use test router hash. - router_hash: H256::from_low_u64_be(1), - }; - - let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); - } + pub fn run_inbound_message_test_suite(suite: InboundMessageTestSuite) { + let test_routers = suite.routers; - assert_eq!(handler.times(), test.1.mock_called_times); + for test in suite.tests { + println!("Executing test for - {:?}", test.router_messages); - assert_eq!( - InboundMessages::::get(MESSAGE_PROOF), - // test.1 = &ExpectedTestResult - test.1.inbound_message, - ); - assert_eq!( - InboundMessageProofCount::::get(MESSAGE_PROOF), - // test.1 = &ExpectedTestResult - test.1.proof_count, - ); - }); - } - }; - } + new_test_ext().execute_with(|| { + let session_id = 1; - lazy_static! { - static ref TEST_MESSAGES: Vec = - vec![Message::Simple, Message::Proof(MESSAGE_PROOF),]; + InboundRouters::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::try_from(test_routers.clone()).unwrap(), + ); + InboundDomainSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let handler = MockLiquidityPools::mock_handle(move |_, _| Ok(())); + + for router_message in test.router_messages { + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: router_message.1, + router_hash: router_message.0, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_ok!(res); + } + + let expected_message_submitted_times = + test.expected_test_result.message_submitted_times; + let message_submitted_times = handler.times(); + + assert_eq!( + message_submitted_times, + expected_message_submitted_times, + "Expected message to be submitted {expected_message_submitted_times} times, was {message_submitted_times}" + ); + + for expected_storage_entry in + test.expected_test_result.expected_storage_entries + { + let expected_storage_entry_router_hash = expected_storage_entry.0; + let expected_inbound_entry = expected_storage_entry.1; + + let storage_entry = PendingInboundEntries::::get( + session_id, + (MESSAGE_PROOF, expected_storage_entry_router_hash), + ); + assert_eq!(storage_entry, expected_inbound_entry, "Expected inbound entry {expected_inbound_entry:?}, found {storage_entry:?}"); + } + }); + } } - /// Generate all `Message` combinations for a specific - /// number of messages, like: + /// Generate all `TestEntry` combinations like: /// /// vec![ - /// Message::Simple, - /// Message::Simple, + /// (*ROUTER_HASH_1, Message::Simple), + /// (*ROUTER_HASH_1, Message::Simple), /// ] /// vec![ - /// Message::Simple, - /// Message::Proof(MESSAGE_PROOF), + /// (*ROUTER_HASH_1, Message::Simple), + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), /// ] /// vec![ - /// Message::Proof(MESSAGE_PROOF), - /// Message::Simple, + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// (*ROUTER_HASH_1, Message::Simple), /// ] /// vec![ - /// Message::Proof(MESSAGE_PROOF), - /// Message::Proof(MESSAGE_PROOF), + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + /// (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), /// ] - pub fn generate_test_combinations(count: usize) -> Vec> { - std::iter::repeat(TEST_MESSAGES.clone().into_iter()) - .take(count) - .multi_cartesian_product() - .collect::>() - } - - pub struct ExpectedTestResult { - pub inbound_message: Option<(DomainAddress, Message, u32)>, - pub proof_count: u32, - pub mock_called_times: u32, - } - - pub fn gen_new(t: T, count: usize) -> Vec::Item>> + pub fn generate_test_combinations( + t: T, + count: usize, + ) -> Vec::Item>> where T: IntoIterator + Clone, T::IntoIter: Clone, @@ -670,6 +673,52 @@ mod message_processor_impl { .multi_cartesian_product() .collect::>() } + + pub type RouterMessage = (H256, Message); + + pub struct InboundMessageTestSuite { + pub routers: Vec, + pub tests: Vec, + } + + pub struct InboundMessageTest { + pub router_messages: Vec, + pub expected_test_result: ExpectedTestResult, + } + + #[derive(Clone, Debug)] + pub struct ExpectedTestResult { + pub message_submitted_times: u32, + pub expected_storage_entries: Vec<(H256, Option>)>, + } + + pub fn generate_test_suite( + routers: Vec, + test_data: Vec, + expected_results: HashMap, ExpectedTestResult>, + message_count: usize, + ) -> InboundMessageTestSuite { + let tests = generate_test_combinations(test_data, message_count); + + let tests = tests + .into_iter() + .map(|router_messages| { + let expected_test_result = expected_results + .get(&router_messages) + .expect( + format!("test for {router_messages:?} should be covered").as_str(), + ) + .clone(); + + InboundMessageTest { + router_messages, + expected_test_result, + } + }) + .collect::>(); + + InboundMessageTestSuite { routers, tests } + } } use util::*; @@ -681,6 +730,7 @@ mod message_processor_impl { fn success() { new_test_ext().execute_with(|| { let message = Message::Simple; + let message_proof = message.to_message_proof().get_message_proof().unwrap(); let session_id = 1; let domain_address = DomainAddress::EVM(1, [1; 20]); let router_hash = *ROUTER_HASH_1; @@ -708,6 +758,12 @@ mod message_processor_impl { let (res, _) = LiquidityPoolsGateway::process(gateway_message); assert_ok!(res); assert_eq!(handler.times(), 1); + + assert!(PendingInboundEntries::::get( + session_id, + (message_proof, router_hash) + ) + .is_none()); }); } @@ -807,526 +863,1585 @@ mod message_processor_impl { } } - // mod combined_messages { - // use super::*; - // - // mod two_messages { - // use super::*; - // - // lazy_static! { - // static ref TEST_MAP: HashMap, ExpectedTestResult> = - // HashMap::from([ - // ( - // vec![Message::Simple, Message::Simple], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 4 - // )), - // proof_count: 0, - // mock_called_times: 0, - // } - // ), - // ( - // vec![Message::Proof(MESSAGE_PROOF), Message::Proof(MESSAGE_PROOF)], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 2, - // mock_called_times: 0, - // } - // ), - // ( - // vec![Message::Simple, Message::Proof(MESSAGE_PROOF)], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![Message::Proof(MESSAGE_PROOF), Message::Simple], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ]); - // } - // - // #[test] - // fn two_messages() { - // let tests = generate_test_combinations(2) - // .iter() - // .map(|x| { - // ( - // x.clone(), - // TEST_MAP - // .get(x) - // .expect(format!("test for {x:?} should be covered").as_str()), - // ) - // }) - // .collect::>(); - // - // run_tests!(tests); - // } - // } - // - // mod three_messages { - // use super::*; - // - // lazy_static! { - // static ref TEST_MAP: HashMap, ExpectedTestResult> = - // HashMap::from([ - // ( - // vec![Message::Simple, Message::Simple, Message::Simple,], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 6 - // )), - // proof_count: 0, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 3, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 4 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 4 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 4 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ) - // ]); - // } - // - // #[test] - // fn three_messages() { - // let tests = generate_test_combinations(3) - // .iter() - // .map(|x| { - // ( - // x.clone(), - // TEST_MAP - // .get(x) - // .expect(format!("test for {x:?} should be covered").as_str()), - // ) - // }) - // .collect::>(); - // - // run_tests!(tests); - // } - // } - // - // mod four_messages { - // use super::*; - // - // lazy_static! { - // static ref TEST_MAP: HashMap, ExpectedTestResult> = - // HashMap::from([ - // ( - // vec![ - // Message::Simple, - // Message::Simple, - // Message::Simple, - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 8 - // )), - // proof_count: 0, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 4, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 1, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 1, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 1, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 1, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 2 - // )), - // proof_count: 0, - // mock_called_times: 1, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Simple, - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 6 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 6 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Simple, - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 6 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ( - // vec![ - // Message::Proof(MESSAGE_PROOF), - // Message::Simple, - // Message::Simple, - // Message::Simple, - // ], - // ExpectedTestResult { - // inbound_message: Some(( - // DomainAddress::EVM(1, [1; 20]), - // Message::Simple, - // 6 - // )), - // proof_count: 1, - // mock_called_times: 0, - // } - // ), - // ]); - // } - // - // #[test] - // fn four_messages() { - // let tests = generate_test_combinations(4) - // .iter() - // .filter(|x| TEST_MAP.get(x.clone()).is_some()) - // .map(|x| { - // ( - // x.clone(), - // TEST_MAP - // .get(x) - // .expect(format!("test for {x:?} should be covered").as_str()), - // ) - // }) - // .collect::>(); - // - // run_tests!(tests); - // } - // } - // } - // - // #[test] - // fn two_non_proof_and_four_proofs() { - // let tests = generate_test_combinations(6) - // .into_iter() - // .filter(|x| { - // let r = x.iter().counts_by(|c| c.clone()); - // let non_proof_count = r.get(&Message::Simple); - // let proof_count = r.get(&Message::Proof(MESSAGE_PROOF)); - // - // match (non_proof_count, proof_count) { - // (Some(non_proof_count), Some(proof_count)) => { - // *non_proof_count == 2 && *proof_count == 4 - // } - // _ => false, - // } - // }) - // .map(|x| { - // ( - // x, - // ExpectedTestResult { - // inbound_message: None, - // proof_count: 0, - // mock_called_times: 2, - // }, - // ) - // }) - // .collect::>(); - // - // run_tests!(tests); - // } + mod two_routers { + use super::*; + + mod success { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 3, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod four_messages { + use super::*; + + const MESSAGE_COUNT: usize = 4; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 4, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 2, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + } + + mod failure { + use super::*; + + #[test] + fn message_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + InboundRouters::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) + .unwrap(), + ); + InboundDomainSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + router_hash: *ROUTER_HASH_2, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::MessageExpectedFromFirstRouter); + }); + } + + #[test] + fn proof_not_expected_from_first_router() { + new_test_ext().execute_with(|| { + let session_id = 1; + + InboundRouters::::insert( + TEST_DOMAIN_ADDRESS.domain(), + BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) + .unwrap(), + ); + InboundDomainSessions::::insert( + TEST_DOMAIN_ADDRESS.domain(), + session_id, + ); + + let gateway_message = GatewayMessage::Inbound { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Proof(MESSAGE_PROOF), + router_hash: *ROUTER_HASH_1, + }; + + let (res, _) = LiquidityPoolsGateway::process(gateway_message); + assert_noop!(res, Error::::ProofNotExpectedFromFirstRouter); + }); + } + } + } + + mod three_routers { + use super::*; + + lazy_static! { + static ref TEST_DATA: Vec = vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ]; + } + + mod two_messages { + use super::*; + + const MESSAGE_COUNT: usize = 2; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2, *ROUTER_HASH_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + + mod three_messages { + use super::*; + + const MESSAGE_COUNT: usize = 3; + + #[test] + fn success() { + let expected_results: HashMap, ExpectedTestResult> = + HashMap::from([ + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 6, + }), + ), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 3, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + ( + *ROUTER_HASH_1, + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 4, + }), + ), + (*ROUTER_HASH_2, None), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_1, Message::Simple), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 1, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + (*ROUTER_HASH_2, None), + (*ROUTER_HASH_3, None), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ], + }, + ), + ( + vec![ + (*ROUTER_HASH_2, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + (*ROUTER_HASH_3, Message::Proof(MESSAGE_PROOF)), + ], + ExpectedTestResult { + message_submitted_times: 0, + expected_storage_entries: vec![ + (*ROUTER_HASH_1, None), + ( + *ROUTER_HASH_2, + Some(InboundEntry::::Proof { + current_count: 1, + }), + ), + ( + *ROUTER_HASH_3, + Some(InboundEntry::::Proof { + current_count: 2, + }), + ), + ], + }, + ), + ]); + + let suite = generate_test_suite( + vec![*ROUTER_HASH_1, *ROUTER_HASH_2, *ROUTER_HASH_3], + TEST_DATA.clone(), + expected_results, + MESSAGE_COUNT, + ); + + run_inbound_message_test_suite(suite); + } + } + } #[test] fn inbound_message_handler_error() { From 2535664782a9cc0d4c86a17b025abab8bb21afd5 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Sat, 10 Aug 2024 13:34:38 +0300 Subject: [PATCH 07/11] docs: Improve comments --- libs/traits/src/liquidity_pools.rs | 3 + .../routers/src/lib.rs | 2 - pallets/liquidity-pools-gateway/src/lib.rs | 60 ++++++++++++++----- pallets/liquidity-pools-gateway/src/tests.rs | 23 ++++--- 4 files changed, 63 insertions(+), 25 deletions(-) diff --git a/libs/traits/src/liquidity_pools.rs b/libs/traits/src/liquidity_pools.rs index 2dfd5d891b..6ff5b4b588 100644 --- a/libs/traits/src/liquidity_pools.rs +++ b/libs/traits/src/liquidity_pools.rs @@ -37,7 +37,10 @@ pub trait LPEncoding: Sized { /// It's the identity message for composing messages with pack_with fn empty() -> Self; + /// Retrieves the message proof, if any. fn get_message_proof(&self) -> Option; + + /// Converts the message into a message proof type. fn to_message_proof(&self) -> Self; } diff --git a/pallets/liquidity-pools-gateway/routers/src/lib.rs b/pallets/liquidity-pools-gateway/routers/src/lib.rs index df47238053..47e3601398 100644 --- a/pallets/liquidity-pools-gateway/routers/src/lib.rs +++ b/pallets/liquidity-pools-gateway/routers/src/lib.rs @@ -93,9 +93,7 @@ where fn hash(&self) -> Self::Hash { match self { - DomainRouter::EthereumXCM(r) => r.hash(), DomainRouter::AxelarEVM(r) => r.hash(), - DomainRouter::AxelarXCM(r) => r.hash(), } } } diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index f62b872b7e..454a293fb5 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -59,7 +59,7 @@ mod mock; #[cfg(test)] mod tests; -/// Type that stores the information required when processing inbound messages. +/// Type used when storing inbound message information. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] #[scale_info(skip_type_params(T))] pub enum InboundEntry { @@ -73,6 +73,7 @@ pub enum InboundEntry { }, } +/// Type used when processing inbound messages. #[derive(Clone)] pub struct InboundProcessingInfo { domain_address: DomainAddress, @@ -83,6 +84,8 @@ pub struct InboundProcessingInfo { #[frame_support::pallet] pub mod pallet { + use sp_arithmetic::traits::EnsureAdd; + use super::*; const STORAGE_VERSION: StorageVersion = StorageVersion::new(1); @@ -262,13 +265,13 @@ pub mod pallet { pub type InboundRouters = StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; - /// Storage for the session ID of an inbound domain. + /// Storage for the inbound message session IDs. #[pallet::storage] - #[pallet::getter(fn inbound_domain_sessions)] - pub type InboundDomainSessions = + #[pallet::getter(fn inbound_message_sessions)] + pub type InboundMessageSessions = StorageMap<_, Blake2_128Concat, Domain, T::SessionId>; - /// Storage for inbound router session IDs. + /// Storage for inbound message session IDs. #[pallet::storage] pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; @@ -529,13 +532,22 @@ pub mod pallet { ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; - let session_id = SessionIdStore::::try_mutate(|n| { - n.ensure_add_assign(One::one())?; - Ok::(*n) + let (old_session_id, new_session_id) = SessionIdStore::::try_mutate(|n| { + let old_session_id = *n; + let new_session_id = n.ensure_add(One::one())?; + + *n = new_session_id; + + Ok::<(T::SessionId, T::SessionId), DispatchError>((old_session_id, new_session_id)) })?; InboundRouters::::insert(domain.clone(), router_hashes.clone()); - InboundDomainSessions::::insert(domain.clone(), session_id); + InboundMessageSessions::::insert(domain.clone(), new_session_id); + + //TODO(cdamian): The storages are updated with the new session. + // We can process the removal of entries associated with the old entries + // `on_idle`. + let _ = PendingInboundEntries::::clear_prefix(old_session_id, u32::MAX, None); Self::deposit_event(Event::InboundRoutersSet { domain, @@ -562,6 +574,8 @@ pub mod pallet { } impl Pallet { + /// Calculates and returns the proof count required for processing one + /// inbound message. fn get_expected_proof_count(domain: &Domain) -> Result { let routers = InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; @@ -571,6 +585,7 @@ pub mod pallet { Ok(expected_proof_count as u32) } + /// Gets the message proof for a message. fn get_message_proof(message: T::Message) -> Proof { match message.get_message_proof() { None => message @@ -581,6 +596,8 @@ pub mod pallet { } } + /// Creates an inbound entry based on whether the inbound message is a + /// proof or not. fn create_inbound_entry( domain_address: DomainAddress, message: T::Message, @@ -634,6 +651,8 @@ pub mod pallet { } } + /// Updates the inbound entry for a particular message, increasing the + /// counts accordingly. fn update_pending_entry( session_id: T::SessionId, message_proof: Proof, @@ -662,7 +681,6 @@ pub mod pallet { .. } => old.ensure_add_assign(new).map_err(|e| e.into()), InboundEntry::Proof { .. } => { - //TODO(cdamian): Test with 2 routers Err(Error::::ExpectedMessageType.into()) } }, @@ -679,6 +697,7 @@ pub mod pallet { ) } + /// Creates, validates and updates the inbound entry. fn validate_and_update_pending_entries( inbound_processing_info: &InboundProcessingInfo, message: T::Message, @@ -705,6 +724,8 @@ pub mod pallet { Ok(()) } + /// Checks if the number of proofs required for executing one message + /// were received, and returns the message if so. fn get_executable_message( inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, @@ -741,6 +762,8 @@ pub mod pallet { None } + /// Decreases the counts for inbound entries and removes them if the + /// counts reach 0. fn decrease_pending_entries_counts( inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, @@ -790,6 +813,8 @@ pub mod pallet { Ok(()) } + /// Retrieves the information required for processing an inbound + /// message. fn get_inbound_processing_info( domain_address: DomainAddress, weight: &mut Weight, @@ -799,7 +824,7 @@ pub mod pallet { weight.saturating_accrue(T::DbWeight::get().reads(1)); - let current_session_id = InboundDomainSessions::::get(domain_address.domain()) + let current_session_id = InboundMessageSessions::::get(domain_address.domain()) .ok_or(Error::::InboundDomainSessionNotFound)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); @@ -816,7 +841,8 @@ pub mod pallet { }) } - /// Give the message to the `InboundMessageHandler` to be processed. + /// Iterates over a batch of messages and checks if the requirements for + /// processing each message are met. fn process_inbound_message( domain_address: DomainAddress, message: T::Message, @@ -874,9 +900,8 @@ pub mod pallet { (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) } - /// Retrieves the router stored for the provided domain, sends the - /// message using the router, and calculates and returns the required - /// weight for these operations in the `DispatchResultWithPostInfo`. + /// Retrieves the stored router, sends the message, and calculates and + /// returns the router operation result and the weight used. fn process_outbound_message( sender: T::AccountId, message: T::Message, @@ -896,6 +921,8 @@ pub mod pallet { (result, router_weight.unwrap_or(read_weight)) } + /// Retrieves the hashes of the routers set for a domain and queues the + /// message and proofs accordingly. fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { let router_hashes = OutboundDomainRouters::::get(destination.clone()) .ok_or(Error::::MultiRouterNotFound)?; @@ -973,7 +1000,8 @@ pub mod pallet { } } - /// Process a message. + /// Returns the max processing weight for a message, based on its + /// direction. fn max_processing_weight(msg: &Self::Message) -> Weight { match msg { GatewayMessage::Inbound { message, .. } => { diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index c2352ad917..74908488db 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -597,7 +597,7 @@ mod message_processor_impl { TEST_DOMAIN_ADDRESS.domain(), BoundedVec::try_from(test_routers.clone()).unwrap(), ); - InboundDomainSessions::::insert( + InboundMessageSessions::::insert( TEST_DOMAIN_ADDRESS.domain(), session_id, ); @@ -641,7 +641,7 @@ mod message_processor_impl { } } - /// Generate all `TestEntry` combinations like: + /// Used for generating all `RouterMessage` combinations like: /// /// vec![ /// (*ROUTER_HASH_1, Message::Simple), @@ -674,24 +674,33 @@ mod message_processor_impl { .collect::>() } + /// Type used for mapping a message to a router hash. pub type RouterMessage = (H256, Message); + /// Type used for aggregating tests for inbound messages. pub struct InboundMessageTestSuite { pub routers: Vec, pub tests: Vec, } + /// Type used for defining a test which contains a set of + /// `RouterMessage` combinations and the expected test result. pub struct InboundMessageTest { pub router_messages: Vec, pub expected_test_result: ExpectedTestResult, } + /// Type used for defining the number of expected inbound message + /// submission and the exected storage state. #[derive(Clone, Debug)] pub struct ExpectedTestResult { pub message_submitted_times: u32, pub expected_storage_entries: Vec<(H256, Option>)>, } + /// Generates the combinations of `RouterMessage` used when testing, + /// maps the `ExpectedTestResult` for each and creates the + /// `InboundMessageTestSuite`. pub fn generate_test_suite( routers: Vec, test_data: Vec, @@ -744,7 +753,7 @@ mod message_processor_impl { domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); - InboundDomainSessions::::insert(domain_address.domain(), session_id); + InboundMessageSessions::::insert(domain_address.domain(), session_id); let handler = MockLiquidityPools::mock_handle( move |mock_domain_address, mock_message| { @@ -825,7 +834,7 @@ mod message_processor_impl { domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); - InboundDomainSessions::::insert(domain_address.domain(), session_id); + InboundMessageSessions::::insert(domain_address.domain(), session_id); let (res, _) = LiquidityPoolsGateway::process(gateway_message); assert_noop!(res, Error::::UnknownInboundMessageRouter); @@ -850,7 +859,7 @@ mod message_processor_impl { domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); - InboundDomainSessions::::insert(domain_address.domain(), session_id); + InboundMessageSessions::::insert(domain_address.domain(), session_id); PendingInboundEntries::::insert( session_id, (message_proof, router_hash), @@ -1480,7 +1489,7 @@ mod message_processor_impl { BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) .unwrap(), ); - InboundDomainSessions::::insert( + InboundMessageSessions::::insert( TEST_DOMAIN_ADDRESS.domain(), session_id, ); @@ -1506,7 +1515,7 @@ mod message_processor_impl { BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) .unwrap(), ); - InboundDomainSessions::::insert( + InboundMessageSessions::::insert( TEST_DOMAIN_ADDRESS.domain(), session_id, ); From 0a6482df5abe5f575193dcf5b874998db1083a6d Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Sat, 10 Aug 2024 13:41:39 +0300 Subject: [PATCH 08/11] lp-gateway: Move message processing logic to a new file --- pallets/liquidity-pools-gateway/src/lib.rs | 409 +---------------- .../src/message_processing.rs | 417 ++++++++++++++++++ pallets/liquidity-pools-gateway/src/tests.rs | 2 +- 3 files changed, 421 insertions(+), 407 deletions(-) create mode 100644 pallets/liquidity-pools-gateway/src/message_processing.rs diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 454a293fb5..1e8049e89b 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -42,9 +42,9 @@ pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; use sp_arithmetic::traits::{BaseArithmetic, EnsureSub, One}; use sp_runtime::traits::EnsureAddAssign; -use sp_std::{cmp::Ordering, convert::TryInto, vec::Vec}; +use sp_std::{convert::TryInto, vec::Vec}; -use crate::weights::WeightInfo; +use crate::{message_processing::InboundEntry, weights::WeightInfo}; mod origin; pub use origin::*; @@ -56,32 +56,10 @@ pub mod weights; #[cfg(test)] mod mock; +mod message_processing; #[cfg(test)] mod tests; -/// Type used when storing inbound message information. -#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] -#[scale_info(skip_type_params(T))] -pub enum InboundEntry { - Message { - domain_address: DomainAddress, - message: T::Message, - expected_proof_count: u32, - }, - Proof { - current_count: u32, - }, -} - -/// Type used when processing inbound messages. -#[derive(Clone)] -pub struct InboundProcessingInfo { - domain_address: DomainAddress, - inbound_routers: BoundedVec, - current_session_id: T::SessionId, - expected_proof_count_per_message: u32, -} - #[frame_support::pallet] pub mod pallet { use sp_arithmetic::traits::EnsureAdd; @@ -573,387 +551,6 @@ pub mod pallet { } } - impl Pallet { - /// Calculates and returns the proof count required for processing one - /// inbound message. - fn get_expected_proof_count(domain: &Domain) -> Result { - let routers = - InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; - - let expected_proof_count = routers.len().ensure_sub(1)?; - - Ok(expected_proof_count as u32) - } - - /// Gets the message proof for a message. - fn get_message_proof(message: T::Message) -> Proof { - match message.get_message_proof() { - None => message - .to_message_proof() - .get_message_proof() - .expect("message proof ensured by 'to_message_proof'"), - Some(proof) => proof, - } - } - - /// Creates an inbound entry based on whether the inbound message is a - /// proof or not. - fn create_inbound_entry( - domain_address: DomainAddress, - message: T::Message, - expected_proof_count: u32, - ) -> InboundEntry { - match message.get_message_proof() { - None => InboundEntry::Message { - domain_address, - message, - expected_proof_count, - }, - Some(_) => InboundEntry::Proof { current_count: 1 }, - } - } - - /// Validation ensures that: - /// - /// - the router that sent the inbound message is a valid router for the - /// specific domain. - /// - messages are only sent by the first inbound router. - /// - proofs are not sent by the first inbound router. - fn validate_inbound_entry( - inbound_processing_info: &InboundProcessingInfo, - router_hash: T::Hash, - inbound_entry: &InboundEntry, - ) -> DispatchResult { - let inbound_routers = inbound_processing_info.inbound_routers.clone(); - - ensure!( - inbound_routers.iter().any(|x| x == &router_hash), - Error::::UnknownInboundMessageRouter - ); - - match inbound_entry { - InboundEntry::Message { .. } => { - ensure!( - inbound_routers.get(0) == Some(&router_hash), - Error::::MessageExpectedFromFirstRouter - ); - - Ok(()) - } - InboundEntry::Proof { .. } => { - ensure!( - inbound_routers.get(0) != Some(&router_hash), - Error::::ProofNotExpectedFromFirstRouter - ); - - Ok(()) - } - } - } - - /// Updates the inbound entry for a particular message, increasing the - /// counts accordingly. - fn update_pending_entry( - session_id: T::SessionId, - message_proof: Proof, - router_hash: T::Hash, - inbound_entry: InboundEntry, - weight: &mut Weight, - ) -> DispatchResult { - weight.saturating_accrue(T::DbWeight::get().writes(1)); - - PendingInboundEntries::::try_mutate( - session_id, - (message_proof, router_hash), - |storage_entry| match storage_entry { - None => { - *storage_entry = Some(inbound_entry); - - Ok::<(), DispatchError>(()) - } - Some(stored_inbound_entry) => match stored_inbound_entry { - InboundEntry::Message { - expected_proof_count: old, - .. - } => match inbound_entry { - InboundEntry::Message { - expected_proof_count: new, - .. - } => old.ensure_add_assign(new).map_err(|e| e.into()), - InboundEntry::Proof { .. } => { - Err(Error::::ExpectedMessageType.into()) - } - }, - InboundEntry::Proof { current_count: old } => match inbound_entry { - InboundEntry::Proof { current_count: new } => { - old.ensure_add_assign(new).map_err(|e| e.into()) - } - InboundEntry::Message { .. } => { - Err(Error::::ExpectedMessageProofType.into()) - } - }, - }, - }, - ) - } - - /// Creates, validates and updates the inbound entry. - fn validate_and_update_pending_entries( - inbound_processing_info: &InboundProcessingInfo, - message: T::Message, - message_proof: Proof, - router_hash: T::Hash, - weight: &mut Weight, - ) -> DispatchResult { - let inbound_entry = Self::create_inbound_entry( - inbound_processing_info.domain_address.clone(), - message, - inbound_processing_info.expected_proof_count_per_message, - ); - - Self::validate_inbound_entry(&inbound_processing_info, router_hash, &inbound_entry)?; - - Self::update_pending_entry( - inbound_processing_info.current_session_id, - message_proof, - router_hash, - inbound_entry, - weight, - )?; - - Ok(()) - } - - /// Checks if the number of proofs required for executing one message - /// were received, and returns the message if so. - fn get_executable_message( - inbound_processing_info: &InboundProcessingInfo, - message_proof: Proof, - ) -> Option { - let mut message = None; - let mut votes = 0; - - for inbound_router in &inbound_processing_info.inbound_routers { - match PendingInboundEntries::::get( - inbound_processing_info.current_session_id, - (message_proof, inbound_router), - ) { - // We expected one InboundEntry for each router, if that's not the case, - // we can return. - None => return None, - Some(inbound_entry) => match inbound_entry { - InboundEntry::Message { - message: stored_message, - .. - } => message = Some(stored_message), - InboundEntry::Proof { current_count } => { - if current_count > 0 { - votes += 1; - } - } - }, - }; - } - - if votes == inbound_processing_info.expected_proof_count_per_message { - return message; - } - - None - } - - /// Decreases the counts for inbound entries and removes them if the - /// counts reach 0. - fn decrease_pending_entries_counts( - inbound_processing_info: &InboundProcessingInfo, - message_proof: Proof, - ) -> DispatchResult { - for inbound_router in &inbound_processing_info.inbound_routers { - match PendingInboundEntries::::try_mutate( - inbound_processing_info.current_session_id, - (message_proof, inbound_router), - |storage_entry| match storage_entry { - None => Err(Error::::PendingInboundEntryNotFound.into()), - Some(stored_inbound_entry) => match stored_inbound_entry { - InboundEntry::Message { - expected_proof_count, - .. - } => { - let updated_count = (*expected_proof_count).ensure_sub( - inbound_processing_info.expected_proof_count_per_message, - )?; - - if updated_count == 0 { - *storage_entry = None; - } else { - *expected_proof_count = updated_count; - } - - Ok::<(), DispatchError>(()) - } - InboundEntry::Proof { current_count } => { - let updated_count = (*current_count).ensure_sub(1)?; - - if updated_count == 0 { - *storage_entry = None; - } else { - *current_count = updated_count; - } - - Ok::<(), DispatchError>(()) - } - }, - }, - ) { - Ok(()) => {} - Err(e) => return Err(e), - } - } - - Ok(()) - } - - /// Retrieves the information required for processing an inbound - /// message. - fn get_inbound_processing_info( - domain_address: DomainAddress, - weight: &mut Weight, - ) -> Result, DispatchError> { - let inbound_routers = InboundRouters::::get(domain_address.domain()) - .ok_or(Error::::MultiRouterNotFound)?; - - weight.saturating_accrue(T::DbWeight::get().reads(1)); - - let current_session_id = InboundMessageSessions::::get(domain_address.domain()) - .ok_or(Error::::InboundDomainSessionNotFound)?; - - weight.saturating_accrue(T::DbWeight::get().reads(1)); - - let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?; - - weight.saturating_accrue(T::DbWeight::get().reads(1)); - - Ok(InboundProcessingInfo { - domain_address, - inbound_routers, - current_session_id, - expected_proof_count_per_message: expected_proof_count, - }) - } - - /// Iterates over a batch of messages and checks if the requirements for - /// processing each message are met. - fn process_inbound_message( - domain_address: DomainAddress, - message: T::Message, - router_hash: T::Hash, - ) -> (DispatchResult, Weight) { - let mut weight = Default::default(); - - let inbound_processing_info = - match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) { - Ok(i) => i, - Err(e) => return (Err(e), weight), - }; - - weight.saturating_accrue( - Weight::from_parts(0, T::Message::max_encoded_len() as u64) - .saturating_add(LP_DEFENSIVE_WEIGHT), - ); - - let mut count = 0; - - for submessage in message.submessages() { - count += 1; - - let message_proof = Self::get_message_proof(message.clone()); - - if let Err(e) = Self::validate_and_update_pending_entries( - &inbound_processing_info, - submessage.clone(), - message_proof, - router_hash, - &mut weight, - ) { - return (Err(e), weight); - } - - match Self::get_executable_message(&inbound_processing_info, message_proof) { - Some(m) => { - if let Err(e) = Self::decrease_pending_entries_counts( - &inbound_processing_info, - message_proof, - ) { - return (Err(e), weight.saturating_mul(count)); - } - - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) - { - // We only consume the processed weight if error during the batch - return (Err(e), weight.saturating_mul(count)); - } - } - None => continue, - } - } - - (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) - } - - /// Retrieves the stored router, sends the message, and calculates and - /// returns the router operation result and the weight used. - fn process_outbound_message( - sender: T::AccountId, - message: T::Message, - router_hash: T::Hash, - ) -> (DispatchResult, Weight) { - let read_weight = T::DbWeight::get().reads(1); - - let Some(router) = OutboundRouters::::get(router_hash) else { - return (Err(Error::::RouterNotFound.into()), read_weight); - }; - - let (result, router_weight) = match router.send(sender, message.serialize()) { - Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), - Err(e) => (Err(e.error), e.post_info.actual_weight), - }; - - (result, router_weight.unwrap_or(read_weight)) - } - - /// Retrieves the hashes of the routers set for a domain and queues the - /// message and proofs accordingly. - fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - let router_hashes = OutboundDomainRouters::::get(destination.clone()) - .ok_or(Error::::MultiRouterNotFound)?; - - let message_proof = message.to_message_proof(); - let mut message_opt = Some(message); - - for router_hash in router_hashes { - // Ensure that we only send the actual message once, using one router. - // The remaining routers will send the message proof. - let router_msg = match message_opt.take() { - Some(m) => m, - None => message_proof.clone(), - }; - - // We are using the sender specified in the pallet config so that we can - // ensure that the account is funded - let gateway_message = - GatewayMessage::::Outbound { - sender: T::Sender::get(), - message: router_msg, - router_hash, - }; - - T::MessageQueue::submit(gateway_message)?; - } - - Ok(()) - } - } - impl OutboundMessageHandler for Pallet { type Destination = Domain; type Message = T::Message; diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs new file mode 100644 index 0000000000..9b939329ef --- /dev/null +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -0,0 +1,417 @@ +use cfg_primitives::LP_DEFENSIVE_WEIGHT; +use cfg_traits::liquidity_pools::{InboundMessageHandler, LPEncoding, MessageQueue, Proof, Router}; +use cfg_types::domain_address::{Domain, DomainAddress}; +use frame_support::{ + dispatch::DispatchResult, + ensure, + pallet_prelude::{Decode, Encode, Get, TypeInfo}, + weights::Weight, + BoundedVec, +}; +use parity_scale_codec::MaxEncodedLen; +use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub}; +use sp_runtime::DispatchError; + +use crate::{ + message::GatewayMessage, Config, Error, InboundMessageSessions, InboundRouters, + OutboundDomainRouters, OutboundRouters, Pallet, PendingInboundEntries, +}; + +/// Type used when storing inbound message information. +#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] +#[scale_info(skip_type_params(T))] +pub enum InboundEntry { + Message { + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + }, + Proof { + current_count: u32, + }, +} + +/// Type used when processing inbound messages. +#[derive(Clone)] +pub struct InboundProcessingInfo { + domain_address: DomainAddress, + inbound_routers: BoundedVec, + current_session_id: T::SessionId, + expected_proof_count_per_message: u32, +} + +impl Pallet { + /// Calculates and returns the proof count required for processing one + /// inbound message. + fn get_expected_proof_count(domain: &Domain) -> Result { + let routers = InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; + + let expected_proof_count = routers.len().ensure_sub(1)?; + + Ok(expected_proof_count as u32) + } + + /// Gets the message proof for a message. + fn get_message_proof(message: T::Message) -> Proof { + match message.get_message_proof() { + None => message + .to_message_proof() + .get_message_proof() + .expect("message proof ensured by 'to_message_proof'"), + Some(proof) => proof, + } + } + + /// Creates an inbound entry based on whether the inbound message is a + /// proof or not. + fn create_inbound_entry( + domain_address: DomainAddress, + message: T::Message, + expected_proof_count: u32, + ) -> InboundEntry { + match message.get_message_proof() { + None => InboundEntry::Message { + domain_address, + message, + expected_proof_count, + }, + Some(_) => InboundEntry::Proof { current_count: 1 }, + } + } + + /// Validation ensures that: + /// + /// - the router that sent the inbound message is a valid router for the + /// specific domain. + /// - messages are only sent by the first inbound router. + /// - proofs are not sent by the first inbound router. + fn validate_inbound_entry( + inbound_processing_info: &InboundProcessingInfo, + router_hash: T::Hash, + inbound_entry: &InboundEntry, + ) -> DispatchResult { + let inbound_routers = inbound_processing_info.inbound_routers.clone(); + + ensure!( + inbound_routers.iter().any(|x| x == &router_hash), + Error::::UnknownInboundMessageRouter + ); + + match inbound_entry { + InboundEntry::Message { .. } => { + ensure!( + inbound_routers.get(0) == Some(&router_hash), + Error::::MessageExpectedFromFirstRouter + ); + + Ok(()) + } + InboundEntry::Proof { .. } => { + ensure!( + inbound_routers.get(0) != Some(&router_hash), + Error::::ProofNotExpectedFromFirstRouter + ); + + Ok(()) + } + } + } + + /// Updates the inbound entry for a particular message, increasing the + /// counts accordingly. + fn update_pending_entry( + session_id: T::SessionId, + message_proof: Proof, + router_hash: T::Hash, + inbound_entry: InboundEntry, + weight: &mut Weight, + ) -> DispatchResult { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + PendingInboundEntries::::try_mutate( + session_id, + (message_proof, router_hash), + |storage_entry| match storage_entry { + None => { + *storage_entry = Some(inbound_entry); + + Ok::<(), DispatchError>(()) + } + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count: old, + .. + } => match inbound_entry { + InboundEntry::Message { + expected_proof_count: new, + .. + } => old.ensure_add_assign(new).map_err(|e| e.into()), + InboundEntry::Proof { .. } => Err(Error::::ExpectedMessageType.into()), + }, + InboundEntry::Proof { current_count: old } => match inbound_entry { + InboundEntry::Proof { current_count: new } => { + old.ensure_add_assign(new).map_err(|e| e.into()) + } + InboundEntry::Message { .. } => { + Err(Error::::ExpectedMessageProofType.into()) + } + }, + }, + }, + ) + } + + /// Creates, validates and updates the inbound entry. + fn validate_and_update_pending_entries( + inbound_processing_info: &InboundProcessingInfo, + message: T::Message, + message_proof: Proof, + router_hash: T::Hash, + weight: &mut Weight, + ) -> DispatchResult { + let inbound_entry = Self::create_inbound_entry( + inbound_processing_info.domain_address.clone(), + message, + inbound_processing_info.expected_proof_count_per_message, + ); + + Self::validate_inbound_entry(&inbound_processing_info, router_hash, &inbound_entry)?; + + Self::update_pending_entry( + inbound_processing_info.current_session_id, + message_proof, + router_hash, + inbound_entry, + weight, + )?; + + Ok(()) + } + + /// Checks if the number of proofs required for executing one message + /// were received, and returns the message if so. + fn get_executable_message( + inbound_processing_info: &InboundProcessingInfo, + message_proof: Proof, + ) -> Option { + let mut message = None; + let mut votes = 0; + + for inbound_router in &inbound_processing_info.inbound_routers { + match PendingInboundEntries::::get( + inbound_processing_info.current_session_id, + (message_proof, inbound_router), + ) { + // We expected one InboundEntry for each router, if that's not the case, + // we can return. + None => return None, + Some(inbound_entry) => match inbound_entry { + InboundEntry::Message { + message: stored_message, + .. + } => message = Some(stored_message), + InboundEntry::Proof { current_count } => { + if current_count > 0 { + votes += 1; + } + } + }, + }; + } + + if votes == inbound_processing_info.expected_proof_count_per_message { + return message; + } + + None + } + + /// Decreases the counts for inbound entries and removes them if the + /// counts reach 0. + fn decrease_pending_entries_counts( + inbound_processing_info: &InboundProcessingInfo, + message_proof: Proof, + ) -> DispatchResult { + for inbound_router in &inbound_processing_info.inbound_routers { + match PendingInboundEntries::::try_mutate( + inbound_processing_info.current_session_id, + (message_proof, inbound_router), + |storage_entry| match storage_entry { + None => Err(Error::::PendingInboundEntryNotFound.into()), + Some(stored_inbound_entry) => match stored_inbound_entry { + InboundEntry::Message { + expected_proof_count, + .. + } => { + let updated_count = (*expected_proof_count).ensure_sub( + inbound_processing_info.expected_proof_count_per_message, + )?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *expected_proof_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + InboundEntry::Proof { current_count } => { + let updated_count = (*current_count).ensure_sub(1)?; + + if updated_count == 0 { + *storage_entry = None; + } else { + *current_count = updated_count; + } + + Ok::<(), DispatchError>(()) + } + }, + }, + ) { + Ok(()) => {} + Err(e) => return Err(e), + } + } + + Ok(()) + } + + /// Retrieves the information required for processing an inbound + /// message. + fn get_inbound_processing_info( + domain_address: DomainAddress, + weight: &mut Weight, + ) -> Result, DispatchError> { + let inbound_routers = InboundRouters::::get(domain_address.domain()) + .ok_or(Error::::MultiRouterNotFound)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let current_session_id = InboundMessageSessions::::get(domain_address.domain()) + .ok_or(Error::::InboundDomainSessionNotFound)?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?; + + weight.saturating_accrue(T::DbWeight::get().reads(1)); + + Ok(InboundProcessingInfo { + domain_address, + inbound_routers, + current_session_id, + expected_proof_count_per_message: expected_proof_count, + }) + } + + /// Iterates over a batch of messages and checks if the requirements for + /// processing each message are met. + pub(crate) fn process_inbound_message( + domain_address: DomainAddress, + message: T::Message, + router_hash: T::Hash, + ) -> (DispatchResult, Weight) { + let mut weight = Default::default(); + + let inbound_processing_info = + match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) { + Ok(i) => i, + Err(e) => return (Err(e), weight), + }; + + weight.saturating_accrue( + Weight::from_parts(0, T::Message::max_encoded_len() as u64) + .saturating_add(LP_DEFENSIVE_WEIGHT), + ); + + let mut count = 0; + + for submessage in message.submessages() { + count += 1; + + let message_proof = Self::get_message_proof(message.clone()); + + if let Err(e) = Self::validate_and_update_pending_entries( + &inbound_processing_info, + submessage.clone(), + message_proof, + router_hash, + &mut weight, + ) { + return (Err(e), weight); + } + + match Self::get_executable_message(&inbound_processing_info, message_proof) { + Some(m) => { + if let Err(e) = Self::decrease_pending_entries_counts( + &inbound_processing_info, + message_proof, + ) { + return (Err(e), weight.saturating_mul(count)); + } + + if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) { + // We only consume the processed weight if error during the batch + return (Err(e), weight.saturating_mul(count)); + } + } + None => continue, + } + } + + (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) + } + + /// Retrieves the stored router, sends the message, and calculates and + /// returns the router operation result and the weight used. + pub(crate) fn process_outbound_message( + sender: T::AccountId, + message: T::Message, + router_hash: T::Hash, + ) -> (DispatchResult, Weight) { + let read_weight = T::DbWeight::get().reads(1); + + let Some(router) = OutboundRouters::::get(router_hash) else { + return (Err(Error::::RouterNotFound.into()), read_weight); + }; + + let (result, router_weight) = match router.send(sender, message.serialize()) { + Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), + Err(e) => (Err(e.error), e.post_info.actual_weight), + }; + + (result, router_weight.unwrap_or(read_weight)) + } + + /// Retrieves the hashes of the routers set for a domain and queues the + /// message and proofs accordingly. + pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { + let router_hashes = OutboundDomainRouters::::get(destination.clone()) + .ok_or(Error::::MultiRouterNotFound)?; + + let message_proof = message.to_message_proof(); + let mut message_opt = Some(message); + + for router_hash in router_hashes { + // Ensure that we only send the actual message once, using one router. + // The remaining routers will send the message proof. + let router_msg = match message_opt.take() { + Some(m) => m, + None => message_proof.clone(), + }; + + // We are using the sender specified in the pallet config so that we can + // ensure that the account is funded + let gateway_message = GatewayMessage::::Outbound { + sender: T::Sender::get(), + message: router_msg, + router_hash, + }; + + T::MessageQueue::submit(gateway_message)?; + } + + Ok(()) + } +} diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 74908488db..772f070d4d 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -23,7 +23,7 @@ use super::{ origin::*, pallet::*, }; -use crate::{GatewayMessage, InboundEntry}; +use crate::{message_processing::InboundEntry, GatewayMessage}; pub const TEST_DOMAIN_ADDRESS: DomainAddress = DomainAddress::EVM(0, [1; 20]); From 128bc9f190691e875f5bbcc26feaf95b65e872c3 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:28:23 +0300 Subject: [PATCH 09/11] lp-gateway: Merge inbound/outbound routers extrinsics into one, add logic for removing invalid session IDs on idle --- pallets/liquidity-pools-gateway/src/lib.rs | 173 ++++++------------ .../src/message_processing.rs | 120 ++++++++---- .../liquidity-pools-gateway/src/weights.rs | 52 +----- 3 files changed, 145 insertions(+), 200 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 1e8049e89b..458f410af5 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -62,6 +62,7 @@ mod tests; #[frame_support::pallet] pub mod pallet { + use frame_system::pallet_prelude::BlockNumberFor; use sp_arithmetic::traits::EnsureAdd; use super::*; @@ -75,6 +76,13 @@ pub mod pallet { #[pallet::origin] pub type Origin = GatewayOrigin; + #[pallet::hooks] + impl Hooks> for Pallet { + fn on_idle(_now: BlockNumberFor, max_weight: Weight) -> Weight { + Self::clear_invalid_session_ids(max_weight) + } + } + #[pallet::config] pub trait Config: frame_system::Config { /// The origin type. @@ -148,8 +156,12 @@ pub mod pallet { #[pallet::event] #[pallet::generate_deposit(pub (super) fn deposit_event)] pub enum Event { - /// The router for a given domain was set. - DomainRouterSet { domain: Domain, router: T::Router }, + /// The routers for a given domain were set. + RoutersSet { + domain: Domain, + //TODO(cdamian): Use T::RouterId + router_ids: BoundedVec, + }, /// An instance was added to a domain. InstanceAdded { instance: DomainAddress }, @@ -162,26 +174,24 @@ pub mod pallet { domain: Domain, hook_address: [u8; 20], }, - - /// The outbound routers for a given domain were set. - OutboundRoutersSet { - domain: Domain, - routers: BoundedVec, - }, - - /// Inbound routers were set. - InboundRoutersSet { - domain: Domain, - router_hashes: BoundedVec, - }, } - /// Storage for domain routers. + // TODO(cdamian): Add migration to clear this storage. + // /// Storage for domain routers. + // /// + // /// This can only be set by an admin. + // #[pallet::storage] + // #[pallet::getter(fn domain_routers)] + // pub type DomainRouters = StorageMap<_, Blake2_128Concat, Domain, + // T::Router>; + + /// Storage for routers specific for a domain. /// /// This can only be set by an admin. #[pallet::storage] - #[pallet::getter(fn domain_routers)] - pub type DomainRouters = StorageMap<_, Blake2_128Concat, Domain, T::Router>; + #[pallet::getter(fn routers)] + pub type Routers = + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage that contains a limited number of whitelisted instances of /// deployed liquidity pools for a particular domain. @@ -208,21 +218,6 @@ pub mod pallet { pub(crate) type PackedMessage = StorageMap<_, Blake2_128Concat, (T::AccountId, Domain), T::Message>; - /// Storage for outbound routers. - /// - /// This can only be set by an admin. - #[pallet::storage] - #[pallet::getter(fn routers)] - pub type OutboundRouters = StorageMap<_, Blake2_128Concat, T::Hash, T::Router>; - - /// Storage for outbound routers specific for a domain. - /// - /// This can only be set by an admin. - #[pallet::storage] - #[pallet::getter(fn outbound_domain_routers)] - pub type OutboundDomainRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; - /// Storage for pending inbound messages. #[pallet::storage] #[pallet::getter(fn pending_inbound_entries)] @@ -235,14 +230,6 @@ pub mod pallet { InboundEntry, >; - /// Storage for inbound routers specific for a domain. - /// - /// This can only be set by an admin. - #[pallet::storage] - #[pallet::getter(fn inbound_routers)] - pub type InboundRouters = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; - /// Storage for the inbound message session IDs. #[pallet::storage] #[pallet::getter(fn inbound_message_sessions)] @@ -253,6 +240,14 @@ pub mod pallet { #[pallet::storage] pub type SessionIdStore = StorageValue<_, T::SessionId, ValueQuery>; + /// Storage that keeps track of invalid session IDs. + /// + /// Any `PendingInboundEntries` mapped to the invalid IDs are removed from + /// storage during `on_idle`. + #[pallet::storage] + #[pallet::getter(fn invalid_session_ids)] + pub type InvalidSessionIds = StorageMap<_, Blake2_128Concat, T::SessionId, ()>; + #[pallet::error] pub enum Error { /// Router initialization failed. @@ -320,23 +315,34 @@ pub mod pallet { #[pallet::call] impl Pallet { - /// Set a domain's router, - #[pallet::weight(T::WeightInfo::set_domain_router())] + /// Sets the router IDs used for a specific domain, + #[pallet::weight(T::WeightInfo::set_domain_routers())] #[pallet::call_index(0)] - pub fn set_domain_router( + pub fn set_domain_routers( origin: OriginFor, domain: Domain, - router: T::Router, + //TODO(cdamian): Use T::RouterId + router_ids: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; - ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); + //TODO(cdamian): Outbound - Call router.init() for each router? - router.init().map_err(|_| Error::::RouterInitFailed)?; + >::insert(domain.clone(), router_ids.clone()); - >::insert(domain.clone(), router.clone()); + let (old_session_id, new_session_id) = SessionIdStore::::try_mutate(|n| { + let old_session_id = *n; + let new_session_id = old_session_id.ensure_add(One::one())?; + + *n = new_session_id; + + Ok::<(T::SessionId, T::SessionId), DispatchError>((old_session_id, new_session_id)) + })?; - Self::deposit_event(Event::DomainRouterSet { domain, router }); + InboundMessageSessions::::insert(domain.clone(), new_session_id); + InvalidSessionIds::::insert(old_session_id, ()); + + Self::deposit_event(Event::RoutersSet { domain, router_ids }); Ok(()) } @@ -466,81 +472,12 @@ pub mod pallet { } } - /// Set outbound routers for a particular domain. - #[pallet::weight(T::WeightInfo::set_outbound_routers())] - #[pallet::call_index(11)] - pub fn set_outbound_routers( - origin: OriginFor, - domain: Domain, - routers: BoundedVec, - ) -> DispatchResult { - T::AdminOrigin::ensure_origin(origin)?; - - ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); - - let mut router_hashes = Vec::new(); - - for router in &routers { - router.init().map_err(|_| Error::::RouterInitFailed)?; - - let router_hash = router.hash(); - - router_hashes.push(router_hash); - - OutboundRouters::::insert(router_hash, router); - } - - >::insert( - domain.clone(), - BoundedVec::try_from(router_hashes).map_err(|_| Error::::InvalidMultiRouter)?, - ); - - Self::deposit_event(Event::OutboundRoutersSet { domain, routers }); - - Ok(()) - } - - /// Set inbound routers. - #[pallet::weight(T::WeightInfo::set_inbound_routers())] - #[pallet::call_index(12)] - pub fn set_inbound_routers( - origin: OriginFor, - domain: Domain, - router_hashes: BoundedVec, - ) -> DispatchResult { - T::AdminOrigin::ensure_origin(origin)?; - - let (old_session_id, new_session_id) = SessionIdStore::::try_mutate(|n| { - let old_session_id = *n; - let new_session_id = n.ensure_add(One::one())?; - - *n = new_session_id; - - Ok::<(T::SessionId, T::SessionId), DispatchError>((old_session_id, new_session_id)) - })?; - - InboundRouters::::insert(domain.clone(), router_hashes.clone()); - InboundMessageSessions::::insert(domain.clone(), new_session_id); - - //TODO(cdamian): The storages are updated with the new session. - // We can process the removal of entries associated with the old entries - // `on_idle`. - let _ = PendingInboundEntries::::clear_prefix(old_session_id, u32::MAX, None); - - Self::deposit_event(Event::InboundRoutersSet { - domain, - router_hashes, - }); - - Ok(()) - } - /// Manually increase the proof count for a particular message and /// executes it if the required count is reached. /// /// Can only be called by `AdminOrigin`. #[pallet::weight(T::WeightInfo::execute_message_recovery())] - #[pallet::call_index(13)] + #[pallet::call_index(11)] pub fn execute_message_recovery( origin: OriginFor, message_proof: Proof, diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs index 9b939329ef..e65b70f41d 100644 --- a/pallets/liquidity-pools-gateway/src/message_processing.rs +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -13,10 +13,14 @@ use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub}; use sp_runtime::DispatchError; use crate::{ - message::GatewayMessage, Config, Error, InboundMessageSessions, InboundRouters, - OutboundDomainRouters, OutboundRouters, Pallet, PendingInboundEntries, + message::GatewayMessage, Config, Error, InboundMessageSessions, InvalidSessionIds, Pallet, + PendingInboundEntries, Routers, }; +/// The limit used when clearing the `PendingInboundEntries` for invalid +/// session IDs. +const INVALID_ID_REMOVAL_LIMIT: u32 = 100; + /// Type used when storing inbound message information. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] #[scale_info(skip_type_params(T))] @@ -35,7 +39,7 @@ pub enum InboundEntry { #[derive(Clone)] pub struct InboundProcessingInfo { domain_address: DomainAddress, - inbound_routers: BoundedVec, + routers: BoundedVec, current_session_id: T::SessionId, expected_proof_count_per_message: u32, } @@ -44,7 +48,7 @@ impl Pallet { /// Calculates and returns the proof count required for processing one /// inbound message. fn get_expected_proof_count(domain: &Domain) -> Result { - let routers = InboundRouters::::get(domain).ok_or(Error::::MultiRouterNotFound)?; + let routers = Routers::::get(domain).ok_or(Error::::MultiRouterNotFound)?; let expected_proof_count = routers.len().ensure_sub(1)?; @@ -90,17 +94,17 @@ impl Pallet { router_hash: T::Hash, inbound_entry: &InboundEntry, ) -> DispatchResult { - let inbound_routers = inbound_processing_info.inbound_routers.clone(); + let routers = inbound_processing_info.routers.clone(); ensure!( - inbound_routers.iter().any(|x| x == &router_hash), + routers.iter().any(|x| x == &router_hash), Error::::UnknownInboundMessageRouter ); match inbound_entry { InboundEntry::Message { .. } => { ensure!( - inbound_routers.get(0) == Some(&router_hash), + routers.get(0) == Some(&router_hash), Error::::MessageExpectedFromFirstRouter ); @@ -108,7 +112,7 @@ impl Pallet { } InboundEntry::Proof { .. } => { ensure!( - inbound_routers.get(0) != Some(&router_hash), + routers.get(0) != Some(&router_hash), Error::::ProofNotExpectedFromFirstRouter ); @@ -117,9 +121,9 @@ impl Pallet { } } - /// Updates the inbound entry for a particular message, increasing the - /// counts accordingly. - fn update_pending_entry( + /// Upserts an inbound entry for a particular message, increasing the + /// relevant counts accordingly. + fn upsert_pending_entry( session_id: T::SessionId, message_proof: Proof, router_hash: T::Hash, @@ -161,8 +165,8 @@ impl Pallet { ) } - /// Creates, validates and updates the inbound entry. - fn validate_and_update_pending_entries( + /// Creates, validates and upserts the inbound entry. + fn validate_and_upsert_pending_entries( inbound_processing_info: &InboundProcessingInfo, message: T::Message, message_proof: Proof, @@ -177,7 +181,7 @@ impl Pallet { Self::validate_inbound_entry(&inbound_processing_info, router_hash, &inbound_entry)?; - Self::update_pending_entry( + Self::upsert_pending_entry( inbound_processing_info.current_session_id, message_proof, router_hash, @@ -197,10 +201,10 @@ impl Pallet { let mut message = None; let mut votes = 0; - for inbound_router in &inbound_processing_info.inbound_routers { + for router in &inbound_processing_info.routers { match PendingInboundEntries::::get( inbound_processing_info.current_session_id, - (message_proof, inbound_router), + (message_proof, router), ) { // We expected one InboundEntry for each router, if that's not the case, // we can return. @@ -232,10 +236,10 @@ impl Pallet { inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, ) -> DispatchResult { - for inbound_router in &inbound_processing_info.inbound_routers { + for router in &inbound_processing_info.routers { match PendingInboundEntries::::try_mutate( inbound_processing_info.current_session_id, - (message_proof, inbound_router), + (message_proof, router), |storage_entry| match storage_entry { None => Err(Error::::PendingInboundEntryNotFound.into()), Some(stored_inbound_entry) => match stored_inbound_entry { @@ -283,8 +287,8 @@ impl Pallet { domain_address: DomainAddress, weight: &mut Weight, ) -> Result, DispatchError> { - let inbound_routers = InboundRouters::::get(domain_address.domain()) - .ok_or(Error::::MultiRouterNotFound)?; + let routers = + Routers::::get(domain_address.domain()).ok_or(Error::::MultiRouterNotFound)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); @@ -299,7 +303,7 @@ impl Pallet { Ok(InboundProcessingInfo { domain_address, - inbound_routers, + routers, current_session_id, expected_proof_count_per_message: expected_proof_count, }) @@ -332,7 +336,7 @@ impl Pallet { let message_proof = Self::get_message_proof(message.clone()); - if let Err(e) = Self::validate_and_update_pending_entries( + if let Err(e) = Self::validate_and_upsert_pending_entries( &inbound_processing_info, submessage.clone(), message_proof, @@ -372,23 +376,27 @@ impl Pallet { ) -> (DispatchResult, Weight) { let read_weight = T::DbWeight::get().reads(1); - let Some(router) = OutboundRouters::::get(router_hash) else { - return (Err(Error::::RouterNotFound.into()), read_weight); - }; + // TODO(cdamian): Update when the router refactor is done. - let (result, router_weight) = match router.send(sender, message.serialize()) { - Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), - Err(e) => (Err(e.error), e.post_info.actual_weight), - }; + // let Some(router) = Routers::::get(router_hash) else { + // return (Err(Error::::RouterNotFound.into()), read_weight); + // }; + // + // let (result, router_weight) = match router.send(sender, message.serialize()) + // { Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight), + // Err(e) => (Err(e.error), e.post_info.actual_weight), + // }; + // + // (result, router_weight.unwrap_or(read_weight)) - (result, router_weight.unwrap_or(read_weight)) + (Ok(()), read_weight) } /// Retrieves the hashes of the routers set for a domain and queues the /// message and proofs accordingly. pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - let router_hashes = OutboundDomainRouters::::get(destination.clone()) - .ok_or(Error::::MultiRouterNotFound)?; + let router_hashes = + Routers::::get(destination.clone()).ok_or(Error::::MultiRouterNotFound)?; let message_proof = message.to_message_proof(); let mut message_opt = Some(message); @@ -414,4 +422,52 @@ impl Pallet { Ok(()) } + + /// Clears `PendingInboundEntries` mapped to invalid session IDs as long as + /// there is enough weight available for this operation. + /// + /// The invalid session IDs are removed from storage if all entries mapped + /// to them were cleared. + pub(crate) fn clear_invalid_session_ids(max_weight: Weight) -> Weight { + let invalid_session_ids = InvalidSessionIds::::iter_keys().collect::>(); + + let mut weight = T::DbWeight::get().reads(1); + + for invalid_session_id in invalid_session_ids { + let mut cursor: Option> = None; + + loop { + let res = PendingInboundEntries::::clear_prefix( + invalid_session_id, + INVALID_ID_REMOVAL_LIMIT, + cursor.as_ref().map(|x| x.as_ref()), + ); + + weight.saturating_accrue( + T::DbWeight::get().reads_writes(res.loops.into(), res.unique.into()), + ); + + if weight.all_gte(max_weight) { + return weight; + } + + cursor = match res.maybe_cursor { + None => { + InvalidSessionIds::::remove(invalid_session_id); + + weight.saturating_accrue(T::DbWeight::get().writes(1)); + + if weight.all_gte(max_weight) { + return weight; + } + + break; + } + Some(c) => Some(c), + }; + } + } + + weight + } } diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index c568dfc4bf..aaf598ec91 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -13,20 +13,16 @@ use frame_support::weights::{constants::RocksDbWeight, Weight}; pub trait WeightInfo { - fn set_domain_router() -> Weight; + fn set_domain_routers() -> Weight; fn add_instance() -> Weight; fn remove_instance() -> Weight; fn add_relayer() -> Weight; fn remove_relayer() -> Weight; fn receive_message() -> Weight; - fn process_outbound_message() -> Weight; - fn process_failed_outbound_message() -> Weight; fn start_batch_message() -> Weight; fn end_batch_message() -> Weight; fn set_domain_hook_address() -> Weight; - fn set_outbound_routers() -> Weight; fn execute_message_recovery() -> Weight; - fn set_inbound_routers() -> Weight; } // NOTE: We use temporary weights here. `execute_epoch` is by far our heaviest @@ -35,7 +31,7 @@ pub trait WeightInfo { const N: u64 = 4; impl WeightInfo for () { - fn set_domain_router() -> Weight { + fn set_domain_routers() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` @@ -107,28 +103,6 @@ impl WeightInfo for () { .saturating_add(Weight::from_parts(0, 17774).saturating_mul(N)) } - fn process_outbound_message() -> Weight { - // TODO: BENCHMARK CORRECTLY - // - // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` - // This one has one read and one write for sure and possible one - // read for `AdminOrigin` - Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(1)) - } - - fn process_failed_outbound_message() -> Weight { - // TODO: BENCHMARK CORRECTLY - // - // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` - // This one has one read and one write for sure and possible one - // read for `AdminOrigin` - Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(1)) - } - fn start_batch_message() -> Weight { // TODO: BENCHMARK CORRECTLY // @@ -162,17 +136,6 @@ impl WeightInfo for () { .saturating_add(RocksDbWeight::get().writes(2)) } - fn set_outbound_routers() -> Weight { - // TODO: BENCHMARK CORRECTLY - // - // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` - // This one has one read and one write for sure and possible one - // read for `AdminOrigin` - Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(2)) - } - fn execute_message_recovery() -> Weight { // TODO: BENCHMARK CORRECTLY // @@ -183,15 +146,4 @@ impl WeightInfo for () { .saturating_add(RocksDbWeight::get().reads(2)) .saturating_add(RocksDbWeight::get().writes(2)) } - - fn set_inbound_routers() -> Weight { - // TODO: BENCHMARK CORRECTLY - // - // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve` - // This one has one read and one write for sure and possible one - // read for `AdminOrigin` - Weight::from_parts(30_117_000, 5991) - .saturating_add(RocksDbWeight::get().reads(2)) - .saturating_add(RocksDbWeight::get().writes(2)) - } } From 54ee3cb7f33b0f2313f9b33b88d9796a67fe8859 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Mon, 12 Aug 2024 21:01:47 +0300 Subject: [PATCH 10/11] lp-gateway: Unit tests WIP --- pallets/liquidity-pools-gateway/src/lib.rs | 47 +- .../liquidity-pools-gateway/src/message.rs | 10 +- .../src/message_processing.rs | 59 ++- pallets/liquidity-pools-gateway/src/mock.rs | 2 + pallets/liquidity-pools-gateway/src/tests.rs | 429 +++++++++++------- 5 files changed, 326 insertions(+), 221 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 458f410af5..663a07caa5 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -40,9 +40,8 @@ use message::GatewayMessage; use orml_traits::GetByKey; pub use pallet::*; use parity_scale_codec::{EncodeLike, FullCodec}; -use sp_arithmetic::traits::{BaseArithmetic, EnsureSub, One}; -use sp_runtime::traits::EnsureAddAssign; -use sp_std::{convert::TryInto, vec::Vec}; +use sp_arithmetic::traits::{BaseArithmetic, One}; +use sp_std::convert::TryInto; use crate::{message_processing::InboundEntry, weights::WeightInfo}; @@ -116,6 +115,9 @@ pub mod pallet { + EncodeLike + PartialEq; + /// The type used for identifying routers. + type RouterId: Clone + Debug + MaxEncodedLen + TypeInfo + FullCodec + EncodeLike + PartialEq; + /// The type that processes inbound messages. type InboundMessageHandler: InboundMessageHandler< Sender = DomainAddress, @@ -135,7 +137,7 @@ pub mod pallet { /// Type used for queueing messages. type MessageQueue: MessageQueue< - Message = GatewayMessage, + Message = GatewayMessage, >; /// Maximum number of routers allowed for a domain. @@ -159,8 +161,7 @@ pub mod pallet { /// The routers for a given domain were set. RoutersSet { domain: Domain, - //TODO(cdamian): Use T::RouterId - router_ids: BoundedVec, + router_ids: BoundedVec, }, /// An instance was added to a domain. @@ -191,7 +192,7 @@ pub mod pallet { #[pallet::storage] #[pallet::getter(fn routers)] pub type Routers = - StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; + StorageMap<_, Blake2_128Concat, Domain, BoundedVec>; /// Storage that contains a limited number of whitelisted instances of /// deployed liquidity pools for a particular domain. @@ -226,7 +227,7 @@ pub mod pallet { Blake2_128Concat, T::SessionId, Blake2_128Concat, - (Proof, T::Hash), + (Proof, T::RouterId), InboundEntry, >; @@ -321,11 +322,12 @@ pub mod pallet { pub fn set_domain_routers( origin: OriginFor, domain: Domain, - //TODO(cdamian): Use T::RouterId - router_ids: BoundedVec, + router_ids: BoundedVec, ) -> DispatchResult { T::AdminOrigin::ensure_origin(origin)?; + ensure!(domain != Domain::Centrifuge, Error::::DomainNotSupported); + //TODO(cdamian): Outbound - Call router.init() for each router? >::insert(domain.clone(), router_ids.clone()); @@ -394,6 +396,7 @@ pub mod pallet { #[pallet::call_index(5)] pub fn receive_message( origin: OriginFor, + router_id: T::RouterId, msg: BoundedVec, ) -> DispatchResult { let GatewayOrigin::Domain(origin_address) = T::LocalEVMOrigin::ensure_origin(origin)?; @@ -407,12 +410,12 @@ pub mod pallet { Error::::UnknownInstance, ); - let gateway_message = GatewayMessage::::Inbound { - domain_address: origin_address, - message: T::Message::deserialize(&msg)?, - //TODO(cdamian): Use an actual router hash. - router_hash: T::Hash::default(), - }; + let gateway_message = + GatewayMessage::::Inbound { + domain_address: origin_address, + message: T::Message::deserialize(&msg)?, + router_id, + }; T::MessageQueue::submit(gateway_message) } @@ -480,8 +483,8 @@ pub mod pallet { #[pallet::call_index(11)] pub fn execute_message_recovery( origin: OriginFor, + domain: Domain, message_proof: Proof, - proof_count: u32, ) -> DispatchResult { //TODO(cdamian): Implement this. unimplemented!() @@ -517,20 +520,20 @@ pub mod pallet { } impl MessageProcessor for Pallet { - type Message = GatewayMessage; + type Message = GatewayMessage; fn process(msg: Self::Message) -> (DispatchResult, Weight) { match msg { GatewayMessage::Inbound { domain_address, message, - router_hash, - } => Self::process_inbound_message(domain_address, message, router_hash), + router_id, + } => Self::process_inbound_message(domain_address, message, router_id), GatewayMessage::Outbound { sender, message, - router_hash, - } => Self::process_outbound_message(sender, message, router_hash), + router_id, + } => Self::process_outbound_message(sender, message, router_id), } } diff --git a/pallets/liquidity-pools-gateway/src/message.rs b/pallets/liquidity-pools-gateway/src/message.rs index 42226a46eb..0d6fc4ff38 100644 --- a/pallets/liquidity-pools-gateway/src/message.rs +++ b/pallets/liquidity-pools-gateway/src/message.rs @@ -1,18 +1,18 @@ -use cfg_types::domain_address::{Domain, DomainAddress}; +use cfg_types::domain_address::DomainAddress; use frame_support::pallet_prelude::{Decode, Encode, MaxEncodedLen, TypeInfo}; /// Message type used by the LP gateway. #[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)] -pub enum GatewayMessage { +pub enum GatewayMessage { Inbound { domain_address: DomainAddress, message: Message, - router_hash: Hash, + router_id: RouterId, }, Outbound { sender: AccountId, message: Message, - router_hash: Hash, + router_id: RouterId, }, } @@ -23,7 +23,7 @@ impl Default GatewayMessage::Inbound { domain_address: Default::default(), message: Default::default(), - router_hash: Default::default(), + router_id: Default::default(), } } } diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs index e65b70f41d..c627aa3d9c 100644 --- a/pallets/liquidity-pools-gateway/src/message_processing.rs +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -1,5 +1,5 @@ use cfg_primitives::LP_DEFENSIVE_WEIGHT; -use cfg_traits::liquidity_pools::{InboundMessageHandler, LPEncoding, MessageQueue, Proof, Router}; +use cfg_traits::liquidity_pools::{InboundMessageHandler, LPEncoding, MessageQueue, Proof}; use cfg_types::domain_address::{Domain, DomainAddress}; use frame_support::{ dispatch::DispatchResult, @@ -39,7 +39,7 @@ pub enum InboundEntry { #[derive(Clone)] pub struct InboundProcessingInfo { domain_address: DomainAddress, - routers: BoundedVec, + routers: BoundedVec, current_session_id: T::SessionId, expected_proof_count_per_message: u32, } @@ -91,20 +91,20 @@ impl Pallet { /// - proofs are not sent by the first inbound router. fn validate_inbound_entry( inbound_processing_info: &InboundProcessingInfo, - router_hash: T::Hash, + router_id: &T::RouterId, inbound_entry: &InboundEntry, ) -> DispatchResult { let routers = inbound_processing_info.routers.clone(); ensure!( - routers.iter().any(|x| x == &router_hash), + routers.iter().any(|x| x == router_id), Error::::UnknownInboundMessageRouter ); match inbound_entry { InboundEntry::Message { .. } => { ensure!( - routers.get(0) == Some(&router_hash), + routers.get(0) == Some(&router_id), Error::::MessageExpectedFromFirstRouter ); @@ -112,7 +112,7 @@ impl Pallet { } InboundEntry::Proof { .. } => { ensure!( - routers.get(0) != Some(&router_hash), + routers.get(0) != Some(&router_id), Error::::ProofNotExpectedFromFirstRouter ); @@ -126,7 +126,7 @@ impl Pallet { fn upsert_pending_entry( session_id: T::SessionId, message_proof: Proof, - router_hash: T::Hash, + router_id: T::RouterId, inbound_entry: InboundEntry, weight: &mut Weight, ) -> DispatchResult { @@ -134,7 +134,7 @@ impl Pallet { PendingInboundEntries::::try_mutate( session_id, - (message_proof, router_hash), + (message_proof, router_id), |storage_entry| match storage_entry { None => { *storage_entry = Some(inbound_entry); @@ -170,7 +170,7 @@ impl Pallet { inbound_processing_info: &InboundProcessingInfo, message: T::Message, message_proof: Proof, - router_hash: T::Hash, + router_id: T::RouterId, weight: &mut Weight, ) -> DispatchResult { let inbound_entry = Self::create_inbound_entry( @@ -179,12 +179,12 @@ impl Pallet { inbound_processing_info.expected_proof_count_per_message, ); - Self::validate_inbound_entry(&inbound_processing_info, router_hash, &inbound_entry)?; + Self::validate_inbound_entry(&inbound_processing_info, &router_id, &inbound_entry)?; Self::upsert_pending_entry( inbound_processing_info.current_session_id, message_proof, - router_hash, + router_id, inbound_entry, weight, )?; @@ -197,11 +197,14 @@ impl Pallet { fn get_executable_message( inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, + weight: &mut Weight, ) -> Option { let mut message = None; let mut votes = 0; for router in &inbound_processing_info.routers { + weight.saturating_accrue(T::DbWeight::get().reads(1)); + match PendingInboundEntries::::get( inbound_processing_info.current_session_id, (message_proof, router), @@ -235,8 +238,11 @@ impl Pallet { fn decrease_pending_entries_counts( inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, + weight: &mut Weight, ) -> DispatchResult { for router in &inbound_processing_info.routers { + weight.saturating_accrue(T::DbWeight::get().writes(1)); + match PendingInboundEntries::::try_mutate( inbound_processing_info.current_session_id, (message_proof, router), @@ -314,7 +320,7 @@ impl Pallet { pub(crate) fn process_inbound_message( domain_address: DomainAddress, message: T::Message, - router_hash: T::Hash, + router_id: T::RouterId, ) -> (DispatchResult, Weight) { let mut weight = Default::default(); @@ -340,17 +346,19 @@ impl Pallet { &inbound_processing_info, submessage.clone(), message_proof, - router_hash, + router_id.clone(), &mut weight, ) { return (Err(e), weight); } - match Self::get_executable_message(&inbound_processing_info, message_proof) { + match Self::get_executable_message(&inbound_processing_info, message_proof, &mut weight) + { Some(m) => { if let Err(e) = Self::decrease_pending_entries_counts( &inbound_processing_info, message_proof, + &mut weight, ) { return (Err(e), weight.saturating_mul(count)); } @@ -364,7 +372,7 @@ impl Pallet { } } - (Ok(()), LP_DEFENSIVE_WEIGHT.saturating_mul(count)) + (Ok(()), weight.saturating_mul(count)) } /// Retrieves the stored router, sends the message, and calculates and @@ -372,13 +380,13 @@ impl Pallet { pub(crate) fn process_outbound_message( sender: T::AccountId, message: T::Message, - router_hash: T::Hash, + router_id: T::RouterId, ) -> (DispatchResult, Weight) { let read_weight = T::DbWeight::get().reads(1); // TODO(cdamian): Update when the router refactor is done. - // let Some(router) = Routers::::get(router_hash) else { + // let Some(router) = Routers::::get(router_id) else { // return (Err(Error::::RouterNotFound.into()), read_weight); // }; // @@ -392,16 +400,16 @@ impl Pallet { (Ok(()), read_weight) } - /// Retrieves the hashes of the routers set for a domain and queues the + /// Retrieves the IDs of the routers set for a domain and queues the /// message and proofs accordingly. pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { - let router_hashes = + let router_ids = Routers::::get(destination.clone()).ok_or(Error::::MultiRouterNotFound)?; let message_proof = message.to_message_proof(); let mut message_opt = Some(message); - for router_hash in router_hashes { + for router_id in router_ids { // Ensure that we only send the actual message once, using one router. // The remaining routers will send the message proof. let router_msg = match message_opt.take() { @@ -411,11 +419,12 @@ impl Pallet { // We are using the sender specified in the pallet config so that we can // ensure that the account is funded - let gateway_message = GatewayMessage::::Outbound { - sender: T::Sender::get(), - message: router_msg, - router_hash, - }; + let gateway_message = + GatewayMessage::::Outbound { + sender: T::Sender::get(), + message: router_msg, + router_id, + }; T::MessageQueue::submit(gateway_message)?; } diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index 332b56df33..a162519560 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -150,6 +150,8 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type Message = Message; type MessageQueue = MockLiquidityPoolsGatewayQueue; type Router = RouterMock; + //TODO(cdamian): Change to some other type for tests? + type RouterId = H256; type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 772f070d4d..f8da936f1a 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -1,8 +1,7 @@ use std::collections::HashMap; -use cfg_mocks::*; use cfg_primitives::LP_DEFENSIVE_WEIGHT; -use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler, Proof}; +use cfg_traits::liquidity_pools::{LPEncoding, MessageProcessor, OutboundMessageHandler}; use cfg_types::domain_address::*; use frame_support::{ assert_err, assert_noop, assert_ok, dispatch::PostDispatchInfo, pallet_prelude::Pays, @@ -12,7 +11,7 @@ use itertools::Itertools; use lazy_static::lazy_static; use parity_scale_codec::MaxEncodedLen; use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160, H256}; -use sp_runtime::{DispatchError, DispatchError::BadOrigin, DispatchErrorWithPostInfo}; +use sp_runtime::{DispatchError, DispatchError::BadOrigin}; use sp_std::sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -54,63 +53,77 @@ mod utils { use utils::*; -mod set_domain_router { +mod set_domain_routers { use super::*; #[test] fn success() { new_test_ext().execute_with(|| { let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Ok(())); - assert_ok!(LiquidityPoolsGateway::set_domain_router( + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + //TODO(cdamian): Enable this after we figure out router init? + // let router = RouterMock::::default(); + // router.mock_init(move || Ok(())); + + let router_ids = + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router.clone(), + router_ids.clone(), )); - let storage_entry = DomainRouters::::get(domain.clone()); - assert_eq!(storage_entry.unwrap(), router); - - event_exists(Event::::DomainRouterSet { domain, router }); - }); - } - #[test] - fn router_init_error() { - new_test_ext().execute_with(|| { - let domain = Domain::EVM(0); - let router = RouterMock::::default(); - router.mock_init(move || Err(DispatchError::Other("error"))); - - assert_noop!( - LiquidityPoolsGateway::set_domain_router( - RuntimeOrigin::root(), - domain.clone(), - router, - ), - Error::::RouterInitFailed, + assert_eq!(Routers::::get(domain.clone()).unwrap(), router_ids); + assert_eq!( + InboundMessageSessions::::get(domain.clone()), + Some(1) ); + assert_eq!(InvalidSessionIds::::get(0), Some(())); + + event_exists(Event::::RoutersSet { domain, router_ids }); }); } + //TODO(cdamian): Enable this after we figure out router init? + // + // fn router_init_error() { + // new_test_ext().execute_with(|| { + // let domain = Domain::EVM(0); + // let router = RouterMock::::default(); + // router.mock_init(move || Err(DispatchError::Other("error"))); + // + // assert_noop!( + // LiquidityPoolsGateway::set_domain_router( + // RuntimeOrigin::root(), + // domain.clone(), + // router, + // ), + // Error::::RouterInitFailed, + // ); + // }); + // } #[test] fn bad_origin() { new_test_ext().execute_with(|| { let domain = Domain::EVM(0); - let router = RouterMock::::default(); assert_noop!( - LiquidityPoolsGateway::set_domain_router( + LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::signed(get_test_account_id()), domain.clone(), - router, + BoundedVec::try_from(vec![]).unwrap(), ), BadOrigin ); - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); + assert!(Routers::::get(domain.clone()).is_none()); + assert!(InboundMessageSessions::::get(domain).is_none()); + assert!(InvalidSessionIds::::get(0).is_none()); }); } @@ -118,19 +131,19 @@ mod set_domain_router { fn unsupported_domain() { new_test_ext().execute_with(|| { let domain = Domain::Centrifuge; - let router = RouterMock::::default(); assert_noop!( - LiquidityPoolsGateway::set_domain_router( + LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - router + BoundedVec::try_from(vec![]).unwrap(), ), Error::::DomainNotSupported ); - let storage_entry = DomainRouters::::get(domain); - assert!(storage_entry.is_none()); + assert!(Routers::::get(domain.clone()).is_none()); + assert!(InboundMessageSessions::::get(domain).is_none()); + assert!(InvalidSessionIds::::get(0).is_none()); }); } } @@ -295,6 +308,8 @@ mod receive_message_domain { let domain_address = DomainAddress::EVM(0, address.into()); let message = Message::Simple; + let router_id = H256::from_low_u64_be(1); + assert_ok!(LiquidityPoolsGateway::add_instance( RuntimeOrigin::root(), domain_address.clone(), @@ -305,7 +320,7 @@ mod receive_message_domain { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash: H256::from_low_u64_be(1), + router_id, }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -315,6 +330,7 @@ mod receive_message_domain { assert_ok!(LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() )); }); @@ -325,9 +341,12 @@ mod receive_message_domain { new_test_ext().execute_with(|| { let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); + assert_noop!( LiquidityPoolsGateway::receive_message( RuntimeOrigin::signed(AccountId32::new([0u8; 32])), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), BadOrigin, @@ -340,10 +359,12 @@ mod receive_message_domain { new_test_ext().execute_with(|| { let domain_address = DomainAddress::Centrifuge(get_test_account_id().into()); let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), Error::::InvalidMessageOrigin, @@ -357,10 +378,12 @@ mod receive_message_domain { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); let encoded_msg = Message::Simple.serialize(); + let router_id = H256::from_low_u64_be(1); assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), Error::::UnknownInstance, @@ -375,6 +398,8 @@ mod receive_message_domain { let domain_address = DomainAddress::EVM(0, address.into()); let message = Message::Simple; + let router_id = H256::from_low_u64_be(1); + assert_ok!(LiquidityPoolsGateway::add_instance( RuntimeOrigin::root(), domain_address.clone(), @@ -387,7 +412,7 @@ mod receive_message_domain { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash: H256::from_low_u64_be(1), + router_id, }; MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_message| { @@ -398,6 +423,7 @@ mod receive_message_domain { assert_noop!( LiquidityPoolsGateway::receive_message( GatewayOrigin::Domain(domain_address).into(), + router_id, BoundedVec::::try_from(encoded_msg).unwrap() ), err, @@ -417,25 +443,30 @@ mod outbound_message_handler_impl { let msg = Message::Simple; let message_proof = msg.to_message_proof().get_message_proof().unwrap(); - let router_hash_1 = H256::from_low_u64_be(1); - let router_hash_2 = H256::from_low_u64_be(2); - let router_hash_3 = H256::from_low_u64_be(3); - - let router_mock_1 = RouterMock::::default(); - let router_mock_2 = RouterMock::::default(); - let router_mock_3 = RouterMock::::default(); - - router_mock_1.mock_init(move || Ok(())); - router_mock_1.mock_hash(move || router_hash_1); - router_mock_2.mock_init(move || Ok(())); - router_mock_2.mock_hash(move || router_hash_2); - router_mock_3.mock_init(move || Ok(())); - router_mock_3.mock_hash(move || router_hash_3); - - assert_ok!(LiquidityPoolsGateway::set_outbound_routers( + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + //TODO(cdamian): Router init + // let router_hash_1 = H256::from_low_u64_be(1); + // let router_hash_2 = H256::from_low_u64_be(2); + // let router_hash_3 = H256::from_low_u64_be(3); + // + // let router_mock_1 = RouterMock::::default(); + // let router_mock_2 = RouterMock::::default(); + // let router_mock_3 = RouterMock::::default(); + // + // router_mock_1.mock_init(move || Ok(())); + // router_mock_1.mock_hash(move || router_hash_1); + // router_mock_2.mock_init(move || Ok(())); + // router_mock_2.mock_hash(move || router_hash_2); + // router_mock_3.mock_init(move || Ok(())); + // router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(), )); MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { @@ -485,42 +516,48 @@ mod outbound_message_handler_impl { let sender = get_test_account_id(); let msg = Message::Simple; - let router_hash_1 = H256::from_low_u64_be(1); - let router_hash_2 = H256::from_low_u64_be(2); - let router_hash_3 = H256::from_low_u64_be(3); - - let router_mock_1 = RouterMock::::default(); - let router_mock_2 = RouterMock::::default(); - let router_mock_3 = RouterMock::::default(); - - router_mock_1.mock_init(move || Ok(())); - router_mock_1.mock_hash(move || router_hash_1); - router_mock_2.mock_init(move || Ok(())); - router_mock_2.mock_hash(move || router_hash_2); - router_mock_3.mock_init(move || Ok(())); - router_mock_3.mock_hash(move || router_hash_3); - - assert_ok!(LiquidityPoolsGateway::set_outbound_routers( + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let router_id_3 = H256::from_low_u64_be(3); + + //TODO(cdamian): Router init? + // let router_hash_1 = H256::from_low_u64_be(1); + // let router_hash_2 = H256::from_low_u64_be(2); + // let router_hash_3 = H256::from_low_u64_be(3); + // + // let router_mock_1 = RouterMock::::default(); + // let router_mock_2 = RouterMock::::default(); + // let router_mock_3 = RouterMock::::default(); + // + // router_mock_1.mock_init(move || Ok(())); + // router_mock_1.mock_hash(move || router_hash_1); + // router_mock_2.mock_init(move || Ok(())); + // router_mock_2.mock_hash(move || router_hash_2); + // router_mock_3.mock_init(move || Ok(())); + // router_mock_3.mock_hash(move || router_hash_3); + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( RuntimeOrigin::root(), domain.clone(), - BoundedVec::try_from(vec![router_mock_1, router_mock_2, router_mock_3]).unwrap(), + BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(), )); let gateway_message = GatewayMessage::Outbound { sender: ::Sender::get(), message: msg.clone(), - router_hash: router_hash_3, + router_id: router_id_1, }; let err = DispatchError::Unavailable; - MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { + let handler = MockLiquidityPoolsGatewayQueue::mock_submit(move |mock_msg| { assert_eq!(mock_msg, gateway_message); Err(err) }); assert_noop!(LiquidityPoolsGateway::handle(sender, domain, msg), err); + assert_eq!(handler.times(), 1); }); } } @@ -593,7 +630,7 @@ mod message_processor_impl { new_test_ext().execute_with(|| { let session_id = 1; - InboundRouters::::insert( + Routers::::insert( TEST_DOMAIN_ADDRESS.domain(), BoundedVec::try_from(test_routers.clone()).unwrap(), ); @@ -608,7 +645,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: TEST_DOMAIN_ADDRESS, message: router_message.1, - router_hash: router_message.0, + router_id: router_message.0, }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -746,10 +783,10 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash, + router_id: router_hash, }; - InboundRouters::::insert( + Routers::::insert( domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); @@ -785,7 +822,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash, + router_id: router_hash, }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -802,10 +839,10 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash, + router_id: router_hash, }; - InboundRouters::::insert( + Routers::::insert( domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); @@ -827,10 +864,10 @@ mod message_processor_impl { message: message.clone(), // The router stored has a different hash, this should trigger the expected // error. - router_hash: *ROUTER_HASH_2, + router_id: *ROUTER_HASH_2, }; - InboundRouters::::insert( + Routers::::insert( domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); @@ -852,10 +889,10 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash, + router_id: router_hash, }; - InboundRouters::::insert( + Routers::::insert( domain_address.domain(), BoundedVec::<_, _>::try_from(vec![router_hash]).unwrap(), ); @@ -1484,7 +1521,7 @@ mod message_processor_impl { new_test_ext().execute_with(|| { let session_id = 1; - InboundRouters::::insert( + Routers::::insert( TEST_DOMAIN_ADDRESS.domain(), BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) .unwrap(), @@ -1497,7 +1534,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: TEST_DOMAIN_ADDRESS, message: Message::Simple, - router_hash: *ROUTER_HASH_2, + router_id: *ROUTER_HASH_2, }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -1510,7 +1547,7 @@ mod message_processor_impl { new_test_ext().execute_with(|| { let session_id = 1; - InboundRouters::::insert( + Routers::::insert( TEST_DOMAIN_ADDRESS.domain(), BoundedVec::<_, _>::try_from(vec![*ROUTER_HASH_1, *ROUTER_HASH_2]) .unwrap(), @@ -1523,7 +1560,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Inbound { domain_address: TEST_DOMAIN_ADDRESS, message: Message::Proof(MESSAGE_PROOF), - router_hash: *ROUTER_HASH_1, + router_id: *ROUTER_HASH_1, }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); @@ -2457,31 +2494,19 @@ mod message_processor_impl { new_test_ext().execute_with(|| { let domain_address = DomainAddress::EVM(1, [1; 20]); - let message = Message::Proof(MESSAGE_PROOF); - let gateway_message = GatewayMessage::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - router_hash: H256::from_low_u64_be(1), - }; + let router_id = H256::from_low_u64_be(1); - let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); - - let message = Message::Proof(MESSAGE_PROOF); - let gateway_message = GatewayMessage::Inbound { - domain_address: domain_address.clone(), - message: message.clone(), - router_hash: H256::from_low_u64_be(1), - }; - - let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_ok!(res); + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); let message = Message::Simple; let gateway_message = GatewayMessage::Inbound { domain_address: domain_address.clone(), message: message.clone(), - router_hash: H256::from_low_u64_be(1), + router_id, }; let err = DispatchError::Unavailable; @@ -2493,9 +2518,8 @@ mod message_processor_impl { Err(err) }); - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + let (res, _) = LiquidityPoolsGateway::process(gateway_message); assert_noop!(res, err); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT); }); } } @@ -2518,18 +2542,21 @@ mod message_processor_impl { pays_fee: Pays::Yes, }; - let router_hash = H256::from_low_u64_be(1); - - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); - - Ok(router_post_info) - }); - router_mock.mock_hash(move || router_hash); - - DomainRouters::::insert(domain.clone(), router_mock); + let router_id = H256::from_low_u64_be(1); + + //TODO(cdamian): Drop mock? + // let router_hash = H256::from_low_u64_be(1); + // + // let router_mock = RouterMock::::default(); + // router_mock.mock_send(move |mock_sender, mock_message| { + // assert_eq!(mock_sender, expected_sender); + // assert_eq!(mock_message, expected_message.serialize()); + // + // Ok(router_post_info) + // }); + // router_mock.mock_hash(move || router_hash); + // + // DomainRouters::::insert(domain.clone(), router_mock); let min_expected_weight = ::DbWeight::get() .reads(1) + router_post_info.actual_weight.unwrap() @@ -2538,7 +2565,7 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Outbound { sender, message: message.clone(), - router_hash, + router_id, }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); @@ -2547,25 +2574,27 @@ mod message_processor_impl { }); } - #[test] - fn router_not_found() { - new_test_ext().execute_with(|| { - let sender = get_test_account_id(); - let message = Message::Simple; - - let expected_weight = ::DbWeight::get().reads(1); - - let gateway_message = GatewayMessage::Outbound { - sender, - message, - router_hash: H256::from_low_u64_be(1), - }; - - let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, Error::::RouterNotFound); - assert_eq!(weight, expected_weight); - }); - } + //TODO(cdamian): Fix when bi-directional routers are in. + // #[test] + // fn router_not_found() { + // new_test_ext().execute_with(|| { + // let sender = get_test_account_id(); + // let message = Message::Simple; + // + // let expected_weight = ::DbWeight::get().reads(1); + // + // let gateway_message = GatewayMessage::Outbound { + // sender, + // message, + // router_id: H256::from_low_u64_be(1), + // }; + // + // let (res, weight) = LiquidityPoolsGateway::process(gateway_message); + // assert_noop!(res, Error::::RouterNotFound); + // assert_eq!(weight, expected_weight); + // }); + // } #[test] fn router_error() { @@ -2582,20 +2611,20 @@ mod message_processor_impl { pays_fee: Pays::Yes, }; - let router_err = DispatchError::Unavailable; - - let router_mock = RouterMock::::default(); - router_mock.mock_send(move |mock_sender, mock_message| { - assert_eq!(mock_sender, expected_sender); - assert_eq!(mock_message, expected_message.serialize()); - - Err(DispatchErrorWithPostInfo { - post_info: router_post_info, - error: router_err, - }) - }); - - DomainRouters::::insert(domain.clone(), router_mock); + // let router_err = DispatchError::Unavailable; + // + // let router_mock = RouterMock::::default(); + // router_mock.mock_send(move |mock_sender, mock_message| { + // assert_eq!(mock_sender, expected_sender); + // assert_eq!(mock_message, expected_message.serialize()); + // + // Err(DispatchErrorWithPostInfo { + // post_info: router_post_info, + // error: router_err, + // }) + // }); + // + // DomainRouters::::insert(domain.clone(), router_mock); let min_expected_weight = ::DbWeight::get() .reads(1) + router_post_info.actual_weight.unwrap() @@ -2604,11 +2633,13 @@ mod message_processor_impl { let gateway_message = GatewayMessage::Outbound { sender, message: message.clone(), - router_hash: H256::from_low_u64_be(1), + router_id: H256::from_low_u64_be(1), }; let (res, weight) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, router_err); + //TODO(cdamian): Error out + assert_ok!(res); + // assert_noop!(res, router_err) assert!(weight.all_lte(min_expected_weight)); }); } @@ -2649,6 +2680,10 @@ mod batches { // Ok Batched assert_ok!(LiquidityPoolsGateway::handle(USER, DOMAIN, Message::Simple)); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert(DOMAIN, BoundedVec::try_from(vec![router_id_1]).unwrap()); + // Not batched, it belong to OTHER assert_ok!(LiquidityPoolsGateway::handle( OTHER, @@ -2656,6 +2691,11 @@ mod batches { Message::Simple )); + Routers::::insert( + Domain::EVM(2), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + // Not batched, it belong to EVM 2 assert_ok!(LiquidityPoolsGateway::handle( USER, @@ -2698,6 +2738,10 @@ mod batches { DispatchError::Other(MAX_PACKED_MESSAGES_ERR) ); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert(DOMAIN, BoundedVec::try_from(vec![router_id_1]).unwrap()); + assert_ok!(LiquidityPoolsGateway::end_batch_message( RuntimeOrigin::signed(USER), DOMAIN @@ -2736,16 +2780,54 @@ mod batches { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); - MockLiquidityPools::mock_handle(|_, _| Ok(())); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); + + let handler = MockLiquidityPools::mock_handle(|_, _| Ok(())); + + let submessage_count = 5; let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, - message: Message::deserialize(&(1..=5).collect::>()).unwrap(), - router_hash: *ROUTER_HASH_1, + message: Message::deserialize(&(1..=submessage_count).collect::>()).unwrap(), + router_id: *ROUTER_HASH_1, }); - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 5); + let expected_weight = Weight::default() + // get_inbound_processing_info + .saturating_add(::DbWeight::get().reads(3)) + // process_inbound_message + .saturating_add(Weight::from_parts(0, Message::max_encoded_len() as u64)) + .saturating_add(LP_DEFENSIVE_WEIGHT) + // upsert_pending_entry + .saturating_add( + ::DbWeight::get() + .writes(1) + .saturating_mul(submessage_count.into()), + ) + // get_executable_message + .saturating_add( + ::DbWeight::get() + .reads(1) + .saturating_mul(submessage_count.into()), + ) + // decrease_pending_entries_counts + .saturating_add( + ::DbWeight::get() + .writes(1) + .saturating_mul(submessage_count.into()), + ) + // process_inbound_message + .saturating_mul(submessage_count.into()); + assert_ok!(result); + assert_eq!(weight, expected_weight); + assert_eq!(handler.times(), submessage_count as u32); }); } @@ -2755,23 +2837,32 @@ mod batches { let address = H160::from_slice(&get_test_account_id().as_slice()[..20]); let domain_address = DomainAddress::EVM(0, address.into()); + let router_id_1 = H256::from_low_u64_be(1); + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), 1); + let counter = Arc::new(AtomicU32::new(0)); - MockLiquidityPools::mock_handle(move |_, _| { + + let handler = MockLiquidityPools::mock_handle(move |_, _| { match counter.fetch_add(1, Ordering::Relaxed) { 2 => Err(DispatchError::Unavailable), _ => Ok(()), } }); - let (result, weight) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { + let (result, _) = LiquidityPoolsGateway::process(GatewayMessage::Inbound { domain_address, message: Message::deserialize(&(1..=5).collect::>()).unwrap(), - router_hash: *ROUTER_HASH_1, + router_id: *ROUTER_HASH_1, }); - // 2 correct messages and 1 failed message processed. - assert_eq!(weight, LP_DEFENSIVE_WEIGHT * 3); assert_err!(result, DispatchError::Unavailable); + // 2 correct messages and 1 failed message processed. + assert_eq!(handler.times(), 3); }); } } From 121b52f43cafbc00b937a1a3f9c32faa1009f006 Mon Sep 17 00:00:00 2001 From: Cosmin Damian <17934949+cdamian@users.noreply.github.com> Date: Mon, 12 Aug 2024 23:27:54 +0300 Subject: [PATCH 11/11] lp-gateway: Add extrinsic for executing message recovery --- pallets/liquidity-pools-gateway/src/lib.rs | 88 ++++-- .../src/message_processing.rs | 12 +- pallets/liquidity-pools-gateway/src/mock.rs | 2 +- pallets/liquidity-pools-gateway/src/tests.rs | 286 ++++++++++++++++-- 4 files changed, 329 insertions(+), 59 deletions(-) diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index 663a07caa5..c28907cbfb 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -62,7 +62,7 @@ mod tests; #[frame_support::pallet] pub mod pallet { use frame_system::pallet_prelude::BlockNumberFor; - use sp_arithmetic::traits::EnsureAdd; + use sp_arithmetic::traits::{EnsureAdd, EnsureAddAssign}; use super::*; @@ -162,6 +162,7 @@ pub mod pallet { RoutersSet { domain: Domain, router_ids: BoundedVec, + session_id: T::SessionId, }, /// An instance was added to a domain. @@ -175,6 +176,13 @@ pub mod pallet { domain: Domain, hook_address: [u8; 20], }, + + /// Message recovery was executed. + MessageRecoveryExecuted { + domain: Domain, + proof: Proof, + router_id: T::RouterId, + }, } // TODO(cdamian): Add migration to clear this storage. @@ -246,14 +254,11 @@ pub mod pallet { /// Any `PendingInboundEntries` mapped to the invalid IDs are removed from /// storage during `on_idle`. #[pallet::storage] - #[pallet::getter(fn invalid_session_ids)] - pub type InvalidSessionIds = StorageMap<_, Blake2_128Concat, T::SessionId, ()>; + #[pallet::getter(fn invalid_message_sessions)] + pub type InvalidMessageSessions = StorageMap<_, Blake2_128Concat, T::SessionId, ()>; #[pallet::error] pub enum Error { - /// Router initialization failed. - RouterInitFailed, - /// The origin of the message to be processed is invalid. InvalidMessageOrigin, @@ -269,8 +274,8 @@ pub mod pallet { /// Unknown instance. UnknownInstance, - /// Router not found. - RouterNotFound, + /// Routers not found. + RoutersNotFound, /// Emitted when you call `start_batch_messages()` but that was already /// called. You should finalize the message with `end_batch_messages()` @@ -304,9 +309,6 @@ pub mod pallet { /// Pending inbound entry not found. PendingInboundEntryNotFound, - /// Multi-router not found. - MultiRouterNotFound, - /// Message proof cannot be retrieved. MessageProofRetrieval, @@ -332,19 +334,23 @@ pub mod pallet { >::insert(domain.clone(), router_ids.clone()); - let (old_session_id, new_session_id) = SessionIdStore::::try_mutate(|n| { - let old_session_id = *n; - let new_session_id = old_session_id.ensure_add(One::one())?; + if let Some(old_session_id) = InboundMessageSessions::::get(domain.clone()) { + InvalidMessageSessions::::insert(old_session_id, ()); + } - *n = new_session_id; + let session_id = SessionIdStore::::try_mutate(|n| { + n.ensure_add_assign(One::one())?; - Ok::<(T::SessionId, T::SessionId), DispatchError>((old_session_id, new_session_id)) + Ok::(*n) })?; - InboundMessageSessions::::insert(domain.clone(), new_session_id); - InvalidSessionIds::::insert(old_session_id, ()); + InboundMessageSessions::::insert(domain.clone(), session_id); - Self::deposit_event(Event::RoutersSet { domain, router_ids }); + Self::deposit_event(Event::RoutersSet { + domain, + router_ids, + session_id, + }); Ok(()) } @@ -484,10 +490,48 @@ pub mod pallet { pub fn execute_message_recovery( origin: OriginFor, domain: Domain, - message_proof: Proof, + proof: Proof, + router_id: T::RouterId, ) -> DispatchResult { - //TODO(cdamian): Implement this. - unimplemented!() + T::AdminOrigin::ensure_origin(origin)?; + + let session_id = InboundMessageSessions::::get(&domain) + .ok_or(Error::::InboundDomainSessionNotFound)?; + + let routers = Routers::::get(&domain).ok_or(Error::::RoutersNotFound)?; + + ensure!( + routers.iter().any(|x| x == &router_id), + Error::::UnknownInboundMessageRouter + ); + + PendingInboundEntries::::try_mutate( + session_id, + (proof, router_id.clone()), + |storage_entry| match storage_entry { + Some(entry) => match entry { + InboundEntry::Proof { current_count } => { + current_count.ensure_add_assign(1).map_err(|e| e.into()) + } + InboundEntry::Message { .. } => { + Err(Error::::ExpectedMessageProofType.into()) + } + }, + None => { + *storage_entry = Some(InboundEntry::::Proof { current_count: 1 }); + + Ok::<(), DispatchError>(()) + } + }, + )?; + + Self::deposit_event(Event::::MessageRecoveryExecuted { + domain, + proof, + router_id, + }); + + Ok(()) } } diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs index c627aa3d9c..80828991a1 100644 --- a/pallets/liquidity-pools-gateway/src/message_processing.rs +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -13,7 +13,7 @@ use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub}; use sp_runtime::DispatchError; use crate::{ - message::GatewayMessage, Config, Error, InboundMessageSessions, InvalidSessionIds, Pallet, + message::GatewayMessage, Config, Error, InboundMessageSessions, InvalidMessageSessions, Pallet, PendingInboundEntries, Routers, }; @@ -48,7 +48,7 @@ impl Pallet { /// Calculates and returns the proof count required for processing one /// inbound message. fn get_expected_proof_count(domain: &Domain) -> Result { - let routers = Routers::::get(domain).ok_or(Error::::MultiRouterNotFound)?; + let routers = Routers::::get(domain).ok_or(Error::::RoutersNotFound)?; let expected_proof_count = routers.len().ensure_sub(1)?; @@ -294,7 +294,7 @@ impl Pallet { weight: &mut Weight, ) -> Result, DispatchError> { let routers = - Routers::::get(domain_address.domain()).ok_or(Error::::MultiRouterNotFound)?; + Routers::::get(domain_address.domain()).ok_or(Error::::RoutersNotFound)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); @@ -404,7 +404,7 @@ impl Pallet { /// message and proofs accordingly. pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult { let router_ids = - Routers::::get(destination.clone()).ok_or(Error::::MultiRouterNotFound)?; + Routers::::get(destination.clone()).ok_or(Error::::RoutersNotFound)?; let message_proof = message.to_message_proof(); let mut message_opt = Some(message); @@ -438,7 +438,7 @@ impl Pallet { /// The invalid session IDs are removed from storage if all entries mapped /// to them were cleared. pub(crate) fn clear_invalid_session_ids(max_weight: Weight) -> Weight { - let invalid_session_ids = InvalidSessionIds::::iter_keys().collect::>(); + let invalid_session_ids = InvalidMessageSessions::::iter_keys().collect::>(); let mut weight = T::DbWeight::get().reads(1); @@ -462,7 +462,7 @@ impl Pallet { cursor = match res.maybe_cursor { None => { - InvalidSessionIds::::remove(invalid_session_id); + InvalidMessageSessions::::remove(invalid_session_id); weight.saturating_accrue(T::DbWeight::get().writes(1)); diff --git a/pallets/liquidity-pools-gateway/src/mock.rs b/pallets/liquidity-pools-gateway/src/mock.rs index a162519560..b97d1ba99a 100644 --- a/pallets/liquidity-pools-gateway/src/mock.rs +++ b/pallets/liquidity-pools-gateway/src/mock.rs @@ -155,7 +155,7 @@ impl pallet_liquidity_pools_gateway::Config for Runtime { type RuntimeEvent = RuntimeEvent; type RuntimeOrigin = RuntimeOrigin; type Sender = Sender; - type SessionId = u64; + type SessionId = u32; type WeightInfo = (); } diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index f8da936f1a..7e01d1c45b 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -10,8 +10,12 @@ use frame_support::{ use itertools::Itertools; use lazy_static::lazy_static; use parity_scale_codec::MaxEncodedLen; +use sp_arithmetic::ArithmeticError::Overflow; use sp_core::{bounded::BoundedVec, crypto::AccountId32, ByteArray, H160, H256}; -use sp_runtime::{DispatchError, DispatchError::BadOrigin}; +use sp_runtime::{ + DispatchError, + DispatchError::{Arithmetic, BadOrigin}, +}; use sp_std::sync::{ atomic::{AtomicU32, Ordering}, Arc, @@ -61,15 +65,13 @@ mod set_domain_routers { new_test_ext().execute_with(|| { let domain = Domain::EVM(0); + let mut session_id = 1; + let router_id_1 = H256::from_low_u64_be(1); let router_id_2 = H256::from_low_u64_be(2); let router_id_3 = H256::from_low_u64_be(3); - //TODO(cdamian): Enable this after we figure out router init? - // let router = RouterMock::::default(); - // router.mock_init(move || Ok(())); - - let router_ids = + let mut router_ids = BoundedVec::try_from(vec![router_id_1, router_id_2, router_id_3]).unwrap(); assert_ok!(LiquidityPoolsGateway::set_domain_routers( @@ -81,32 +83,44 @@ mod set_domain_routers { assert_eq!(Routers::::get(domain.clone()).unwrap(), router_ids); assert_eq!( InboundMessageSessions::::get(domain.clone()), - Some(1) + Some(session_id) + ); + assert_eq!(InvalidMessageSessions::::get(session_id - 1), None); + + event_exists(Event::::RoutersSet { + domain: domain.clone(), + router_ids, + session_id, + }); + + router_ids = BoundedVec::try_from(vec![router_id_3, router_id_2, router_id_1]).unwrap(); + + session_id += 1; + + assert_ok!(LiquidityPoolsGateway::set_domain_routers( + RuntimeOrigin::root(), + domain.clone(), + router_ids.clone(), + )); + + assert_eq!(Routers::::get(domain.clone()).unwrap(), router_ids); + assert_eq!( + InboundMessageSessions::::get(domain.clone()), + Some(session_id) + ); + assert_eq!( + InvalidMessageSessions::::get(session_id - 1), + Some(()) ); - assert_eq!(InvalidSessionIds::::get(0), Some(())); - event_exists(Event::::RoutersSet { domain, router_ids }); + event_exists(Event::::RoutersSet { + domain, + router_ids, + session_id, + }); }); } - //TODO(cdamian): Enable this after we figure out router init? - // - // fn router_init_error() { - // new_test_ext().execute_with(|| { - // let domain = Domain::EVM(0); - // let router = RouterMock::::default(); - // router.mock_init(move || Err(DispatchError::Other("error"))); - // - // assert_noop!( - // LiquidityPoolsGateway::set_domain_router( - // RuntimeOrigin::root(), - // domain.clone(), - // router, - // ), - // Error::::RouterInitFailed, - // ); - // }); - // } #[test] fn bad_origin() { new_test_ext().execute_with(|| { @@ -123,7 +137,7 @@ mod set_domain_routers { assert!(Routers::::get(domain.clone()).is_none()); assert!(InboundMessageSessions::::get(domain).is_none()); - assert!(InvalidSessionIds::::get(0).is_none()); + assert!(InvalidMessageSessions::::get(0).is_none()); }); } @@ -143,7 +157,25 @@ mod set_domain_routers { assert!(Routers::::get(domain.clone()).is_none()); assert!(InboundMessageSessions::::get(domain).is_none()); - assert!(InvalidSessionIds::::get(0).is_none()); + assert!(InvalidMessageSessions::::get(0).is_none()); + }); + } + + #[test] + fn session_id_overflow() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + + SessionIdStore::::set(u32::MAX); + + assert_noop!( + LiquidityPoolsGateway::set_domain_routers( + RuntimeOrigin::root(), + domain, + BoundedVec::try_from(vec![]).unwrap(), + ), + Arithmetic(Overflow) + ); }); } } @@ -826,7 +858,7 @@ mod message_processor_impl { }; let (res, _) = LiquidityPoolsGateway::process(gateway_message); - assert_noop!(res, Error::::MultiRouterNotFound); + assert_noop!(res, Error::::RoutersNotFound); }); } @@ -2866,3 +2898,197 @@ mod batches { }); } } + +mod execute_message_recovery { + use super::*; + + #[test] + fn success() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + )); + + event_exists(Event::::MessageRecoveryExecuted { + domain: domain.clone(), + proof: MESSAGE_PROOF, + router_id: router_id.clone(), + }); + + let inbound_entry = + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, router_id)) + .expect("inbound entry is stored"); + + assert_eq!( + inbound_entry, + InboundEntry::::Proof { current_count: 1 } + ); + + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + )); + + let inbound_entry = + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, router_id)) + .expect("inbound entry is stored"); + + assert_eq!( + inbound_entry, + InboundEntry::::Proof { current_count: 2 } + ); + + event_exists(Event::::MessageRecoveryExecuted { + domain: domain.clone(), + proof: MESSAGE_PROOF, + router_id: router_id.clone(), + }); + }); + } + + #[test] + fn inbound_domain_session_not_found() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Error::::InboundDomainSessionNotFound + ); + }); + } + + #[test] + fn routers_not_found() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Error::::RoutersNotFound + ); + }); + } + + #[test] + fn unknown_inbound_message_router() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id_1 = H256::from_low_u64_be(1); + let router_id_2 = H256::from_low_u64_be(2); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id_1]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id_2 + ), + Error::::UnknownInboundMessageRouter + ); + }); + } + + #[test] + fn proof_count_overflow() { + new_test_ext().execute_with(|| { + let domain = Domain::EVM(0); + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain.clone(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain.clone(), session_id); + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, router_id), + InboundEntry::::Proof { + current_count: u32::MAX, + }, + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain.clone(), + MESSAGE_PROOF, + router_id + ), + Arithmetic(Overflow) + ); + }); + } + + #[test] + fn expected_message_proof_type() { + new_test_ext().execute_with(|| { + let domain_address = TEST_DOMAIN_ADDRESS; + let router_id = H256::from_low_u64_be(1); + let session_id = 1; + + Routers::::insert( + domain_address.domain(), + BoundedVec::try_from(vec![router_id]).unwrap(), + ); + InboundMessageSessions::::insert(domain_address.domain(), session_id); + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, router_id), + InboundEntry::::Message { + domain_address: domain_address.clone(), + message: Message::Simple, + expected_proof_count: 2, + }, + ); + + assert_noop!( + LiquidityPoolsGateway::execute_message_recovery( + RuntimeOrigin::root(), + domain_address.domain(), + MESSAGE_PROOF, + router_id + ), + Error::::ExpectedMessageProofType + ); + }); + } +}