diff --git a/pallets/liquidity-pools-gateway/src/lib.rs b/pallets/liquidity-pools-gateway/src/lib.rs index c2e53c916a..d66f224706 100644 --- a/pallets/liquidity-pools-gateway/src/lib.rs +++ b/pallets/liquidity-pools-gateway/src/lib.rs @@ -61,6 +61,8 @@ mod tests; #[frame_support::pallet] pub mod pallet { + use frame_support::dispatch::PostDispatchInfo; + use super::*; const STORAGE_VERSION: StorageVersion = StorageVersion::new(1); @@ -291,12 +293,15 @@ pub mod pallet { /// Not enough routers are stored for a domain. NotEnoughRoutersForDomain, + + /// First router for a domain was not found. + FirstRouterNotFound, } #[pallet::call] impl Pallet { /// Sets the router IDs used for a specific domain, - #[pallet::weight(T::WeightInfo::set_domain_routers())] + #[pallet::weight(T::WeightInfo::set_routers())] #[pallet::call_index(0)] pub fn set_routers( origin: OriginFor, @@ -451,22 +456,29 @@ pub mod pallet { #[pallet::call_index(11)] pub fn execute_message_recovery( origin: OriginFor, + domain_address: DomainAddress, proof: Proof, router_id: T::RouterId, - ) -> DispatchResult { + ) -> DispatchResultWithPostInfo { T::AdminOrigin::ensure_origin(origin)?; - let session_id = SessionIdStore::::get().ok_or(Error::::SessionIdNotFound)?; + let mut weight = Weight::default(); - let routers = Routers::::get().ok_or(Error::::RoutersNotFound)?; + let inbound_processing_info = + Self::get_inbound_processing_info(domain_address, &mut weight)?; ensure!( - routers.iter().any(|x| x == &router_id), + inbound_processing_info + .router_ids + .iter() + .any(|x| x == &router_id), Error::::UnknownRouter ); + weight.saturating_accrue(T::DbWeight::get().writes(1)); + PendingInboundEntries::::try_mutate( - session_id, + inbound_processing_info.current_session_id, (proof, router_id.clone()), |storage_entry| match storage_entry { Some(entry) => match entry { @@ -485,9 +497,14 @@ pub mod pallet { }, )?; + Self::execute_if_requirements_are_met(&inbound_processing_info, proof, &mut weight)?; + Self::deposit_event(Event::::MessageRecoveryExecuted { proof, router_id }); - Ok(()) + Ok(PostDispatchInfo { + actual_weight: Some(weight), + pays_fee: Pays::Yes, + }) } } diff --git a/pallets/liquidity-pools-gateway/src/message_processing.rs b/pallets/liquidity-pools-gateway/src/message_processing.rs index 87c23e3def..be334410a5 100644 --- a/pallets/liquidity-pools-gateway/src/message_processing.rs +++ b/pallets/liquidity-pools-gateway/src/message_processing.rs @@ -39,16 +39,16 @@ pub enum InboundEntry { /// Type used when processing inbound messages. #[derive(Clone)] pub struct InboundProcessingInfo { - domain_address: DomainAddress, - router_ids: Vec, - current_session_id: T::SessionId, - expected_proof_count_per_message: u32, + pub domain_address: DomainAddress, + pub router_ids: Vec, + pub current_session_id: T::SessionId, + pub expected_proof_count_per_message: u32, } impl Pallet { /// Retrieves all available routers for a domain and then filters them based /// on the routers that we have in storage. - fn get_router_ids_for_domain(domain: Domain) -> Result, DispatchError> { + pub fn get_router_ids_for_domain(domain: Domain) -> Result, DispatchError> { let all_routers_for_domain = T::RouterProvider::routers_for_domain(domain); let stored_routers = Routers::::get().ok_or(Error::::RoutersNotFound)?; @@ -68,10 +68,9 @@ impl Pallet { /// Calculates and returns the proof count required for processing one /// inbound message. - fn get_expected_proof_count(domain: Domain) -> Result { - let routers_count = Self::get_router_ids_for_domain(domain)?.len(); - - let expected_proof_count = routers_count + fn get_expected_proof_count(router_ids: &Vec) -> Result { + let expected_proof_count = router_ids + .len() .ensure_sub(1) .map_err(|_| Error::::NotEnoughRoutersForDomain)?; @@ -216,12 +215,13 @@ impl Pallet { } /// Checks if the number of proofs required for executing one message - /// were received, and returns the message if so. - fn get_executable_message( + /// were received, and if so, decreases the counts accordingly and executes + /// the message. + pub(crate) fn execute_if_requirements_are_met( inbound_processing_info: &InboundProcessingInfo, message_proof: Proof, weight: &mut Weight, - ) -> Option { + ) -> DispatchResult { let mut message = None; let mut votes = 0; @@ -234,7 +234,7 @@ impl Pallet { ) { // We expected one InboundEntry for each router, if that's not the case, // we can return. - None => return None, + None => return Ok(()), Some(inbound_entry) => match inbound_entry { InboundEntry::Message { message: stored_message, @@ -249,11 +249,25 @@ impl Pallet { }; } - if votes == inbound_processing_info.expected_proof_count_per_message { - return message; + if votes < inbound_processing_info.expected_proof_count_per_message { + return Ok(()); } - None + match message { + Some(msg) => { + Self::decrease_pending_entries_counts( + &inbound_processing_info, + message_proof, + weight, + )?; + + T::InboundMessageHandler::handle( + inbound_processing_info.domain_address.clone(), + msg, + ) + } + None => Ok(()), + } } /// Decreases the counts for inbound entries and removes them if the @@ -312,7 +326,7 @@ impl Pallet { /// Retrieves the information required for processing an inbound /// message. - fn get_inbound_processing_info( + pub(crate) fn get_inbound_processing_info( domain_address: DomainAddress, weight: &mut Weight, ) -> Result, DispatchError> { @@ -324,7 +338,7 @@ impl Pallet { weight.saturating_accrue(T::DbWeight::get().reads(1)); - let expected_proof_count = Self::get_expected_proof_count(domain_address.domain())?; + let expected_proof_count = Self::get_expected_proof_count(&router_ids)?; weight.saturating_accrue(T::DbWeight::get().reads(1)); @@ -373,23 +387,13 @@ impl Pallet { return (Err(e), weight); } - match Self::get_executable_message(&inbound_processing_info, message_proof, &mut weight) - { - Some(m) => { - if let Err(e) = Self::decrease_pending_entries_counts( - &inbound_processing_info, - message_proof, - &mut weight, - ) { - return (Err(e), weight.saturating_mul(count)); - } - - if let Err(e) = T::InboundMessageHandler::handle(domain_address.clone(), m) { - // We only consume the processed weight if error during the batch - return (Err(e), weight.saturating_mul(count)); - } - } - None => continue, + match Self::execute_if_requirements_are_met( + &inbound_processing_info, + message_proof, + &mut weight, + ) { + Err(e) => return (Err(e), weight.saturating_mul(count)), + Ok(_) => continue, } } diff --git a/pallets/liquidity-pools-gateway/src/tests.rs b/pallets/liquidity-pools-gateway/src/tests.rs index 6600aa9a3c..5904f6c082 100644 --- a/pallets/liquidity-pools-gateway/src/tests.rs +++ b/pallets/liquidity-pools-gateway/src/tests.rs @@ -2737,86 +2737,140 @@ mod execute_message_recovery { use super::*; #[test] - fn success() { + fn success_with_execution() { new_test_ext().execute_with(|| { - let router_id = ROUTER_ID_1; let session_id = 1; - Routers::::set(Some(BoundedVec::try_from(vec![router_id.clone()]).unwrap())); + Routers::::set(Some( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2]).unwrap(), + )); SessionIdStore::::set(Some(session_id)); + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, ROUTER_ID_1), + InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 1, + }, + ); + + let handler = + MockLiquidityPools::mock_handle(move |mock_domain_address, mock_message| { + assert_eq!(mock_domain_address, TEST_DOMAIN_ADDRESS); + assert_eq!(mock_message, Message::Simple); + + Ok(()) + }); + assert_ok!(LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, - router_id.clone(), + ROUTER_ID_2, )); event_exists(Event::::MessageRecoveryExecuted { proof: MESSAGE_PROOF, - router_id: router_id.clone(), + router_id: ROUTER_ID_2, }); - let inbound_entry = PendingInboundEntries::::get( + assert_eq!(handler.times(), 1); + + assert!(PendingInboundEntries::::get( session_id, - (MESSAGE_PROOF, router_id.clone()), + (MESSAGE_PROOF, ROUTER_ID_1) ) - .expect("inbound entry is stored"); + .is_none()); - assert_eq!( - inbound_entry, - InboundEntry::::Proof { current_count: 1 } + assert!(PendingInboundEntries::::get( + session_id, + (MESSAGE_PROOF, ROUTER_ID_2) + ) + .is_none()); + }); + } + + #[test] + fn success_without_execution() { + new_test_ext().execute_with(|| { + let session_id = 1; + + Routers::::set(Some( + BoundedVec::try_from(vec![ROUTER_ID_1, ROUTER_ID_2, ROUTER_ID_3]).unwrap(), + )); + SessionIdStore::::set(Some(session_id)); + + PendingInboundEntries::::insert( + session_id, + (MESSAGE_PROOF, ROUTER_ID_1), + InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }, ); assert_ok!(LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, - router_id.clone() + ROUTER_ID_2, )); - let inbound_entry = PendingInboundEntries::::get( - session_id, - (MESSAGE_PROOF, router_id.clone()), - ) - .expect("inbound entry is stored"); - - assert_eq!( - inbound_entry, - InboundEntry::::Proof { current_count: 2 } - ); - event_exists(Event::::MessageRecoveryExecuted { proof: MESSAGE_PROOF, - router_id, + router_id: ROUTER_ID_2, }); + + assert_eq!( + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, ROUTER_ID_1)), + Some(InboundEntry::::Message { + domain_address: TEST_DOMAIN_ADDRESS, + message: Message::Simple, + expected_proof_count: 2, + }) + ); + assert_eq!( + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, ROUTER_ID_2)), + Some(InboundEntry::::Proof { current_count: 1 }) + ); + assert!( + PendingInboundEntries::::get(session_id, (MESSAGE_PROOF, ROUTER_ID_3)) + .is_none() + ) }); } #[test] - fn session_id_not_found() { + fn routers_not_found() { new_test_ext().execute_with(|| { assert_noop!( LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, ROUTER_ID_1, ), - Error::::SessionIdNotFound + Error::::RoutersNotFound ); }); } #[test] - fn routers_not_found() { + fn session_id_not_found() { new_test_ext().execute_with(|| { - SessionIdStore::::set(Some(1)); + Routers::::set(Some(BoundedVec::try_from(vec![ROUTER_ID_1]).unwrap())); assert_noop!( LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, ROUTER_ID_1, ), - Error::::RoutersNotFound + Error::::SessionIdNotFound ); }); } @@ -2830,6 +2884,7 @@ mod execute_message_recovery { assert_noop!( LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, ROUTER_ID_2 ), @@ -2857,6 +2912,7 @@ mod execute_message_recovery { assert_noop!( LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, router_id ), @@ -2887,6 +2943,7 @@ mod execute_message_recovery { assert_noop!( LiquidityPoolsGateway::execute_message_recovery( RuntimeOrigin::root(), + TEST_DOMAIN_ADDRESS, MESSAGE_PROOF, router_id ), diff --git a/pallets/liquidity-pools-gateway/src/weights.rs b/pallets/liquidity-pools-gateway/src/weights.rs index aaf598ec91..b330d71ac6 100644 --- a/pallets/liquidity-pools-gateway/src/weights.rs +++ b/pallets/liquidity-pools-gateway/src/weights.rs @@ -13,7 +13,7 @@ use frame_support::weights::{constants::RocksDbWeight, Weight}; pub trait WeightInfo { - fn set_domain_routers() -> Weight; + fn set_routers() -> Weight; fn add_instance() -> Weight; fn remove_instance() -> Weight; fn add_relayer() -> Weight; @@ -31,7 +31,7 @@ pub trait WeightInfo { const N: u64 = 4; impl WeightInfo for () { - fn set_domain_routers() -> Weight { + fn set_routers() -> Weight { // TODO: BENCHMARK CORRECTLY // // NOTE: Reasonable weight taken from `PoolSystem::set_max_reserve`