Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(framework) One gRPC request per Message in GrpcDriver #5073

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 33 additions & 23 deletions src/py/flwr/server/driver/grpc_driver.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import time
import warnings
from collections.abc import Iterable
from concurrent.futures import ThreadPoolExecutor
from logging import DEBUG, WARNING
from typing import Optional, cast

@@ -33,7 +34,6 @@
from flwr.common.retry_invoker import _make_simple_grpc_retry_invoker, _wrap_stub
from flwr.common.serde import message_from_proto, message_to_proto, run_from_proto
from flwr.common.typing import Run
from flwr.proto.message_pb2 import Message as ProtoMessage # pylint: disable=E0611
from flwr.proto.node_pb2 import Node # pylint: disable=E0611
from flwr.proto.run_pb2 import GetRunRequest, GetRunResponse # pylint: disable=E0611
from flwr.proto.serverappio_pb2 import ( # pylint: disable=E0611
@@ -197,48 +197,58 @@ def push_messages(self, messages: Iterable[Message]) -> Iterable[str]:
This method takes an iterable of messages and sends each message
to the node specified in `dst_node_id`.
"""
# Construct Messages
message_proto_list: list[ProtoMessage] = []
for msg in messages:

def push_msg(msg: Message) -> Optional[str]:
# Check message
self._check_message(msg)
# Convert to proto
msg_proto = message_to_proto(msg)
# Add to list
message_proto_list.append(msg_proto)
# Call GrpcDriverStub method
res: PushInsMessagesResponse = self._stub.PushMessages(
PushInsMessagesRequest(
messages_list=message_proto_list, run_id=cast(Run, self._run).run_id
# Call GrpcDriverStub method
res: PushInsMessagesResponse = self._stub.PushMessages(
PushInsMessagesRequest(
messages_list=[msg_proto],
run_id=cast(Run, self._run).run_id,
)
)
)
if len([msg_id for msg_id in res.message_ids if msg_id]) != len(
list(message_proto_list)
):
return res.message_ids[0] if res.message_ids else None

# Use ThreadPoolExecutor to push messages concurrently with map
with ThreadPoolExecutor() as executor:
results = list(executor.map(push_msg, messages))

if None in results:
log(
WARNING,
"Not all messages could be pushed to the SuperLink. The returned "
"list has `None` for those messages (the order is preserved as passed "
"to `push_messages`). This could be due to a malformed message.",
)
return list(res.message_ids)
return [msg_id for msg_id in results if msg_id is not None]

def pull_messages(self, message_ids: Iterable[str]) -> Iterable[Message]:
"""Pull messages based on message IDs.

This method is used to collect messages from the SuperLink that correspond to a
set of given message IDs.
"""

# Pull Messages
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=message_ids,
run_id=cast(Run, self._run).run_id,
def pull_msg(message_id: str) -> Optional[Message]:
res: PullResMessagesResponse = self._stub.PullMessages(
PullResMessagesRequest(
message_ids=[message_id],
run_id=cast(Run, self._run).run_id,
)
)
# Convert Message from Protobuf representation
if not res.messages_list:
return None
return message_from_proto(res.messages_list[0])

with ThreadPoolExecutor(max_workers=5) as executor:
yield from (
msg for msg in executor.map(pull_msg, message_ids) if msg is not None
)
)
# Convert Message from Protobuf representation
msgs = [message_from_proto(msg_proto) for msg_proto in res.messages_list]
return msgs

def send_and_receive(
self,