Skip to content

Commit

Permalink
Merge pull request #8851 from OpenMined/node-peer-partial-update
Browse files Browse the repository at this point in the history
Fix concurrency issue with NodePeer update
  • Loading branch information
shubham3121 committed Jun 3, 2024
2 parents 842fde9 + ba0445b commit 1239eed
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 185 deletions.
97 changes: 69 additions & 28 deletions packages/syft/src/syft/client/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

# relative
from ..service.metadata.node_metadata import NodeMetadataJSON
from ..service.network.network_service import NodePeer
from ..service.network.node_peer import NodePeer
from ..service.network.node_peer import NodePeerConnectionStatus
from ..service.response import SyftException
from ..types.grid_url import GridURL
from ..util.constants import DEFAULT_TIMEOUT
Expand Down Expand Up @@ -120,13 +121,40 @@ def _repr_html_(self) -> str:
on = self.online_networks
if len(on) == 0:
return "(no gateways online - try syft.gateways.all_networks to see offline gateways)"
return pd.DataFrame(on)._repr_html_() # type: ignore
df = pd.DataFrame(on)
total_df = pd.DataFrame(
[
[
f"{len(on)} / {len(self.all_networks)} (online networks / all networks)"
]
+ [""] * (len(df.columns) - 1)
],
columns=df.columns,
index=["Total"],
)
df = pd.concat([df, total_df])
return df._repr_html_() # type: ignore

def __repr__(self) -> str:
on = self.online_networks
if len(on) == 0:
return "(no gateways online - try syft.gateways.all_networks to see offline gateways)"
return pd.DataFrame(on).to_string()
df = pd.DataFrame(on)
total_df = pd.DataFrame(
[
[
f"{len(on)} / {len(self.all_networks)} (online networks / all networks)"
]
+ [""] * (len(df.columns) - 1)
],
columns=df.columns,
index=["Total"],
)
df = pd.concat([df, total_df])
return df.to_string()

def __len__(self) -> int:
return len(self.all_networks)

@staticmethod
def create_client(network: dict[str, Any]) -> Client:
Expand Down Expand Up @@ -228,32 +256,25 @@ def check_network(network: dict) -> dict[Any, Any] | None:

@property
def online_domains(self) -> list[tuple[NodePeer, NodeMetadataJSON | None]]:
def check_domain(
peer: NodePeer,
) -> tuple[NodePeer, NodeMetadataJSON | None] | None:
try:
guest_client = peer.guest_client
metadata = guest_client.metadata
return peer, metadata
except Exception as e: # nosec
print(f"Error in checking domain with exception {e}")
return None

networks = self.online_networks

# We can use a with statement to ensure threads are cleaned up promptly
with futures.ThreadPoolExecutor(max_workers=20) as executor:
# map
_all_online_domains = []
for network in networks:
_all_online_domains = []
for network in networks:
try:
network_client = NetworkRegistry.create_client(network)
domains: list[NodePeer] = network_client.domains.retrieve_nodes()
for domain in domains:
self.all_domains[str(domain.id)] = domain
_online_domains = list(
executor.map(lambda domain: check_domain(domain), domains)
)
_all_online_domains += _online_domains
except Exception as e:
print(f"Error in creating network client with exception {e}")
continue

domains: list[NodePeer] = network_client.domains.retrieve_nodes()
for domain in domains:
self.all_domains[str(domain.id)] = domain

_all_online_domains += [
(domain, domain.guest_client.metadata)
for domain in domains
if domain.ping_status == NodePeerConnectionStatus.ACTIVE
]

return [domain for domain in _all_online_domains if domain is not None]

Expand Down Expand Up @@ -281,13 +302,33 @@ def _repr_html_(self) -> str:
on: list[dict[str, Any]] = self.__make_dict__()
if len(on) == 0:
return "(no domains online - try syft.domains.all_domains to see offline domains)"
return pd.DataFrame(on)._repr_html_() # type: ignore
df = pd.DataFrame(on)
total_df = pd.DataFrame(
[
[f"{len(on)} / {len(self.all_domains)} (online domains / all domains)"]
+ [""] * (len(df.columns) - 1)
],
columns=df.columns,
index=["Total"],
)
df = pd.concat([df, total_df])
return df._repr_html_() # type: ignore

