Skip to content

Commit

Permalink
Don't use domain when storing routers (#1962)
Browse files Browse the repository at this point in the history
* lp-gateway: Unit tests WIP

* lp-gateway: Don't store routers under domain

* wip
  • Loading branch information
cdamian authored Aug 13, 2024
1 parent 5c80597 commit 317631c
Show file tree
Hide file tree
Showing 4 changed files with 534 additions and 611 deletions.
67 changes: 26 additions & 41 deletions pallets/liquidity-pools-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use message::GatewayMessage;
use orml_traits::GetByKey;
pub use pallet::*;
use parity_scale_codec::FullCodec;
use sp_arithmetic::traits::{BaseArithmetic, EnsureAddAssign, One};
use sp_arithmetic::traits::{BaseArithmetic, EnsureAdd, EnsureAddAssign, One};
use sp_std::convert::TryInto;

use crate::{message_processing::InboundEntry, weights::WeightInfo};
Expand Down Expand Up @@ -102,7 +102,7 @@ pub mod pallet {
/// The Liquidity Pools message type.
type Message: LPEncoding + Clone + Debug + PartialEq + MaxEncodedLen + TypeInfo + FullCodec;

/// The target of of the messages comming from this chain
/// The target of the messages coming from this chain
type MessageSender: MessageSender<Middleware = Self::RouterId, Origin = DomainAddress>;

/// An identification of a router
Expand Down Expand Up @@ -148,7 +148,6 @@ pub mod pallet {
pub enum Event<T: Config> {
/// The routers for a given domain were set.
RoutersSet {
domain: Domain,
router_ids: BoundedVec<T::RouterId, T::MaxRouterCount>,
session_id: T::SessionId,
},
Expand All @@ -167,7 +166,6 @@ pub mod pallet {

/// Message recovery was executed.
MessageRecoveryExecuted {
domain: Domain,
proof: Proof,
router_id: T::RouterId,
},
Expand All @@ -182,13 +180,13 @@ pub mod pallet {
// pub type DomainRouters<T: Config> = StorageMap<_, Blake2_128Concat, Domain,
// T::Router>;

/// Storage for routers specific for a domain.
/// Storage for routers.
///
/// This can only be set by an admin.
#[pallet::storage]
#[pallet::getter(fn routers)]
pub type Routers<T: Config> =
StorageMap<_, Blake2_128Concat, Domain, BoundedVec<T::RouterId, T::MaxRouterCount>>;
StorageValue<_, BoundedVec<T::RouterId, T::MaxRouterCount>, ValueQuery>;

/// Storage that contains a limited number of whitelisted instances of
/// deployed liquidity pools for a particular domain.
Expand Down Expand Up @@ -227,11 +225,11 @@ pub mod pallet {
InboundEntry<T>,
>;

/// Storage for the inbound message session IDs.
#[pallet::storage]
#[pallet::getter(fn inbound_message_sessions)]
pub type InboundMessageSessions<T: Config> =
StorageMap<_, Blake2_128Concat, Domain, T::SessionId>;
// /// Storage for the inbound message session IDs.
// #[pallet::storage]
// #[pallet::getter(fn inbound_message_sessions)]
// pub type InboundMessageSessions<T: Config> =
// StorageMap<_, Blake2_128Concat, Domain, T::SessionId>;

/// Storage for inbound message session IDs.
#[pallet::storage]
Expand Down Expand Up @@ -279,8 +277,8 @@ pub mod pallet {
/// Inbound domain session not found.
InboundDomainSessionNotFound,

/// The router that sent the inbound message is unknown.
UnknownInboundMessageRouter,
/// Unknown router.
UnknownRouter,

/// The router that sent the message is not the first one.
MessageExpectedFromFirstRouter,
Expand Down Expand Up @@ -309,35 +307,28 @@ pub mod pallet {
/// Sets the router IDs used for a specific domain,
#[pallet::weight(T::WeightInfo::set_domain_routers())]
#[pallet::call_index(0)]
pub fn set_domain_routers(
pub fn set_routers(
origin: OriginFor<T>,
domain: Domain,
router_ids: BoundedVec<T::RouterId, T::MaxRouterCount>,
) -> DispatchResult {
T::AdminOrigin::ensure_origin(origin)?;

ensure!(domain != Domain::Centrifuge, Error::<T>::DomainNotSupported);

//TODO(cdamian): Outbound - Call router.init() for each router?
<Routers<T>>::set(router_ids.clone());

<Routers<T>>::insert(domain.clone(), router_ids.clone());
let (old_session_id, new_session_id) = SessionIdStore::<T>::try_mutate(|n| {
let old_session_id = *n;
let new_session_id = n.ensure_add(One::one())?;

if let Some(old_session_id) = InboundMessageSessions::<T>::get(domain.clone()) {
InvalidMessageSessions::<T>::insert(old_session_id, ());
}
*n = new_session_id;

let session_id = SessionIdStore::<T>::try_mutate(|n| {
n.ensure_add_assign(One::one())?;

Ok::<T::SessionId, DispatchError>(*n)
Ok::<(T::SessionId, T::SessionId), DispatchError>((old_session_id, new_session_id))
})?;

InboundMessageSessions::<T>::insert(domain.clone(), session_id);
InvalidMessageSessions::<T>::insert(old_session_id, ());

Self::deposit_event(Event::RoutersSet {
domain,
router_ids,
session_id,
session_id: new_session_id,
});

Ok(())
Expand Down Expand Up @@ -452,7 +443,7 @@ pub mod pallet {

match PackedMessage::<T>::take((&sender, &destination)) {
Some(msg) if msg.submessages().is_empty() => Ok(()), //No-op
Some(message) => Self::queue_message(destination, message),
Some(message) => Self::queue_outbound_message(destination, message),
None => Err(Error::<T>::MessagePackingNotStarted.into()),
}
}
Expand All @@ -465,20 +456,18 @@ pub mod pallet {
#[pallet::call_index(11)]
pub fn execute_message_recovery(
origin: OriginFor<T>,
domain: Domain,
proof: Proof,
router_id: T::RouterId,
) -> DispatchResult {
T::AdminOrigin::ensure_origin(origin)?;

let session_id = InboundMessageSessions::<T>::get(&domain)
.ok_or(Error::<T>::InboundDomainSessionNotFound)?;
let session_id = SessionIdStore::<T>::get();

let routers = Routers::<T>::get(&domain).ok_or(Error::<T>::RoutersNotFound)?;
let routers = Routers::<T>::get();

ensure!(
routers.iter().any(|x| x == &router_id),
Error::<T>::UnknownInboundMessageRouter
Error::<T>::UnknownRouter
);

PendingInboundEntries::<T>::try_mutate(
Expand All @@ -501,11 +490,7 @@ pub mod pallet {
},
)?;

Self::deposit_event(Event::<T>::MessageRecoveryExecuted {
domain,
proof,
router_id,
});
Self::deposit_event(Event::<T>::MessageRecoveryExecuted { proof, router_id });

Ok(())
}
Expand All @@ -528,7 +513,7 @@ pub mod pallet {

PackedMessage::<T>::mutate((&from, destination.clone()), |batch| match batch {
Some(batch) => batch.pack_with(message),
None => Self::queue_message(destination, message),
None => Self::queue_outbound_message(destination, message),
})
}
}
Expand Down
78 changes: 35 additions & 43 deletions pallets/liquidity-pools-gateway/src/message_processing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use cfg_primitives::LP_DEFENSIVE_WEIGHT;
use cfg_traits::liquidity_pools::{InboundMessageHandler, LPEncoding, MessageQueue, Proof};
use cfg_traits::liquidity_pools::{
InboundMessageHandler, LPEncoding, MessageQueue, MessageSender, Proof, RouterSupport,
};
use cfg_types::domain_address::{Domain, DomainAddress};
use frame_support::{
dispatch::DispatchResult,
Expand All @@ -13,8 +15,8 @@ use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub};
use sp_runtime::DispatchError;

use crate::{
message::GatewayMessage, Config, Error, InboundMessageSessions, InvalidMessageSessions, Pallet,
PendingInboundEntries, Routers,
message::GatewayMessage, Config, Error, InvalidMessageSessions, Pallet, PendingInboundEntries,
Routers, SessionIdStore,
};

/// The limit used when clearing the `PendingInboundEntries` for invalid
Expand All @@ -39,18 +41,18 @@ pub enum InboundEntry<T: Config> {
#[derive(Clone)]
pub struct InboundProcessingInfo<T: Config> {
domain_address: DomainAddress,
routers: BoundedVec<T::RouterId, T::MaxRouterCount>,
router_ids: Vec<T::RouterId, T::MaxRouterCount>,
current_session_id: T::SessionId,
expected_proof_count_per_message: u32,
}

impl<T: Config> Pallet<T> {
/// Calculates and returns the proof count required for processing one
/// inbound message.
fn get_expected_proof_count(domain: &Domain) -> Result<u32, DispatchError> {
let routers = Routers::<T>::get(domain).ok_or(Error::<T>::RoutersNotFound)?;
fn get_expected_proof_count(domain: Domain) -> Result<u32, DispatchError> {
let routers_count = T::RouterId::for_domain(domain).len();

let expected_proof_count = routers.len().ensure_sub(1)?;
let expected_proof_count = routers_count.ensure_sub(1)?;

Ok(expected_proof_count as u32)
}
Expand Down Expand Up @@ -94,25 +96,25 @@ impl<T: Config> Pallet<T> {
router_id: &T::RouterId,
inbound_entry: &InboundEntry<T>,
) -> DispatchResult {
let routers = inbound_processing_info.routers.clone();
let router_ids = inbound_processing_info.router_ids.clone();

ensure!(
routers.iter().any(|x| x == router_id),
Error::<T>::UnknownInboundMessageRouter
router_ids.iter().any(|x| x == router_id),
Error::<T>::UnknownRouter
);

match inbound_entry {
InboundEntry::Message { .. } => {
ensure!(
routers.get(0) == Some(&router_id),
router_ids.get(0) == Some(&router_id),
Error::<T>::MessageExpectedFromFirstRouter
);

Ok(())
}
InboundEntry::Proof { .. } => {
ensure!(
routers.get(0) != Some(&router_id),
router_ids.get(0) != Some(&router_id),
Error::<T>::ProofNotExpectedFromFirstRouter
);

Expand Down Expand Up @@ -202,12 +204,12 @@ impl<T: Config> Pallet<T> {
let mut message = None;
let mut votes = 0;

for router in &inbound_processing_info.routers {
for router_id in &inbound_processing_info.router_ids {
weight.saturating_accrue(T::DbWeight::get().reads(1));

match PendingInboundEntries::<T>::get(
inbound_processing_info.current_session_id,
(message_proof, router),
(message_proof, router_id),
) {
// We expected one InboundEntry for each router, if that's not the case,
// we can return.
Expand Down Expand Up @@ -240,12 +242,12 @@ impl<T: Config> Pallet<T> {
message_proof: Proof,
weight: &mut Weight,
) -> DispatchResult {
for router in &inbound_processing_info.routers {
for router_id in &inbound_processing_info.router_ids {
weight.saturating_accrue(T::DbWeight::get().writes(1));

match PendingInboundEntries::<T>::try_mutate(
inbound_processing_info.current_session_id,
(message_proof, router),
(message_proof, router_id),
|storage_entry| match storage_entry {
None => Err(Error::<T>::PendingInboundEntryNotFound.into()),
Some(stored_inbound_entry) => match stored_inbound_entry {
Expand Down Expand Up @@ -290,26 +292,24 @@ impl<T: Config> Pallet<T> {
/// Retrieves the information required for processing an inbound
/// message.
fn get_inbound_processing_info(
domain_address: DomainAddress,
domain: Domain,
weight: &mut Weight,
) -> Result<InboundProcessingInfo<T>, DispatchError> {
let routers =
Routers::<T>::get(domain_address.domain()).ok_or(Error::<T>::RoutersNotFound)?;
let router_ids = T::RouterId::for_domain(domain.clone());

weight.saturating_accrue(T::DbWeight::get().reads(1));

let current_session_id = InboundMessageSessions::<T>::get(domain_address.domain())
.ok_or(Error::<T>::InboundDomainSessionNotFound)?;
let current_session_id = SessionIdStore::<T>::get();

weight.saturating_accrue(T::DbWeight::get().reads(1));

let expected_proof_count = Self::get_expected_proof_count(&domain_address.domain())?;
let expected_proof_count = Self::get_expected_proof_count(domain)?;

weight.saturating_accrue(T::DbWeight::get().reads(1));

Ok(InboundProcessingInfo {
domain_address,
routers,
router_ids,
current_session_id,
expected_proof_count_per_message: expected_proof_count,
})
Expand All @@ -325,7 +325,7 @@ impl<T: Config> Pallet<T> {
let mut weight = Default::default();

let inbound_processing_info =
match Self::get_inbound_processing_info(domain_address.clone(), &mut weight) {
match Self::get_inbound_processing_info(domain_address.domain(), &mut weight) {
Ok(i) => i,
Err(e) => return (Err(e), weight),
};
Expand Down Expand Up @@ -382,29 +382,21 @@ impl<T: Config> Pallet<T> {
message: T::Message,
router_id: T::RouterId,
) -> (DispatchResult, Weight) {
let read_weight = T::DbWeight::get().reads(1);

// TODO(cdamian): Update when the router refactor is done.

// let Some(router) = Routers::<T>::get(router_id) else {
// return (Err(Error::<T>::RouterNotFound.into()), read_weight);
// };
//
// let (result, router_weight) = match router.send(sender, message.serialize())
// { Ok(dispatch_info) => (Ok(()), dispatch_info.actual_weight),
// Err(e) => (Err(e.error), e.post_info.actual_weight),
// };
//
// (result, router_weight.unwrap_or(read_weight))

(Ok(()), read_weight)
let weight = LP_DEFENSIVE_WEIGHT;

match T::MessageSender::send(router_id, sender, message.serialize()) {
Ok(_) => (Ok(()), weight),
Err(e) => (Err(e), weight),
}
}

/// Retrieves the IDs of the routers set for a domain and queues the
/// message and proofs accordingly.
pub(crate) fn queue_message(destination: Domain, message: T::Message) -> DispatchResult {
let router_ids =
Routers::<T>::get(destination.clone()).ok_or(Error::<T>::RoutersNotFound)?;
pub(crate) fn queue_outbound_message(
destination: Domain,
message: T::Message,
) -> DispatchResult {
let router_ids = T::RouterId::for_domain(destination);

let message_proof = message.to_message_proof();
let mut message_opt = Some(message);
Expand Down
4 changes: 2 additions & 2 deletions pallets/liquidity-pools-gateway/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ impl LPEncoding for Message {
}
}

#[derive(Default, Debug, Encode, Decode, Clone, PartialEq, Eq, TypeInfo, MaxEncodedLen)]
pub struct RouterId(u32);
#[derive(Default, Debug, Encode, Decode, Clone, PartialEq, Eq, TypeInfo, MaxEncodedLen, Hash)]
pub struct RouterId(pub u32);

impl RouterSupport<Domain> for RouterId {
fn for_domain(_domain: Domain) -> Vec<RouterId> {
Expand Down
Loading

0 comments on commit 317631c

Please sign in to comment.