Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example for Secure Aggregation with low-level API #4967

Draft
wants to merge 2 commits into
base: sa-aggregator
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/flower-secure-aggregation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ flwr run . --run-config is-demo=false

> \[!NOTE\]
> An update to this example will show how to run this Flower project with the Deployment Engine and TLS certificates, or with Docker.


## Advanced: Use with Low-level API

### Update the `pyproject.toml` file
Change the `[tool.flwr.app.components]` section as follows.

```
[tool.flwr.app.components]
serverapp = "secaggexample.server_app_low_level:app"
clientapp = "secaggexample.client_app_low_level:app"
```
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import time

import torch
from secaggexample.task import Net, get_weights, load_data, set_weights, test, train

from flwr.client import ClientApp, NumPyClient
from flwr.client.mod import secaggplus_mod
from flwr.common import Context

from secaggexample.task import Net, get_weights, load_data, set_weights, test, train


# Define Flower Client
class FlowerClient(NumPyClient):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""secaggexample: A Flower with SecAgg+ app."""

import numpy as np

from flwr.client import ClientApp
from flwr.client.mod import secaggplus_base_mod
from flwr.common import Context, Message, ParametersRecord, RecordSet, array_from_numpy

# Flower ClientApp
app = ClientApp(mods=[secaggplus_base_mod])


@app.query()
def simple_query(msg: Message, ctxt: Context) -> Message:
"""Simple query function."""
pr = ParametersRecord()
pr["simple_array"] = array_from_numpy(np.array([0.5, 1, 2]))
print(f"Sending simple array: {pr['simple_array'].numpy()}")
content = RecordSet()
content.parameters_records["simple_pr"] = pr
return msg.create_reply(content=content)
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
from logging import DEBUG
from typing import List, Tuple

from secaggexample.task import get_weights, make_net
from secaggexample.workflow_with_log import SecAggPlusWorkflowWithLogs

from flwr.common import Context, Metrics, ndarrays_to_parameters
from flwr.common.logger import update_console_handler
from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig
from flwr.server.strategy import FedAvg
from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow

from secaggexample.task import get_weights, make_net
from secaggexample.workflow_with_log import SecAggPlusWorkflowWithLogs


# Define metric aggregation function
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""secaggexample: A Flower with SecAgg+ app."""

import time
from collections.abc import Iterable
from logging import DEBUG, INFO

from flwr.common import Context, Message, RecordSet, log
from flwr.common.constant import MessageType
from flwr.common.logger import update_console_handler
from flwr.server import Driver, ServerApp
from flwr.server.workflow import SecAggPlusAggregator
from flwr.server.workflow.secure_aggregation.secaggplus_aggregator import (
SecAggPlusAggregatorState,
)

# Flower ServerApp
app = ServerApp()


@app.main()
def main(driver: Driver, context: Context) -> None:
# Show debug logs
update_console_handler(DEBUG, True, True)

# Sample at least 5 clients
log(INFO, "Waiting for at least 5 clients to connect...")
while True:
nids = driver.get_node_ids()
if len(nids) >= 5:
break
time.sleep(0.1)

# Create the SecAgg+ aggregator
sa_aggregator = SecAggPlusAggregator(
driver=driver,
context=context,
num_shares=context.run_config["num-shares"],
reconstruction_threshold=context.run_config["reconstruction-threshold"],
timeout=context.run_config["timeout"],
clipping_range=8.0,
on_send=on_send,
on_receive=on_receive,
on_stage_complete=on_stage_complete,
)
msgs = [
driver.create_message(
content=RecordSet(), # Empty message
message_type=MessageType.QUERY,
dst_node_id=nid,
group_id="",
)
for nid in nids
]
msgs[0].metadata.group_id = "drop"
msg = sa_aggregator.aggregate(msgs)

arr = msg.content.parameters_records["simple_pr"]["simple_array"]
log(INFO, f"Received aggregated array: {arr.numpy()}")


# Example `on_send`/`on_receive`/`on_stage_complete` callback functions
def on_send(msgs: Iterable[Message], state: SecAggPlusAggregatorState) -> None:
"""Intercept messages before sending."""
log(INFO, "Intercepted messages before sending.")


def on_receive(msgs: Iterable[Message], state: SecAggPlusAggregatorState) -> None:
"""Intercept reply messages after receiving."""
log(INFO, "Intercepted reply messages after receiving.")


def on_stage_complete(success: bool, state: SecAggPlusAggregatorState) -> None:
"""Handle stage completion event."""
log(INFO, "Handled stage completion event.")
Comment on lines +61 to +74
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you see these functions to be useful for (in addition to logging purposes)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, for sure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be a bit useful for some advanced user to adjust the protocol, but generally it's more for me to implement things like weighted averaging based on the original SecAggPlusInsAggregator.

Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@

from logging import INFO

from secaggexample.task import get_weights, make_net

import flwr.common.recordset_compat as compat
from flwr.common import Context, log, parameters_to_ndarrays, Message
from flwr.common import Context, Message, log, parameters_to_ndarrays
from flwr.common.secure_aggregation.quantization import quantize
from flwr.common.secure_aggregation.secaggplus_constants import Stage
from flwr.server import Driver, LegacyContext
from flwr.server.workflow.constant import MAIN_PARAMS_RECORD
from flwr.server.workflow.secure_aggregation.secaggplus_workflow import (
SecAggPlusWorkflow,
)
from flwr.common.secure_aggregation.secaggplus_constants import Stage
from flwr.server.workflow.secure_aggregation.secaggplus_aggregator import (
SecAggPlusAggregatorState,
)

from secaggexample.task import get_weights, make_net
from flwr.server.workflow.secure_aggregation.secaggplus_workflow import (
SecAggPlusWorkflow,
)


class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow):
Expand Down