def __repr__(self) -> str:
on: list[dict[str, Any]] = self.__make_dict__()
if len(on) == 0:
return "(no domains online - try syft.domains.all_domains to see offline domains)"
return pd.DataFrame(on).to_string()
df = pd.DataFrame(on)
total_df = pd.DataFrame(
[
[f"{len(on)} / {len(self.all_domains)} (online domains / all domains)"]
+ [""] * (len(df.columns) - 1)
],
columns=df.columns,
index=["Total"],
)
df = pd.concat([df, total_df])
return df._repr_html_() # type: ignore

def create_client(self, peer: NodePeer) -> Client:
try:
Expand Down
7 changes: 7 additions & 0 deletions packages/syft/src/syft/protocol/protocol_version.json
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@
"hash": "c1796e7b01c9eae0dbf59cfd5c2c2f0e7eba593e0cea615717246572b27aae4b",
"action": "remove"
}
},
"NodePeerUpdate": {
"1": {
"version": 1,
"hash": "9e7cd39f6a9f90e8c595452865525e0989df1688236acfd1a665ed047ba47de9",
"action": "add"
}
}
}
}
Expand Down
97 changes: 55 additions & 42 deletions packages/syft/src/syft/service/network/network_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

# third party
from result import Err
from loguru import logger
from result import Result

# relative
Expand Down Expand Up @@ -50,6 +50,7 @@
from ..warnings import CRUDWarning
from .association_request import AssociationRequestChange
from .node_peer import NodePeer
from .node_peer import NodePeerUpdate
from .routes import HTTPNodeRoute
from .routes import NodeRoute
from .routes import NodeRouteType
Expand Down Expand Up @@ -87,13 +88,13 @@ def get_by_name(
def update(
self,
credentials: SyftVerifyKey,
peer: NodePeer,
peer_update: NodePeerUpdate,
has_permission: bool = False,
) -> Result[NodePeer, str]:
valid = self.check_type(peer, NodePeer)
valid = self.check_type(peer_update, NodePeerUpdate)
if valid.is_err():
return Err(SyftError(message=valid.err()))
return super().update(credentials, peer)
return SyftError(message=valid.err())
return super().update(credentials, peer_update, has_permission=has_permission)

def create_or_update_peer(
self, credentials: SyftVerifyKey, peer: NodePeer
Expand All @@ -113,13 +114,15 @@ def create_or_update_peer(
valid = self.check_type(peer, NodePeer)
if valid.is_err():
return SyftError(message=valid.err())
existing: Result | NodePeer = self.get_by_uid(
credentials=credentials, uid=peer.id
)
if existing.is_ok() and existing.ok():
existing = existing.ok()
existing.update_routes(peer.node_routes)
result = self.update(credentials, existing)

existing = self.get_by_uid(credentials=credentials, uid=peer.id)
if existing.is_ok() and existing.ok() is not None:
existing_peer: NodePeer = existing.ok()
existing_peer.update_routes(peer.node_routes)
peer_update = NodePeerUpdate(
id=peer.id, node_routes=existing_peer.node_routes
)
result = self.update(credentials, peer_update)
return result
else:
result = self.set(credentials, peer)
Expand Down Expand Up @@ -150,8 +153,6 @@ def __init__(self, store: DocumentStore) -> None:
self.store = store
self.stash = NetworkStash(store=store)

# TODO: Check with MADHAVA, can we even allow guest user to introduce routes to
# domain nodes?
@service_method(
path="network.exchange_credentials_with",
name="exchange_credentials_with",
Expand Down Expand Up @@ -191,26 +192,21 @@ def exchange_credentials_with(
existing_peer_result.is_ok()
and (existing_peer := existing_peer_result.ok()) is not None
):
msg = [
(
f"{existing_peer.node_type} peer '{existing_peer.name}' already exist for "
f"{self_node_peer.node_type} '{self_node_peer.name}'."
)
]
logger.info(
f"{remote_node_peer.node_type} '{remote_node_peer.name}' already exist as a peer for "
f"{self_node_peer.node_type} '{self_node_peer.name}'."
)

if existing_peer != remote_node_peer:
result = self.stash.create_or_update_peer(
context.node.verify_key,
remote_node_peer,
)
msg.append(
f"{existing_peer.node_type} peer '{existing_peer.name}' information change detected."
)
if result.is_err():
msg.append(
f"Attempt to update peer '{existing_peer.name}' information failed."
return SyftError(
message=f"Failed to update peer: {remote_node_peer.name} information."
)
return SyftError(message="\n".join(msg))
msg.append(
logger.info(
f"{existing_peer.node_type} peer '{existing_peer.name}' information successfully updated."
)

Expand All @@ -219,28 +215,32 @@ def exchange_credentials_with(
name=self_node_peer.name
)
if isinstance(remote_self_node_peer, NodePeer):
msg.append(
logger.info(
f"{self_node_peer.node_type} '{self_node_peer.name}' already exist "
f"as a peer for {remote_node_peer.node_type} '{remote_node_peer.name}'."
)
if remote_self_node_peer != self_node_peer:
updated_peer = NodePeerUpdate(
id=self_node_peer.id, node_routes=self_node_peer.node_routes
)
result = remote_client.api.services.network.update_peer(
peer=self_node_peer,
peer_update=updated_peer
)
msg.append(
logger.info(
f"{self_node_peer.node_type} peer '{self_node_peer.name}' information change detected."
)
if isinstance(result, SyftError):
msg.apnpend(
logger.error(
f"Attempt to remotely update {self_node_peer.node_type} peer "
f"'{self_node_peer.name}' information remotely failed."
f"'{self_node_peer.name}' information remotely failed. Error: {result.message}"
)
return SyftError(message="\n".join(msg))
msg.append(
return SyftError(message="Failed to update peer information.")

logger.info(
f"{self_node_peer.node_type} peer '{self_node_peer.name}' "
f"information successfully updated."
)
msg.append(
msg = (
f"Routes between {remote_node_peer.node_type} '{remote_node_peer.name}' and "
f"{self_node_peer.node_type} '{self_node_peer.name}' already exchanged."
)
Expand Down Expand Up @@ -465,20 +465,24 @@ def get_peers_by_type(
return result.ok() or []

@service_method(
path="network.update_peer", name="update_peer", roles=GUEST_ROLE_LEVEL
path="network.update_peer",
name="update_peer",
roles=GUEST_ROLE_LEVEL,
)
def update_peer(
self,
context: AuthedServiceContext,
peer: NodePeer,
peer_update: NodePeerUpdate,
) -> SyftSuccess | SyftError:
# try setting all fields of NodePeerUpdate according to NodePeer

result = self.stash.update(
credentials=context.node.verify_key,
peer=peer,
peer_update=peer_update,
)
if result.is_err():
return SyftError(
message=f"Failed to update peer '{peer.name}'. Error: {result.err()}"
message=f"Failed to update peer '{peer_update.name}'. Error: {result.err()}"
)
return SyftSuccess(
message=f"Peer '{result.ok().name}' information successfully updated."
Expand Down Expand Up @@ -589,9 +593,12 @@ def add_route(
f"peer '{remote_node_peer.name}' with id '{existed_route.id}'."
)
# update the peer in the store with the updated routes
peer_update = NodePeerUpdate(
id=remote_node_peer.id, node_routes=remote_node_peer.node_routes
)
result = self.stash.update(
credentials=context.node.verify_key,
peer=remote_node_peer,
peer_update=peer_update,
)
if result.is_err():
return SyftError(message=str(result.err()))
Expand Down Expand Up @@ -747,8 +754,11 @@ def delete_route(
)
else:
# update the peer with the route removed
peer_update = NodePeerUpdate(
id=remote_node_peer.id, node_routes=remote_node_peer.node_routes
)
result = self.stash.update(
credentials=context.node.verify_key, peer=remote_node_peer
credentials=context.node.verify_key, peer_update=peer_update
)
if result.is_err():
return SyftError(message=str(result.err()))
Expand Down Expand Up @@ -846,7 +856,10 @@ def update_route_priority(
return updated_node_route
new_priority: int = updated_node_route.priority
# update the peer in the store
result = self.stash.update(context.node.verify_key, remote_node_peer)
peer_update = NodePeerUpdate(
id=remote_node_peer.id, node_routes=remote_node_peer.node_routes
)
result = self.stash.update(context.node.verify_key, peer_update)
if result.is_err():
return SyftError(message=str(result.err()))

Expand Down
Loading

0 comments on commit 1239eed

Please sign in to comment.