From efb259e4cf6f941857bfe330924cca74d11e7f83 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Mar 2024 11:14:01 +0000 Subject: [PATCH 01/20] Create Secure Aggregation example (#3091) Co-authored-by: Daniel J. Beutel --- examples/app-pytorch/client.py | 13 +- examples/app-pytorch/server.py | 15 +- examples/app-secure-aggregation/README.md | 93 +++++ examples/app-secure-aggregation/client.py | 34 ++ .../pyproject.toml | 9 +- .../app-secure-aggregation/requirements.txt | 1 + .../run.sh | 6 +- examples/app-secure-aggregation/server.py | 45 +++ .../workflow_with_log.py | 90 +++++ examples/secaggplus-mt/README.md | 36 -- examples/secaggplus-mt/client.py | 48 --- examples/secaggplus-mt/driver.py | 204 ---------- examples/secaggplus-mt/requirements.txt | 1 - examples/secaggplus-mt/workflows.py | 369 ------------------ 14 files changed, 289 insertions(+), 675 deletions(-) create mode 100644 examples/app-secure-aggregation/README.md create mode 100644 examples/app-secure-aggregation/client.py rename examples/{secaggplus-mt => app-secure-aggregation}/pyproject.toml (50%) create mode 100644 examples/app-secure-aggregation/requirements.txt rename examples/{secaggplus-mt => app-secure-aggregation}/run.sh (86%) create mode 100644 examples/app-secure-aggregation/server.py create mode 100644 examples/app-secure-aggregation/workflow_with_log.py delete mode 100644 examples/secaggplus-mt/README.md delete mode 100644 examples/secaggplus-mt/client.py delete mode 100644 examples/secaggplus-mt/driver.py delete mode 100644 examples/secaggplus-mt/requirements.txt delete mode 100644 examples/secaggplus-mt/workflows.py diff --git a/examples/app-pytorch/client.py b/examples/app-pytorch/client.py index 64fb71e9917a..ebbe977ecab1 100644 --- a/examples/app-pytorch/client.py +++ b/examples/app-pytorch/client.py @@ -16,7 +16,7 @@ trainloader, testloader = load_data() -# Define Flower client +# Define FlowerClient and client_fn class FlowerClient(NumPyClient): def fit(self, parameters, config): @@ -31,16 +31,21 @@ def evaluate(self, parameters, config): def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" return FlowerClient().to_client() -# Run via `flower-client-app client:app` -app = ClientApp(client_fn=client_fn) +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, +) # Legacy mode if __name__ == "__main__": - fl.client.start_client( + from flwr.client import start_client + + start_client( server_address="127.0.0.1:8080", client=FlowerClient().to_client(), ) diff --git a/examples/app-pytorch/server.py b/examples/app-pytorch/server.py index 5c22d334e0f5..0b4ad1ddba46 100644 --- a/examples/app-pytorch/server.py +++ b/examples/app-pytorch/server.py @@ -1,6 +1,7 @@ from typing import List, Tuple -import flwr as fl +from flwr.server import ServerApp, ServerConfig +from flwr.server.strategy import FedAvg from flwr.common import Metrics, ndarrays_to_parameters from task import Net, get_weights @@ -33,7 +34,7 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # Define strategy -strategy = fl.server.strategy.FedAvg( +strategy = FedAvg( fraction_fit=1.0, # Select all available clients fraction_evaluate=0.0, # Disable evaluation min_available_clients=2, @@ -43,11 +44,11 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # Define config -config = fl.server.ServerConfig(num_rounds=3) +config = ServerConfig(num_rounds=3) -# Run via `flower-server-app server:app` -app = fl.server.ServerApp( +# Flower ServerApp +app = ServerApp( config=config, strategy=strategy, ) @@ -55,7 +56,9 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: # Legacy mode if __name__ == "__main__": - fl.server.start_server( + from flwr.server import start_server + + start_server( server_address="0.0.0.0:8080", config=config, strategy=strategy, diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md new file mode 100644 index 000000000000..bdc95c5b62fc --- /dev/null +++ b/examples/app-secure-aggregation/README.md @@ -0,0 +1,93 @@ +# Secure aggregation with Flower (the SecAgg+ protocol) ๐Ÿงช + +> ๐Ÿงช = This example covers experimental features that might change in future versions of Flower +> Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. + +The following steps describe how to use Secure Aggregation in flower, with `ClientApp` using `secaggplus_mod` and `ServerApp` using `SecAggPlusWorkflow`. + +## Preconditions + +Let's assume the following project structure: + +```bash +$ tree . +. +โ”œโ”€โ”€ client.py # Client application using `secaggplus_mod` +โ”œโ”€โ”€ server.py # Server application using `SecAggPlusWorkflow` +โ”œโ”€โ”€ workflow_with_log.py # Augmented `SecAggPlusWorkflow` +โ”œโ”€โ”€ run.sh # Quick start script +โ”œโ”€โ”€ pyproject.toml # Project dependencies (poetry) +โ””โ”€โ”€ requirements.txt # Project dependencies (pip) +``` + +## Installing dependencies + +Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. + +### Poetry + +```shell +poetry install +poetry shell +``` + +Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: + +```shell +poetry run python3 -c "import flwr" +``` + +### pip + +Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. + +```shell +pip install -r requirements.txt +``` + +If you don't see any errors you're good to go! + +## Run the example with one command (recommended) + +```bash +./run.sh +``` + +## Run the example with the simulation engine + +```bash +flower-simulation --server-app server:app --client-app client:app --num-supernodes 5 +``` + +## Alternatively, run the example (in 7 terminal windows) + +Start the Flower Superlink in one terminal window: + +```bash +flower-superlink --insecure +``` + +Start 5 Flower `ClientApp` in 5 separate terminal windows: + +```bash +flower-client-app client:app --insecure +``` + +Start the Flower `ServerApp`: + +```bash +flower-server-app server:app --insecure +``` + +## Amend the example for practical usage + +For real-world applications, modify the `workflow` in `server.py` as follows: + +```python +workflow = fl.server.workflow.DefaultWorkflow( + fit_workflow=SecAggPlusWorkflow( + num_shares=, + reconstruction_threshold=, + ) +) +``` diff --git a/examples/app-secure-aggregation/client.py b/examples/app-secure-aggregation/client.py new file mode 100644 index 000000000000..b2fd02ec00d4 --- /dev/null +++ b/examples/app-secure-aggregation/client.py @@ -0,0 +1,34 @@ +import time + +from flwr.client import ClientApp, NumPyClient +from flwr.client.mod import secaggplus_mod +import numpy as np + + +# Define FlowerClient and client_fn +class FlowerClient(NumPyClient): + def fit(self, parameters, config): + # Instead of training and returning model parameters, + # the client directly returns [1.0, 1.0, 1.0] for demonstration purposes. + ret_vec = [np.ones(3)] + # Force a significant delay for testing purposes + if "drop" in config and config["drop"]: + print(f"Client dropped for testing purposes.") + time.sleep(8) + else: + print(f"Client uploading {ret_vec[0]}...") + return ret_vec, 1, {} + + +def client_fn(cid: str): + """Create and return an instance of Flower `Client`.""" + return FlowerClient().to_client() + + +# Flower ClientApp +app = ClientApp( + client_fn=client_fn, + mods=[ + secaggplus_mod, + ], +) diff --git a/examples/secaggplus-mt/pyproject.toml b/examples/app-secure-aggregation/pyproject.toml similarity index 50% rename from examples/secaggplus-mt/pyproject.toml rename to examples/app-secure-aggregation/pyproject.toml index fe6fc67252b8..84b6502064c8 100644 --- a/examples/secaggplus-mt/pyproject.toml +++ b/examples/app-secure-aggregation/pyproject.toml @@ -3,11 +3,12 @@ requires = ["poetry-core>=1.4.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "secaggplus-mt" +name = "app-secure-aggregation" version = "0.1.0" -description = "Secure Aggregation with Driver API" +description = "Flower Secure Aggregation example." authors = ["The Flower Authors "] [tool.poetry.dependencies] -python = ">=3.8,<3.11" -flwr-nightly = { version = "^1.5.0.dev20230629", extras = ["simulation", "rest"] } +python = "^3.8" +# Mandatory dependencies +flwr-nightly = { version = "1.8.0.dev20240309", extras = ["simulation"] } diff --git a/examples/app-secure-aggregation/requirements.txt b/examples/app-secure-aggregation/requirements.txt new file mode 100644 index 000000000000..5bac63a0d44c --- /dev/null +++ b/examples/app-secure-aggregation/requirements.txt @@ -0,0 +1 @@ +flwr-nightly[simulation]==1.8.0.dev20240309 diff --git a/examples/secaggplus-mt/run.sh b/examples/app-secure-aggregation/run.sh similarity index 86% rename from examples/secaggplus-mt/run.sh rename to examples/app-secure-aggregation/run.sh index 659c1aaee8ce..834699692ada 100755 --- a/examples/secaggplus-mt/run.sh +++ b/examples/app-secure-aggregation/run.sh @@ -13,7 +13,7 @@ sleep 2 # Number of client processes to start N=5 # Replace with your desired value -echo "Starting $N clients in background..." +echo "Starting $N ClientApps in background..." # Start N client processes for i in $(seq 1 $N) @@ -22,8 +22,8 @@ do sleep 0.1 done -echo "Starting driver..." -python driver.py +echo "Starting ServerApp..." +flower-server-app --insecure server:app echo "Clearing background processes..." diff --git a/examples/app-secure-aggregation/server.py b/examples/app-secure-aggregation/server.py new file mode 100644 index 000000000000..e9737a5a3c7f --- /dev/null +++ b/examples/app-secure-aggregation/server.py @@ -0,0 +1,45 @@ +from flwr.common import Context +from flwr.server import Driver, LegacyContext, ServerApp, ServerConfig +from flwr.server.strategy import FedAvg +from flwr.server.workflow import DefaultWorkflow, SecAggPlusWorkflow + +from workflow_with_log import SecAggPlusWorkflowWithLogs + + +# Define strategy +strategy = FedAvg( + fraction_fit=1.0, # Select all available clients + fraction_evaluate=0.0, # Disable evaluation + min_available_clients=5, +) + + +# Flower ServerApp +app = ServerApp() + + +@app.main() +def main(driver: Driver, context: Context) -> None: + # Construct the LegacyContext + context = LegacyContext( + state=context.state, + config=ServerConfig(num_rounds=3), + strategy=strategy, + ) + + # Create the workflow + workflow = DefaultWorkflow( + fit_workflow=SecAggPlusWorkflowWithLogs( + num_shares=3, + reconstruction_threshold=2, + timeout=5, + ) + # # For real-world applications, use the following code instead + # fit_workflow=SecAggPlusWorkflow( + # num_shares=, + # reconstruction_threshold=, + # ) + ) + + # Execute + workflow(driver, context) diff --git a/examples/app-secure-aggregation/workflow_with_log.py b/examples/app-secure-aggregation/workflow_with_log.py new file mode 100644 index 000000000000..5a83bd69190f --- /dev/null +++ b/examples/app-secure-aggregation/workflow_with_log.py @@ -0,0 +1,90 @@ +from flwr.server import Driver, LegacyContext +from flwr.server.workflow.secure_aggregation.secaggplus_workflow import ( + SecAggPlusWorkflow, + WorkflowState, +) +import numpy as np +from flwr.common.secure_aggregation.quantization import quantize + + +class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): + """The SecAggPlusWorkflow augmented for this example. + + This class includes additional logging and modifies one of the FitIns to instruct + the target client to simulate a dropout. + """ + + node_ids = [] + + def setup_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + _quantized = quantize( + [np.ones(3) for _ in range(5)], self.clipping_range, self.quantization_range + ) + print( + "\n\n################################ Introduction ################################\n" + "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of\n" + "model updates for demonstration purposes.\n" + "Client 0 is configured to drop out before uploading the masked vector.\n" + f"After quantization, the raw vectors will look like:" + ) + for i in range(1, 5): + print(f"\t{_quantized[i]} from Client {i}") + print( + f"Numbers are rounded to integers stochastically during the quantization\n" + ", and thus entries may not be identical." + ) + print( + "The above raw vectors are hidden from the driver through adding masks.\n" + ) + print( + "########################## Secure Aggregation Start ##########################" + ) + print(f"Sending configurations to 5 clients...") + ret = super().setup_stage(driver, context, state) + print(f"Received public keys from {len(state.active_node_ids)} clients.") + self.node_ids = list(state.active_node_ids) + state.nid_to_fitins[self.node_ids[0]].configs_records["fitins.config"][ + "drop" + ] = True + return ret + + def share_keys_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + print(f"\nForwarding public keys...") + ret = super().share_keys_stage(driver, context, state) + print( + f"Received encrypted key shares from {len(state.active_node_ids)} clients." + ) + return ret + + def collect_masked_input_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + print(f"\nForwarding encrypted key shares and requesting masked vectors...") + ret = super().collect_masked_input_stage(driver, context, state) + for node_id in state.sampled_node_ids - state.active_node_ids: + print(f"Client {self.node_ids.index(node_id)} dropped out.") + for node_id in state.active_node_ids: + print( + f"Received masked vectors from Client {self.node_ids.index(node_id)}." + ) + print(f"Obtained sum of masked vectors: {state.aggregate_ndarrays[1]}") + return ret + + def unmask_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: + print("\nRequesting key shares to unmask the aggregate vector...") + ret = super().unmask_stage(driver, context, state) + print(f"Received key shares from {len(state.active_node_ids)} clients.") + + print( + f"Weighted average of vectors (dequantized): {state.aggregate_ndarrays[0]}" + ) + print( + "########################### Secure Aggregation End ###########################\n\n" + ) + return ret diff --git a/examples/secaggplus-mt/README.md b/examples/secaggplus-mt/README.md deleted file mode 100644 index 0b3b4db3942e..000000000000 --- a/examples/secaggplus-mt/README.md +++ /dev/null @@ -1,36 +0,0 @@ -# Secure Aggregation with Driver API - -This example contains highly experimental code. Please consult the regular PyTorch code examples ([quickstart](https://github.com/adap/flower/tree/main/examples/quickstart-pytorch), [advanced](https://github.com/adap/flower/tree/main/examples/advanced-pytorch)) to learn how to use Flower with PyTorch. - -## Installing Dependencies - -Project dependencies (such as and `flwr`) are defined in `pyproject.toml`. We recommend [Poetry](https://python-poetry.org/docs/) to install those dependencies and manage your virtual environment ([Poetry installation](https://python-poetry.org/docs/#installation)), but feel free to use a different way of installing dependencies and managing virtual environments if you have other preferences. - -### Poetry - -```shell -poetry install -poetry shell -``` - -Poetry will install all your dependencies in a newly created virtual environment. To verify that everything works correctly you can run the following command: - -```shell -poetry run python3 -c "import flwr" -``` - -### pip - -Write the command below in your terminal to install the dependencies according to the configuration file requirements.txt. - -```shell -pip install -r requirements.txt -``` - -If you don't see any errors you're good to go! - -## Run with Driver API - -```bash -./run.sh -``` diff --git a/examples/secaggplus-mt/client.py b/examples/secaggplus-mt/client.py deleted file mode 100644 index 164a261213be..000000000000 --- a/examples/secaggplus-mt/client.py +++ /dev/null @@ -1,48 +0,0 @@ -import time - -import numpy as np - -import flwr as fl -from flwr.common import Status, FitIns, FitRes, Code -from flwr.common.parameter import ndarrays_to_parameters -from flwr.client.mod import secaggplus_mod - - -# Define Flower client with the SecAgg+ protocol -class FlowerClient(fl.client.Client): - def fit(self, fit_ins: FitIns) -> FitRes: - ret_vec = [np.ones(3)] - ret = FitRes( - status=Status(code=Code.OK, message="Success"), - parameters=ndarrays_to_parameters(ret_vec), - num_examples=1, - metrics={}, - ) - # Force a significant delay for testing purposes - if fit_ins.config["drop"]: - print(f"Client dropped for testing purposes.") - time.sleep(4) - return ret - print(f"Client uploading {ret_vec[0]}...") - return ret - - -def client_fn(cid: str): - """.""" - return FlowerClient().to_client() - - -# To run this: `flower-client-app client:app` -app = fl.client.ClientApp( - client_fn=client_fn, - mods=[secaggplus_mod], -) - - -if __name__ == "__main__": - # Start Flower client - fl.client.start_client( - server_address="0.0.0.0:9092", - client=FlowerClient(), - transport="grpc-rere", - ) diff --git a/examples/secaggplus-mt/driver.py b/examples/secaggplus-mt/driver.py deleted file mode 100644 index 42559c2f4a21..000000000000 --- a/examples/secaggplus-mt/driver.py +++ /dev/null @@ -1,204 +0,0 @@ -import random -import time -from typing import Dict, List, Tuple - -import numpy as np -from workflows import get_workflow_factory - -from flwr.common import Metrics, ndarrays_to_parameters -from flwr.server.driver import GrpcDriver -from flwr.proto import driver_pb2, node_pb2, task_pb2 -from flwr.server import History - - -# Convert instruction/result dict to/from list of TaskIns/TaskRes -def task_dict_to_task_ins_list( - task_dict: Dict[int, task_pb2.Task] -) -> List[task_pb2.TaskIns]: - def merge(_task: task_pb2.Task, _merge_task: task_pb2.Task) -> task_pb2.Task: - _task.MergeFrom(_merge_task) - return _task - - return [ - task_pb2.TaskIns( - task_id="", # Do not set, will be created and set by the DriverAPI - group_id="", - run_id=run_id, - task=merge( - task, - task_pb2.Task( - producer=node_pb2.Node( - node_id=0, - anonymous=True, - ), - consumer=node_pb2.Node( - node_id=sampled_node_id, - # Must be False for this Secure Aggregation example - anonymous=False, - ), - ), - ), - ) - for sampled_node_id, task in task_dict.items() - ] - - -def task_res_list_to_task_dict( - task_res_list: List[task_pb2.TaskRes], -) -> Dict[int, task_pb2.Task]: - return {task_res.task.producer.node_id: task_res.task for task_res in task_res_list} - - -# Define metric aggregation function -def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: - examples = [num_examples for num_examples, _ in metrics] - - # Multiply accuracy of each client by number of examples used - train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics] - train_accuracies = [ - num_examples * m["train_accuracy"] for num_examples, m in metrics - ] - val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics] - val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics] - - # Aggregate and return custom metric (weighted average) - return { - "train_loss": sum(train_losses) / sum(examples), - "train_accuracy": sum(train_accuracies) / sum(examples), - "val_loss": sum(val_losses) / sum(examples), - "val_accuracy": sum(val_accuracies) / sum(examples), - } - - -# -------------------------------------------------------------------------- Driver SDK -driver = GrpcDriver(driver_service_address="0.0.0.0:9091", root_certificates=None) -# -------------------------------------------------------------------------- Driver SDK - -anonymous_client_nodes = False -num_client_nodes_per_round = 5 -sleep_time = 0.5 -time_out = 3.9 -num_rounds = 3 -parameters = ndarrays_to_parameters([np.ones(3)]) -wf_factory = get_workflow_factory() - -# -------------------------------------------------------------------------- Driver SDK -driver.connect() -create_run_res: driver_pb2.CreateRunResponse = driver.create_run( - req=driver_pb2.CreateRunRequest() -) -# -------------------------------------------------------------------------- Driver SDK - -run_id = create_run_res.run_id -print(f"Created run id {run_id}") - -history = History() -for server_round in range(num_rounds): - print(f"Commencing server round {server_round + 1}") - - # List of sampled node IDs in this round - sampled_node_ids: List[int] = [] - - # Sample node ids - if anonymous_client_nodes: - # If we're working with anonymous clients, we don't know their identities, and - # we don't know how many of them we have. We, therefore, have to assume that - # enough anonymous client nodes are available or become available over time. - # - # To schedule a TaskIns for an anonymous client node, we set the node_id to 0 - # (and `anonymous` to True) - # Here, we create an array with only zeros in it: - sampled_node_ids = [0] * num_client_nodes_per_round - else: - # If our client nodes have identiy (i.e., they are not anonymous), we can get - # those IDs from the Driver API using `get_nodes`. If enough clients are - # available via the Driver API, we can select a subset by taking a random - # sample. - # - # The Driver API might not immediately return enough client node IDs, so we - # loop and wait until enough client nodes are available. - while True: - # Get a list of node ID's from the server - get_nodes_req = driver_pb2.GetNodesRequest(run_id=run_id) - - # ---------------------------------------------------------------------- Driver SDK - get_nodes_res: driver_pb2.GetNodesResponse = driver.get_nodes( - req=get_nodes_req - ) - # ---------------------------------------------------------------------- Driver SDK - - all_node_ids: List[int] = [node.node_id for node in get_nodes_res.nodes] - - if len(all_node_ids) >= num_client_nodes_per_round: - # Sample client nodes - sampled_node_ids = random.sample( - all_node_ids, num_client_nodes_per_round - ) - break - - time.sleep(3) - - # Log sampled node IDs - time.sleep(sleep_time) - - workflow = wf_factory(parameters, sampled_node_ids) - node_messages = None - - while True: - try: - instructions: Dict[int, task_pb2.Task] = workflow.send(node_messages) - next(workflow) - except StopIteration: - break - # Schedule a task for all sampled nodes - task_ins_list: List[task_pb2.TaskIns] = task_dict_to_task_ins_list(instructions) - - push_task_ins_req = driver_pb2.PushTaskInsRequest(task_ins_list=task_ins_list) - - # ---------------------------------------------------------------------- Driver SDK - push_task_ins_res: driver_pb2.PushTaskInsResponse = driver.push_task_ins( - req=push_task_ins_req - ) - # ---------------------------------------------------------------------- Driver SDK - - time.sleep(sleep_time) - - # Wait for results, ignore empty task_ids - start_time = time.time() - task_ids: List[str] = [ - task_id for task_id in push_task_ins_res.task_ids if task_id != "" - ] - all_task_res: List[task_pb2.TaskRes] = [] - while True: - if time.time() - start_time >= time_out: - break - pull_task_res_req = driver_pb2.PullTaskResRequest( - node=node_pb2.Node(node_id=0, anonymous=True), - task_ids=task_ids, - ) - - # ------------------------------------------------------------------ Driver SDK - pull_task_res_res: driver_pb2.PullTaskResResponse = driver.pull_task_res( - req=pull_task_res_req - ) - # ------------------------------------------------------------------ Driver SDK - - task_res_list: List[task_pb2.TaskRes] = pull_task_res_res.task_res_list - - time.sleep(sleep_time) - - all_task_res += task_res_list - if len(all_task_res) == len(task_ids): - break - - # Collect correct results - node_messages = task_res_list_to_task_dict(all_task_res) - workflow.close() - - # Slow down the start of the next round - time.sleep(sleep_time) - -# -------------------------------------------------------------------------- Driver SDK -driver.disconnect() -# -------------------------------------------------------------------------- Driver SDK -print("Driver disconnected") diff --git a/examples/secaggplus-mt/requirements.txt b/examples/secaggplus-mt/requirements.txt deleted file mode 100644 index eeed6941afc7..000000000000 --- a/examples/secaggplus-mt/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -flwr-nightly[simulation,rest] diff --git a/examples/secaggplus-mt/workflows.py b/examples/secaggplus-mt/workflows.py deleted file mode 100644 index 4177e0e4278f..000000000000 --- a/examples/secaggplus-mt/workflows.py +++ /dev/null @@ -1,369 +0,0 @@ -import random -from logging import WARNING -from typing import Callable, Dict, Generator, List, Optional - -import numpy as np - -from flwr.common import ( - Parameters, - Scalar, - bytes_to_ndarray, - log, - ndarray_to_bytes, - ndarrays_to_parameters, - parameters_to_ndarrays, -) -from flwr.common.secure_aggregation.crypto.shamir import combine_shares -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - bytes_to_private_key, - bytes_to_public_key, - generate_shared_key, -) -from flwr.common.secure_aggregation.ndarrays_arithmetic import ( - factor_extract, - get_parameters_shape, - get_zero_parameters, - parameters_addition, - parameters_mod, - parameters_subtraction, -) -from flwr.common.secure_aggregation.quantization import dequantize, quantize -from flwr.common.secure_aggregation.secaggplus_constants import ( - Key, - Stage, - RECORD_KEY_CONFIGS, -) -from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.common.typing import ConfigsRecordValues, FitIns -from flwr.proto.task_pb2 import Task -from flwr.common import serde -from flwr.common.constant import MessageType -from flwr.common import RecordSet -from flwr.common import recordset_compat as compat -from flwr.common import ConfigsRecord - - -LOG_EXPLAIN = True - - -def get_workflow_factory() -> ( - Callable[[Parameters, List[int]], Generator[Dict[int, Task], Dict[int, Task], None]] -): - return _wrap_workflow_with_sec_agg - - -def _wrap_in_task( - named_values: Dict[str, ConfigsRecordValues], fit_ins: Optional[FitIns] = None -) -> Task: - if fit_ins is not None: - recordset = compat.fitins_to_recordset(fit_ins, keep_input=True) - else: - recordset = RecordSet() - recordset.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(named_values) - return Task( - task_type=MessageType.TRAIN, - recordset=serde.recordset_to_proto(recordset), - ) - - -def _get_from_task(task: Task) -> Dict[str, ConfigsRecordValues]: - recordset = serde.recordset_from_proto(task.recordset) - return recordset.configs_records[RECORD_KEY_CONFIGS] - - -_secure_aggregation_configuration = { - Key.SHARE_NUMBER: 3, - Key.THRESHOLD: 2, - Key.CLIPPING_RANGE: 3.0, - Key.TARGET_RANGE: 1 << 20, - Key.MOD_RANGE: 1 << 30, -} - - -def workflow_with_sec_agg( - parameters: Parameters, - sampled_node_ids: List[int], - sec_agg_config: Dict[str, Scalar], -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - """ - =============== Setup stage =============== - """ - # Protocol config - num_samples = len(sampled_node_ids) - num_shares = sec_agg_config[Key.SHARE_NUMBER] - threshold = sec_agg_config[Key.THRESHOLD] - mod_range = sec_agg_config[Key.MOD_RANGE] - # Quantization config - clipping_range = sec_agg_config[Key.CLIPPING_RANGE] - target_range = sec_agg_config[Key.TARGET_RANGE] - - if LOG_EXPLAIN: - _quantized = quantize( - [np.ones(3) for _ in range(num_samples)], clipping_range, target_range - ) - print( - "\n\n################################ Introduction ################################\n" - "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of\n" - "model updates for demonstration purposes.\n" - "Client 0 is configured to drop out before uploading the masked vector.\n" - f"After quantization, the raw vectors will be:" - ) - for i in range(1, num_samples): - print(f"\t{_quantized[i]} from Client {i}") - print( - f"Numbers are rounded to integers stochastically during the quantization\n" - ", and thus not all entries are identical." - ) - print( - "The above raw vectors are hidden from the driver through adding masks.\n" - ) - print( - "########################## Secure Aggregation Start ##########################" - ) - cfg = { - Key.STAGE: Stage.SETUP, - Key.SAMPLE_NUMBER: num_samples, - Key.SHARE_NUMBER: num_shares, - Key.THRESHOLD: threshold, - Key.CLIPPING_RANGE: clipping_range, - Key.TARGET_RANGE: target_range, - Key.MOD_RANGE: mod_range, - } - # The number of shares should better be odd in the SecAgg+ protocol. - if num_samples != num_shares and num_shares & 0x1 == 0: - log(WARNING, "Number of shares in the SecAgg+ protocol should be odd.") - num_shares += 1 - - # Randomly assign secure IDs to clients - sids = [i for i in range(len(sampled_node_ids))] - random.shuffle(sids) - nid2sid = dict(zip(sampled_node_ids, sids)) - sid2nid = {sid: nid for nid, sid in nid2sid.items()} - # Build neighbour relations (node ID -> secure IDs of neighbours) - half_share = num_shares >> 1 - nid2neighbours = { - node_id: { - (nid2sid[node_id] + offset) % num_samples - for offset in range(-half_share, half_share + 1) - } - for node_id in sampled_node_ids - } - - surviving_node_ids = sampled_node_ids - if LOG_EXPLAIN: - print( - f"Sending configurations to {num_samples} clients and allocating secure IDs..." - ) - # Send setup configuration to clients - yield { - node_id: _wrap_in_task( - named_values={ - **cfg, - Key.SECURE_ID: nid2sid[node_id], - } - ) - for node_id in surviving_node_ids - } - # Receive public keys from clients and build the dict - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - - if LOG_EXPLAIN: - print(f"Received public keys from {len(surviving_node_ids)} clients.") - - sid2public_keys = {} - for node_id, task in node_messages.items(): - key_dict = _get_from_task(task) - pk1, pk2 = key_dict[Key.PUBLIC_KEY_1], key_dict[Key.PUBLIC_KEY_2] - sid2public_keys[nid2sid[node_id]] = [pk1, pk2] - - """ - =============== Share keys stage =============== - """ - if LOG_EXPLAIN: - print(f"\nForwarding public keys...") - # Broadcast public keys to clients - yield { - node_id: _wrap_in_task( - named_values={ - Key.STAGE: Stage.SHARE_KEYS, - **{ - str(sid): value - for sid, value in sid2public_keys.items() - if sid in nid2neighbours[node_id] - }, - } - ) - for node_id in surviving_node_ids - } - - # Receive secret key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received encrypted key shares from {len(surviving_node_ids)} clients.") - # Build forward packet list dictionary - srcs, dsts, ciphertexts = [], [], [] - fwd_ciphertexts: Dict[int, List[bytes]] = { - nid2sid[nid]: [] for nid in surviving_node_ids - } # dest secure ID -> list of ciphertexts - fwd_srcs: Dict[int, List[bytes]] = { - sid: [] for sid in fwd_ciphertexts - } # dest secure ID -> list of src secure IDs - for node_id, task in node_messages.items(): - res_dict = _get_from_task(task) - srcs += [nid2sid[node_id]] * len(res_dict[Key.DESTINATION_LIST]) - dsts += res_dict[Key.DESTINATION_LIST] - ciphertexts += res_dict[Key.CIPHERTEXT_LIST] - - for src, dst, ciphertext in zip(srcs, dsts, ciphertexts): - if dst in fwd_ciphertexts: - fwd_ciphertexts[dst].append(ciphertext) - fwd_srcs[dst].append(src) - - """ - =============== Collect masked input stage =============== - """ - - if LOG_EXPLAIN: - print(f"\nForwarding encrypted key shares and requesting masked input...") - # Send encrypted secret key shares to clients (plus model parameters) - yield { - node_id: _wrap_in_task( - named_values={ - Key.STAGE: Stage.COLLECT_MASKED_INPUT, - Key.CIPHERTEXT_LIST: fwd_ciphertexts[nid2sid[node_id]], - Key.SOURCE_LIST: fwd_srcs[nid2sid[node_id]], - }, - fit_ins=FitIns( - parameters=parameters, config={"drop": nid2sid[node_id] == 0} - ), - ) - for node_id in surviving_node_ids - } - # Collect masked input from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - # Get shape of vector sent by first client - weights = parameters_to_ndarrays(parameters) - masked_vector = [np.array([0], dtype=int)] + get_zero_parameters( - [w.shape for w in weights] - ) - # Add all collected masked vectors and compuute available and dropout clients set - dead_sids = { - nid2sid[node_id] - for node_id in sampled_node_ids - if node_id not in surviving_node_ids - } - active_sids = {nid2sid[node_id] for node_id in surviving_node_ids} - if LOG_EXPLAIN: - for sid in dead_sids: - print(f"Client {sid} dropped out.") - for node_id, task in node_messages.items(): - named_values = _get_from_task(task) - client_masked_vec = named_values[Key.MASKED_PARAMETERS] - client_masked_vec = [bytes_to_ndarray(b) for b in client_masked_vec] - if LOG_EXPLAIN: - print(f"Received {client_masked_vec[1]} from Client {nid2sid[node_id]}.") - masked_vector = parameters_addition(masked_vector, client_masked_vec) - masked_vector = parameters_mod(masked_vector, mod_range) - """ - =============== Unmask stage =============== - """ - - if LOG_EXPLAIN: - print("\nRequesting key shares to unmask the aggregate vector...") - # Send secure IDs of active and dead clients. - yield { - node_id: _wrap_in_task( - named_values={ - Key.STAGE: Stage.UNMASK, - Key.DEAD_SECURE_ID_LIST: list(dead_sids & nid2neighbours[node_id]), - Key.ACTIVE_SECURE_ID_LIST: list(active_sids & nid2neighbours[node_id]), - } - ) - for node_id in surviving_node_ids - } - # Collect key shares from clients - node_messages = yield - surviving_node_ids = [node_id for node_id in node_messages] - if LOG_EXPLAIN: - print(f"Received key shares from {len(surviving_node_ids)} clients.") - # Build collected shares dict - collected_shares_dict: Dict[int, List[bytes]] = {} - for nid in sampled_node_ids: - collected_shares_dict[nid2sid[nid]] = [] - - if len(surviving_node_ids) < threshold: - raise Exception("Not enough available clients after unmask vectors stage") - for _, task in node_messages.items(): - named_values = _get_from_task(task) - for owner_sid, share in zip( - named_values[Key.SECURE_ID_LIST], named_values[Key.SHARE_LIST] - ): - collected_shares_dict[owner_sid].append(share) - # Remove mask for every client who is available before ask vectors stage, - # divide vector by first element - active_sids, dead_sids = set(active_sids), set(dead_sids) - for sid, share_list in collected_shares_dict.items(): - if len(share_list) < threshold: - raise Exception( - "Not enough shares to recover secret in unmask vectors stage" - ) - secret = combine_shares(share_list) - if sid in active_sids: - # The seed for PRG is the private mask seed of an active client. - private_mask = pseudo_rand_gen( - secret, mod_range, get_parameters_shape(masked_vector) - ) - masked_vector = parameters_subtraction(masked_vector, private_mask) - else: - # The seed for PRG is the secret key 1 of a dropped client. - neighbor_list = list(nid2neighbours[sid2nid[sid]]) - neighbor_list.remove(sid) - - for neighbor_sid in neighbor_list: - shared_key = generate_shared_key( - bytes_to_private_key(secret), - bytes_to_public_key(sid2public_keys[neighbor_sid][0]), - ) - pairwise_mask = pseudo_rand_gen( - shared_key, mod_range, get_parameters_shape(masked_vector) - ) - if sid > neighbor_sid: - masked_vector = parameters_addition(masked_vector, pairwise_mask) - else: - masked_vector = parameters_subtraction(masked_vector, pairwise_mask) - recon_parameters = parameters_mod(masked_vector, mod_range) - # Divide vector by number of clients who have given us their masked vector - # i.e. those participating in final unmask vectors stage - total_weights_factor, recon_parameters = factor_extract(recon_parameters) - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (quantized): {recon_parameters[0]}") - # recon_parameters = parameters_divide(recon_parameters, total_weights_factor) - aggregated_vector = dequantize( - quantized_parameters=recon_parameters, - clipping_range=clipping_range, - target_range=target_range, - ) - aggregated_vector[0] -= (len(active_sids) - 1) * clipping_range - if LOG_EXPLAIN: - print(f"Unmasked sum of vectors (dequantized): {aggregated_vector[0]}") - print( - f"Aggregate vector using FedAvg: {aggregated_vector[0] / len(active_sids)}" - ) - print( - "########################### Secure Aggregation End ###########################\n\n" - ) - aggregated_parameters = ndarrays_to_parameters(aggregated_vector) - # Update model parameters - parameters.tensors = aggregated_parameters.tensors - parameters.tensor_type = aggregated_parameters.tensor_type - - -def _wrap_workflow_with_sec_agg( - parameters: Parameters, sampled_node_ids: List[int] -) -> Generator[Dict[int, Task], Dict[int, Task], None]: - return workflow_with_sec_agg( - parameters, sampled_node_ids, sec_agg_config=_secure_aggregation_configuration - ) From aba7c08ba279534e73df05e7c75c066d307b41bb Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Mar 2024 11:36:00 +0000 Subject: [PATCH 02/20] Update app-pytorch with server-side parameter initialization (#3099) --- examples/app-pytorch/server_workflow.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/app-pytorch/server_workflow.py b/examples/app-pytorch/server_workflow.py index 920e266c99e9..6923010ecf7b 100644 --- a/examples/app-pytorch/server_workflow.py +++ b/examples/app-pytorch/server_workflow.py @@ -1,7 +1,9 @@ from typing import List, Tuple +from task import Net, get_weights + import flwr as fl -from flwr.common import Context, Metrics +from flwr.common import Context, Metrics, ndarrays_to_parameters from flwr.server import Driver, LegacyContext @@ -26,12 +28,18 @@ def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics: } +# Initialize model parameters +ndarrays = get_weights(Net()) +parameters = ndarrays_to_parameters(ndarrays) + + # Define strategy strategy = fl.server.strategy.FedAvg( fraction_fit=1.0, # Select all available clients fraction_evaluate=0.0, # Disable evaluation min_available_clients=2, fit_metrics_aggregation_fn=weighted_average, + initial_parameters=parameters, ) From 4d4ae6b567d0048a1e32dde8cbcc6bcdd5125e5f Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 11 Mar 2024 15:08:31 +0100 Subject: [PATCH 03/20] Use correct indent for logging feature warnings (#3104) --- src/py/flwr/common/logger.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/common/logger.py b/src/py/flwr/common/logger.py index 660710ed7f12..2bc41773ed61 100644 --- a/src/py/flwr/common/logger.py +++ b/src/py/flwr/common/logger.py @@ -168,11 +168,10 @@ def warn_experimental_feature(name: str) -> None: """Warn the user when they use an experimental feature.""" log( WARN, - """ - EXPERIMENTAL FEATURE: %s + """EXPERIMENTAL FEATURE: %s - This is an experimental feature. It could change significantly or be removed - entirely in future versions of Flower. + This is an experimental feature. It could change significantly or be removed + entirely in future versions of Flower. """, name, ) @@ -182,11 +181,10 @@ def warn_deprecated_feature(name: str) -> None: """Warn the user when they use a deprecated feature.""" log( WARN, - """ - DEPRECATED FEATURE: %s + """DEPRECATED FEATURE: %s - This is a deprecated feature. It will be removed - entirely in future versions of Flower. + This is a deprecated feature. It will be removed + entirely in future versions of Flower. """, name, ) From 34cfd1d59bca579e79015de48bce2a76a2fe9c76 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Mar 2024 14:20:31 +0000 Subject: [PATCH 04/20] Update logging in `DefaultWorkflow` (#3101) --- .../flwr/server/workflow/default_workflows.py | 73 +++++++++---------- 1 file changed, 36 insertions(+), 37 deletions(-) diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index 2b51ac1b06a6..fad85d8eecf8 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -15,8 +15,9 @@ """Legacy default workflows.""" +import io import timeit -from logging import DEBUG, INFO +from logging import INFO from typing import Optional, cast import flwr.common.recordset_compat as compat @@ -58,16 +59,18 @@ def __call__(self, driver: Driver, context: Context) -> None: ) # Initialize parameters + log(INFO, "[INIT]") default_init_params_workflow(driver, context) # Run federated learning for num_rounds - log(INFO, "FL starting") start_time = timeit.default_timer() cfg = ConfigsRecord() cfg[Key.START_TIME] = start_time context.state.configs_records[MAIN_CONFIGS_RECORD] = cfg for current_round in range(1, context.config.num_rounds + 1): + log(INFO, "") + log(INFO, "[ROUND %s]", current_round) cfg[Key.CURRENT_ROUND] = current_round # Fit round @@ -79,22 +82,19 @@ def __call__(self, driver: Driver, context: Context) -> None: # Evaluate round self.evaluate_workflow(driver, context) - # Bookkeeping + # Bookkeeping and log results end_time = timeit.default_timer() elapsed = end_time - start_time - log(INFO, "FL finished in %s", elapsed) - - # Log results hist = context.history - log(INFO, "app_fit: losses_distributed %s", str(hist.losses_distributed)) - log( - INFO, - "app_fit: metrics_distributed_fit %s", - str(hist.metrics_distributed_fit), - ) - log(INFO, "app_fit: metrics_distributed %s", str(hist.metrics_distributed)) - log(INFO, "app_fit: losses_centralized %s", str(hist.losses_centralized)) - log(INFO, "app_fit: metrics_centralized %s", str(hist.metrics_centralized)) + log(INFO, "") + log(INFO, "[SUMMARY]") + log(INFO, "Run finished %s rounds in %.2fs", context.config.num_rounds, elapsed) + for idx, line in enumerate(io.StringIO(str(hist))): + if idx == 0: + log(INFO, "%s", line.strip("\n")) + else: + log(INFO, "\t%s", line.strip("\n")) + log(INFO, "") # Terminate the thread f_stop.set() @@ -107,12 +107,11 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") - log(INFO, "Initializing global parameters") parameters = context.strategy.initialize_parameters( client_manager=context.client_manager ) if parameters is not None: - log(INFO, "Using initial parameters provided by strategy") + log(INFO, "Using initial global parameters provided by strategy") paramsrecord = compat.parameters_to_parametersrecord( parameters, keep_input=True ) @@ -140,7 +139,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord # Evaluate initial parameters - log(INFO, "Evaluating initial parameters") + log(INFO, "Evaluating initial global parameters") parameters = compat.parametersrecord_to_parameters(paramsrecord, keep_input=True) res = context.strategy.evaluate(0, parameters=parameters) if res is not None: @@ -186,7 +185,9 @@ def default_centralized_evaluation_workflow(_: Driver, context: Context) -> None ) -def default_fit_workflow(driver: Driver, context: Context) -> None: +def default_fit_workflow( # pylint: disable=R0914 + driver: Driver, context: Context +) -> None: """Execute the default workflow for a single fit round.""" if not isinstance(context, LegacyContext): raise TypeError(f"Expect a LegacyContext, but get {type(context).__name__}.") @@ -207,12 +208,11 @@ def default_fit_workflow(driver: Driver, context: Context) -> None: ) if not client_instructions: - log(INFO, "fit_round %s: no clients selected, cancel", current_round) + log(INFO, "configure_fit: no clients selected, cancel") return log( - DEBUG, - "fit_round %s: strategy sampled %s clients (out of %s)", - current_round, + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", len(client_instructions), context.client_manager.num_available(), ) @@ -236,14 +236,14 @@ def default_fit_workflow(driver: Driver, context: Context) -> None: # collect `fit` results from all clients participating in this round messages = list(driver.send_and_receive(out_messages)) del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) # No exception/failure handling currently log( - DEBUG, - "fit_round %s received %s results and %s failures", - current_round, - len(messages), - 0, + INFO, + "aggregate_fit: received %s results and %s failures", + len(messages) - num_failures, + num_failures, ) # Aggregate training results @@ -288,12 +288,11 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: client_manager=context.client_manager, ) if not client_instructions: - log(INFO, "evaluate_round %s: no clients selected, cancel", current_round) + log(INFO, "configure_evaluate: no clients selected, skipping evaluation") return log( - DEBUG, - "evaluate_round %s: strategy sampled %s clients (out of %s)", - current_round, + INFO, + "configure_evaluate: strategy sampled %s clients (out of %s)", len(client_instructions), context.client_manager.num_available(), ) @@ -317,14 +316,14 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: # collect `evaluate` results from all clients participating in this round messages = list(driver.send_and_receive(out_messages)) del out_messages + num_failures = len([msg for msg in messages if msg.has_error()]) # No exception/failure handling currently log( - DEBUG, - "evaluate_round %s received %s results and %s failures", - current_round, - len(messages), - 0, + INFO, + "aggregate_evaluate: received %s results and %s failures", + len(messages) - num_failures, + num_failures, ) # Aggregate the evaluation results From b2f968a878bd92b77e344c924f45de343ee15ba2 Mon Sep 17 00:00:00 2001 From: Taner Topal Date: Mon, 11 Mar 2024 15:31:47 +0100 Subject: [PATCH 05/20] Make the CLI prompt for project_name when not given (#3094) --- src/py/flwr/cli/new/new.py | 9 ++++++--- src/py/flwr/cli/utils.py | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/cli/new/new.py b/src/py/flwr/cli/new/new.py index 1e6739c86880..8d644391ca5b 100644 --- a/src/py/flwr/cli/new/new.py +++ b/src/py/flwr/cli/new/new.py @@ -22,7 +22,7 @@ import typer from typing_extensions import Annotated -from ..utils import prompt_options +from ..utils import prompt_options, prompt_text class MlFramework(str, Enum): @@ -72,9 +72,9 @@ def render_and_create(file_path: str, template: str, context: Dict[str, str]) -> def new( project_name: Annotated[ - str, + Optional[str], typer.Argument(metavar="project_name", help="The name of the project"), - ], + ] = None, framework: Annotated[ Optional[MlFramework], typer.Option(case_sensitive=False, help="The ML framework to use"), @@ -83,6 +83,9 @@ def new( """Create new Flower project.""" print(f"Creating Flower project {project_name}...") + if project_name is None: + project_name = prompt_text("Please provide project name") + if framework is not None: framework_str = str(framework.value) else: diff --git a/src/py/flwr/cli/utils.py b/src/py/flwr/cli/utils.py index d61189ffc4e3..4e86f0c3b8c8 100644 --- a/src/py/flwr/cli/utils.py +++ b/src/py/flwr/cli/utils.py @@ -14,11 +14,24 @@ # ============================================================================== """Flower command line interface utils.""" -from typing import List +from typing import List, cast import typer +def prompt_text(text: str) -> str: + """Ask user to enter text input.""" + while True: + result = typer.prompt( + typer.style(f"\n๐Ÿ’ฌ {text}", fg=typer.colors.MAGENTA, bold=True) + ) + if len(result) > 0: + break + print(typer.style("โŒ Invalid entry", fg=typer.colors.RED, bold=True)) + + return cast(str, result) + + def prompt_options(text: str, options: List[str]) -> str: """Ask user to select one of the given options and return the selected item.""" # Turn options into a list with index as in " [ 0] quickstart-pytorch" From 3d41776d83e0d7b06f67aeb8f5bd0f8eae712cb0 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Mon, 11 Mar 2024 14:44:51 +0000 Subject: [PATCH 06/20] Change default model and quantization value for LLM example (#3100) --- examples/llm-flowertune/README.md | 6 +++--- examples/llm-flowertune/conf/config.yaml | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/llm-flowertune/README.md b/examples/llm-flowertune/README.md index 0d5d067b2b8c..60e183d2a9c0 100644 --- a/examples/llm-flowertune/README.md +++ b/examples/llm-flowertune/README.md @@ -50,11 +50,11 @@ With an activated Python environment, run the example with default config values python main.py ``` -This command will run FL simulations with an 8-bit [OpenLLaMA 3Bv2](https://huggingface.co/openlm-research/open_llama_3b_v2) model involving 2 clients per rounds for 100 FL rounds. You can override configuration parameters directly from the command line. Below are a few settings you might want to test: +This command will run FL simulations with a 4-bit [OpenLLaMA 7Bv2](https://huggingface.co/openlm-research/open_llama_7b_v2) model involving 2 clients per rounds for 100 FL rounds. You can override configuration parameters directly from the command line. Below are a few settings you might want to test: ```bash -# Use OpenLLaMA-7B instead of 3B and 4-bits quantization -python main.py model.name="openlm-research/open_llama_7b_v2" model.quantization=4 +# Use OpenLLaMA-3B instead of 7B and 8-bits quantization +python main.py model.name="openlm-research/open_llama_3b_v2" model.quantization=8 # Run for 50 rounds but increasing the fraction of clients that participate per round to 25% python main.py num_rounds=50 fraction_fit.fraction_fit=0.25 diff --git a/examples/llm-flowertune/conf/config.yaml b/examples/llm-flowertune/conf/config.yaml index 32ab759b0ddf..0b769d351479 100644 --- a/examples/llm-flowertune/conf/config.yaml +++ b/examples/llm-flowertune/conf/config.yaml @@ -8,8 +8,8 @@ dataset: name: "vicgalle/alpaca-gpt4" model: - name: "openlm-research/open_llama_3b_v2" - quantization: 8 # 8 or 4 if you want to do quantization with BitsAndBytes + name: "openlm-research/open_llama_7b_v2" + quantization: 4 # 8 or 4 if you want to do quantization with BitsAndBytes gradient_checkpointing: True lora: peft_lora_r: 32 From bb5b7ba35d6a911e2396160b0dabbca6eef3cab9 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Mon, 11 Mar 2024 15:58:29 +0100 Subject: [PATCH 07/20] Improve flower-client-app logs (#3105) --- examples/app-pytorch/task.py | 4 +++- src/py/flwr/client/app.py | 27 +++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/examples/app-pytorch/task.py b/examples/app-pytorch/task.py index cc7554ba3001..240f290df320 100644 --- a/examples/app-pytorch/task.py +++ b/examples/app-pytorch/task.py @@ -1,8 +1,10 @@ from collections import OrderedDict +from logging import INFO import torch import torch.nn as nn import torch.nn.functional as F +from flwr.common.logger import log from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torchvision.transforms import Compose, Normalize, ToTensor @@ -42,7 +44,7 @@ def load_data(): def train(net, trainloader, valloader, epochs, device): """Train the model on the training set.""" - print("Starting training...") + log(INFO, "Starting training...") net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9) diff --git a/src/py/flwr/client/app.py b/src/py/flwr/client/app.py index c2544e10c9c8..c8287afc0fd0 100644 --- a/src/py/flwr/client/app.py +++ b/src/py/flwr/client/app.py @@ -456,7 +456,19 @@ def _load_client_app() -> ClientApp: time.sleep(3) # Wait for 3s before asking again continue - log(INFO, "Received message") + log(INFO, "") + log( + INFO, + "[RUN %s, ROUND %s]", + message.metadata.run_id, + message.metadata.group_id, + ) + log( + INFO, + "Received: %s message %s", + message.metadata.message_type, + message.metadata.message_id, + ) # Handle control message out_message, sleep_duration = handle_control_message(message) @@ -484,7 +496,18 @@ def _load_client_app() -> ClientApp: # Send send(out_message) - log(INFO, "Sent reply") + log( + INFO, + "[RUN %s, ROUND %s]", + out_message.metadata.run_id, + out_message.metadata.group_id, + ) + log( + INFO, + "Sent: %s reply to message %s", + out_message.metadata.message_type, + message.metadata.message_id, + ) # Unregister node if delete_node is not None: From 100b0abf18cd0703a9996e1ff9e81937f918a626 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Mar 2024 15:55:48 +0000 Subject: [PATCH 08/20] Update secaggplus workflow log (#3108) --- examples/app-secure-aggregation/README.md | 2 +- examples/app-secure-aggregation/run.sh | 2 +- .../workflow_with_log.py | 108 +++++++++--------- .../mod/secure_aggregation/secaggplus_mod.py | 24 ++-- .../secure_aggregation/secaggplus_mod_test.py | 44 +++---- .../secaggplus_constants.py | 4 +- .../secure_aggregation/secaggplus_workflow.py | 95 +++++++++++++-- 7 files changed, 178 insertions(+), 101 deletions(-) diff --git a/examples/app-secure-aggregation/README.md b/examples/app-secure-aggregation/README.md index bdc95c5b62fc..d1ea7bdc893f 100644 --- a/examples/app-secure-aggregation/README.md +++ b/examples/app-secure-aggregation/README.md @@ -76,7 +76,7 @@ flower-client-app client:app --insecure Start the Flower `ServerApp`: ```bash -flower-server-app server:app --insecure +flower-server-app server:app --insecure --verbose ``` ## Amend the example for practical usage diff --git a/examples/app-secure-aggregation/run.sh b/examples/app-secure-aggregation/run.sh index 834699692ada..fa8dc47f26ef 100755 --- a/examples/app-secure-aggregation/run.sh +++ b/examples/app-secure-aggregation/run.sh @@ -23,7 +23,7 @@ do done echo "Starting ServerApp..." -flower-server-app --insecure server:app +flower-server-app --insecure server:app --verbose echo "Clearing background processes..." diff --git a/examples/app-secure-aggregation/workflow_with_log.py b/examples/app-secure-aggregation/workflow_with_log.py index 5a83bd69190f..a03ff8c13b6c 100644 --- a/examples/app-secure-aggregation/workflow_with_log.py +++ b/examples/app-secure-aggregation/workflow_with_log.py @@ -1,3 +1,5 @@ +from flwr.common import Context, log, parameters_to_ndarrays +from logging import INFO from flwr.server import Driver, LegacyContext from flwr.server.workflow.secure_aggregation.secaggplus_workflow import ( SecAggPlusWorkflow, @@ -5,6 +7,8 @@ ) import numpy as np from flwr.common.secure_aggregation.quantization import quantize +from flwr.server.workflow.constant import MAIN_PARAMS_RECORD +import flwr.common.recordset_compat as compat class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): @@ -16,75 +20,73 @@ class SecAggPlusWorkflowWithLogs(SecAggPlusWorkflow): node_ids = [] - def setup_stage( - self, driver: Driver, context: LegacyContext, state: WorkflowState - ) -> bool: + def __call__(self, driver: Driver, context: Context) -> None: _quantized = quantize( [np.ones(3) for _ in range(5)], self.clipping_range, self.quantization_range ) - print( - "\n\n################################ Introduction ################################\n" - "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of\n" - "model updates for demonstration purposes.\n" - "Client 0 is configured to drop out before uploading the masked vector.\n" - f"After quantization, the raw vectors will look like:" + log(INFO, "") + log( + INFO, + "################################ Introduction ################################", + ) + log( + INFO, + "In the example, each client will upload a vector [1.0, 1.0, 1.0] instead of", ) + log(INFO, "model updates for demonstration purposes.") + log( + INFO, + "Client 0 is configured to drop out before uploading the masked vector.", + ) + log(INFO, "After quantization, the raw vectors will look like:") for i in range(1, 5): - print(f"\t{_quantized[i]} from Client {i}") - print( - f"Numbers are rounded to integers stochastically during the quantization\n" - ", and thus entries may not be identical." + log(INFO, "\t%s from Client %s", _quantized[i], i) + log( + INFO, + "Numbers are rounded to integers stochastically during the quantization", ) - print( - "The above raw vectors are hidden from the driver through adding masks.\n" + log(INFO, ", and thus entries may not be identical.") + log( + INFO, + "The above raw vectors are hidden from the driver through adding masks.", ) - print( - "########################## Secure Aggregation Start ##########################" + log(INFO, "") + log( + INFO, + "########################## Secure Aggregation Start ##########################", ) - print(f"Sending configurations to 5 clients...") + + super().__call__(driver, context) + + paramsrecord = context.state.parameters_records[MAIN_PARAMS_RECORD] + parameters = compat.parametersrecord_to_parameters(paramsrecord, True) + ndarrays = parameters_to_ndarrays(parameters) + log( + INFO, + "Weighted average of vectors (dequantized): %s", + ndarrays[0], + ) + log( + INFO, + "########################### Secure Aggregation End ###########################", + ) + log(INFO, "") + + def setup_stage( + self, driver: Driver, context: LegacyContext, state: WorkflowState + ) -> bool: ret = super().setup_stage(driver, context, state) - print(f"Received public keys from {len(state.active_node_ids)} clients.") self.node_ids = list(state.active_node_ids) state.nid_to_fitins[self.node_ids[0]].configs_records["fitins.config"][ "drop" ] = True return ret - def share_keys_stage( + def collect_masked_vectors_stage( self, driver: Driver, context: LegacyContext, state: WorkflowState ) -> bool: - print(f"\nForwarding public keys...") - ret = super().share_keys_stage(driver, context, state) - print( - f"Received encrypted key shares from {len(state.active_node_ids)} clients." - ) - return ret - - def collect_masked_input_stage( - self, driver: Driver, context: LegacyContext, state: WorkflowState - ) -> bool: - print(f"\nForwarding encrypted key shares and requesting masked vectors...") - ret = super().collect_masked_input_stage(driver, context, state) + ret = super().collect_masked_vectors_stage(driver, context, state) for node_id in state.sampled_node_ids - state.active_node_ids: - print(f"Client {self.node_ids.index(node_id)} dropped out.") - for node_id in state.active_node_ids: - print( - f"Received masked vectors from Client {self.node_ids.index(node_id)}." - ) - print(f"Obtained sum of masked vectors: {state.aggregate_ndarrays[1]}") - return ret - - def unmask_stage( - self, driver: Driver, context: LegacyContext, state: WorkflowState - ) -> bool: - print("\nRequesting key shares to unmask the aggregate vector...") - ret = super().unmask_stage(driver, context, state) - print(f"Received key shares from {len(state.active_node_ids)} clients.") - - print( - f"Weighted average of vectors (dequantized): {state.aggregate_ndarrays[0]}" - ) - print( - "########################### Secure Aggregation End ###########################\n\n" - ) + log(INFO, "Client %s dropped out.", self.node_ids.index(node_id)) + log(INFO, "Obtained sum of masked vectors: %s", state.aggregate_ndarrays[1]) return ret diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index e34ff4376a43..ed0f8f4fd7b5 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -178,9 +178,9 @@ def secaggplus_mod( res = _setup(state, configs) elif state.current_stage == Stage.SHARE_KEYS: res = _share_keys(state, configs) - elif state.current_stage == Stage.COLLECT_MASKED_INPUT: + elif state.current_stage == Stage.COLLECT_MASKED_VECTORS: fit = _get_fit_fn(msg, ctxt, call_next) - res = _collect_masked_input(state, configs, fit) + res = _collect_masked_vectors(state, configs, fit) elif state.current_stage == Stage.UNMASK: res = _unmask(state, configs) else: @@ -199,7 +199,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None: # Check the existence of Config.STAGE if Key.STAGE not in configs: raise KeyError( - f"The required key '{Key.STAGE}' is missing from the input `named_values`." + f"The required key '{Key.STAGE}' is missing from the ConfigsRecord." ) # Check the value type of the Config.STAGE @@ -215,7 +215,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None: if current_stage != Stage.UNMASK: log(WARNING, "Restart from the setup stage") # If stage is not "setup", - # the stage from `named_values` should be the expected next stage + # the stage from configs should be the expected next stage else: stages = Stage.all() expected_next_stage = stages[(stages.index(current_stage) + 1) % len(stages)] @@ -229,7 +229,7 @@ def check_stage(current_stage: str, configs: ConfigsRecord) -> None: # pylint: disable-next=too-many-branches def check_configs(stage: str, configs: ConfigsRecord) -> None: """Check the validity of the configs.""" - # Check `named_values` for the setup stage + # Check configs for the setup stage if stage == Stage.SETUP: key_type_pairs = [ (Key.SAMPLE_NUMBER, int), @@ -243,7 +243,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: if key not in configs: raise KeyError( f"Stage {Stage.SETUP}: the required key '{key}' is " - "missing from the input `named_values`." + "missing from the ConfigsRecord." ) # Bool is a subclass of int in Python, # so `isinstance(v, int)` will return True even if v is a boolean. @@ -266,7 +266,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: f"Stage {Stage.SHARE_KEYS}: " f"the value for the key '{key}' must be a list of two bytes." ) - elif stage == Stage.COLLECT_MASKED_INPUT: + elif stage == Stage.COLLECT_MASKED_VECTORS: key_type_pairs = [ (Key.CIPHERTEXT_LIST, bytes), (Key.SOURCE_LIST, int), @@ -274,9 +274,9 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: for key, expected_type in key_type_pairs: if key not in configs: raise KeyError( - f"Stage {Stage.COLLECT_MASKED_INPUT}: " + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " f"the required key '{key}' is " - "missing from the input `named_values`." + "missing from the ConfigsRecord." ) if not isinstance(configs[key], list) or any( elm @@ -285,7 +285,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: if type(elm) is not expected_type ): raise TypeError( - f"Stage {Stage.COLLECT_MASKED_INPUT}: " + f"Stage {Stage.COLLECT_MASKED_VECTORS}: " f"the value for the key '{key}' " f"must be of type List[{expected_type.__name__}]" ) @@ -299,7 +299,7 @@ def check_configs(stage: str, configs: ConfigsRecord) -> None: raise KeyError( f"Stage {Stage.UNMASK}: " f"the required key '{key}' is " - "missing from the input `named_values`." + "missing from the ConfigsRecord." ) if not isinstance(configs[key], list) or any( elm @@ -414,7 +414,7 @@ def _share_keys( # pylint: disable-next=too-many-locals -def _collect_masked_input( +def _collect_masked_vectors( state: SecAggPlusState, configs: ConfigsRecord, fit: Callable[[], FitRes], diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index d78775cfe24d..d72d8b414f65 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -93,7 +93,7 @@ def test_stage_transition(self) -> None: assert Stage.all() == ( Stage.SETUP, Stage.SHARE_KEYS, - Stage.COLLECT_MASKED_INPUT, + Stage.COLLECT_MASKED_VECTORS, Stage.UNMASK, ) @@ -101,13 +101,13 @@ def test_stage_transition(self) -> None: # From one stage to the next stage (Stage.UNMASK, Stage.SETUP), (Stage.SETUP, Stage.SHARE_KEYS), - (Stage.SHARE_KEYS, Stage.COLLECT_MASKED_INPUT), - (Stage.COLLECT_MASKED_INPUT, Stage.UNMASK), + (Stage.SHARE_KEYS, Stage.COLLECT_MASKED_VECTORS), + (Stage.COLLECT_MASKED_VECTORS, Stage.UNMASK), # From any stage to the initial stage # Such transitions will log a warning. (Stage.SETUP, Stage.SETUP), (Stage.SHARE_KEYS, Stage.SETUP), - (Stage.COLLECT_MASKED_INPUT, Stage.SETUP), + (Stage.COLLECT_MASKED_VECTORS, Stage.SETUP), } invalid_transitions = set(product(Stage.all(), Stage.all())).difference( @@ -159,17 +159,17 @@ def test_stage_setup_check(self) -> None: for key, value_type in valid_key_type_pairs } - # Test valid `named_values` + # Test valid configs try: check_configs(Stage.SETUP, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") # Set the stage valid_configs[Key.STAGE] = Stage.SETUP - # Test invalid `named_values` + # Test invalid configs for key, value_type in valid_key_type_pairs: invalid_configs = valid_configs.copy() @@ -202,17 +202,17 @@ def test_stage_share_keys_check(self) -> None: "3": [b"public key 1", b"public key 2"], } - # Test valid `named_values` + # Test valid configs try: check_configs(Stage.SHARE_KEYS, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") # Set the stage valid_configs[Key.STAGE] = Stage.SHARE_KEYS - # Test invalid `named_values` + # Test invalid configs invalid_values: List[ConfigsRecordValues] = [ b"public key 1", [b"public key 1"], @@ -227,8 +227,8 @@ def test_stage_share_keys_check(self) -> None: with self.assertRaises(TypeError): handler(invalid_configs.copy()) - def test_stage_collect_masked_input_check(self) -> None: - """Test content checking for the collect masked input stage.""" + def test_stage_collect_masked_vectors_check(self) -> None: + """Test content checking for the collect masked vectors stage.""" ctxt = _make_ctxt() handler = get_test_handler(ctxt) set_stage = _make_set_state_fn(ctxt) @@ -238,17 +238,17 @@ def test_stage_collect_masked_input_check(self) -> None: Key.SOURCE_LIST: [32, 51324, 32324123, -3], } - # Test valid `named_values` + # Test valid configs try: - check_configs(Stage.COLLECT_MASKED_INPUT, ConfigsRecord(valid_configs)) + check_configs(Stage.COLLECT_MASKED_VECTORS, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") # Set the stage - valid_configs[Key.STAGE] = Stage.COLLECT_MASKED_INPUT + valid_configs[Key.STAGE] = Stage.COLLECT_MASKED_VECTORS - # Test invalid `named_values` + # Test invalid configs # Test missing keys for key in list(valid_configs.keys()): if key == Key.STAGE: @@ -282,17 +282,17 @@ def test_stage_unmask_check(self) -> None: Key.DEAD_NODE_ID_LIST: [32, 51324, 32324123, -3], } - # Test valid `named_values` + # Test valid configs try: check_configs(Stage.UNMASK, ConfigsRecord(valid_configs)) # pylint: disable-next=broad-except except Exception as exc: - self.fail(f"check_named_values() raised {type(exc)} unexpectedly!") + self.fail(f"check_configs() raised {type(exc)} unexpectedly!") # Set the stage valid_configs[Key.STAGE] = Stage.UNMASK - # Test invalid `named_values` + # Test invalid configs # Test missing keys for key in list(valid_configs.keys()): if key == Key.STAGE: @@ -300,7 +300,7 @@ def test_stage_unmask_check(self) -> None: invalid_configs = valid_configs.copy() invalid_configs.pop(key) - set_stage(Stage.COLLECT_MASKED_INPUT) + set_stage(Stage.COLLECT_MASKED_VECTORS) with self.assertRaises(KeyError): handler(invalid_configs) @@ -311,6 +311,6 @@ def test_stage_unmask_check(self) -> None: invalid_configs = valid_configs.copy() invalid_configs[key] = [True, False, True, False] - set_stage(Stage.COLLECT_MASKED_INPUT) + set_stage(Stage.COLLECT_MASKED_VECTORS) with self.assertRaises(TypeError): handler(invalid_configs) diff --git a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py index 1c6c1d14f56e..8a15908c13c5 100644 --- a/src/py/flwr/common/secure_aggregation/secaggplus_constants.py +++ b/src/py/flwr/common/secure_aggregation/secaggplus_constants.py @@ -27,9 +27,9 @@ class Stage: SETUP = "setup" SHARE_KEYS = "share_keys" - COLLECT_MASKED_INPUT = "collect_masked_input" + COLLECT_MASKED_VECTORS = "collect_masked_vectors" UNMASK = "unmask" - _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_INPUT, UNMASK) + _stages = (SETUP, SHARE_KEYS, COLLECT_MASKED_VECTORS, UNMASK) @classmethod def all(cls) -> tuple[str, str, str, str]: diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index a9814738a7d4..559dc1cf8739 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -17,7 +17,7 @@ import random from dataclasses import dataclass, field -from logging import ERROR, WARN +from logging import DEBUG, ERROR, INFO, WARN from typing import Dict, List, Optional, Set, Union, cast import flwr.common.recordset_compat as compat @@ -101,7 +101,7 @@ class SecAggPlusWorkflow: - 'setup': Send SecAgg+ configuration to clients and collect their public keys. - 'share keys': Broadcast public keys among clients and collect encrypted secret key shares. - - 'collect masked inputs': Forward encrypted secret key shares to target clients + - 'collect masked vectors': Forward encrypted secret key shares to target clients and collect masked model parameters. - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. @@ -195,12 +195,15 @@ def __call__(self, driver: Driver, context: Context) -> None: steps = ( self.setup_stage, self.share_keys_stage, - self.collect_masked_input_stage, + self.collect_masked_vectors_stage, self.unmask_stage, ) + log(INFO, "Secure aggregation commencing.") for step in steps: if not step(driver, context, state): + log(INFO, "Secure aggregation halted.") return + log(INFO, "Secure aggregation completed.") def _check_init_params(self) -> None: # pylint: disable=R0912 # Check `num_shares` @@ -287,6 +290,16 @@ def setup_stage( # pylint: disable=R0912, R0914, R0915 proxy_fitins_lst = context.strategy.configure_fit( current_round, parameters, context.client_manager ) + if not proxy_fitins_lst: + log(INFO, "configure_fit: no clients selected, cancel") + return False + log( + INFO, + "configure_fit: strategy sampled %s clients (out of %s)", + len(proxy_fitins_lst), + context.client_manager.num_available(), + ) + state.nid_to_fitins = { proxy.node_id: compat.fitins_to_recordset(fitins, False) for proxy, fitins in proxy_fitins_lst @@ -362,12 +375,22 @@ def make(nid: int) -> Message: ttl="", ) + log( + DEBUG, + "[Stage 0] Sending configurations to %s clients.", + len(state.active_node_ids), + ) msgs = driver.send_and_receive( [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout ) state.active_node_ids = { msg.metadata.src_node_id for msg in msgs if not msg.has_error() } + log( + DEBUG, + "[Stage 0] Received public keys from %s clients.", + len(state.active_node_ids), + ) for msg in msgs: if msg.has_error(): @@ -401,12 +424,22 @@ def make(nid: int) -> Message: ) # Broadcast public keys to clients and receive secret key shares + log( + DEBUG, + "[Stage 1] Forwarding public keys to %s clients.", + len(state.active_node_ids), + ) msgs = driver.send_and_receive( [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout ) state.active_node_ids = { msg.metadata.src_node_id for msg in msgs if not msg.has_error() } + log( + DEBUG, + "[Stage 1] Received encrypted key shares from %s clients.", + len(state.active_node_ids), + ) # Build forward packet list dictionary srcs: List[int] = [] @@ -437,16 +470,16 @@ def make(nid: int) -> Message: return self._check_threshold(state) - def collect_masked_input_stage( + def collect_masked_vectors_stage( self, driver: Driver, context: LegacyContext, state: WorkflowState ) -> bool: - """Execute the 'collect masked input' stage.""" + """Execute the 'collect masked vectors' stage.""" cfg = context.state.configs_records[MAIN_CONFIGS_RECORD] - # Send secret key shares to clients (plus FitIns) and collect masked input + # Send secret key shares to clients (plus FitIns) and collect masked vectors def make(nid: int) -> Message: cfgs_dict = { - Key.STAGE: Stage.COLLECT_MASKED_INPUT, + Key.STAGE: Stage.COLLECT_MASKED_VECTORS, Key.CIPHERTEXT_LIST: state.forward_ciphertexts[nid], Key.SOURCE_LIST: state.forward_srcs[nid], } @@ -461,12 +494,22 @@ def make(nid: int) -> Message: ttl="", ) + log( + DEBUG, + "[Stage 2] Forwarding encrypted key shares to %s clients.", + len(state.active_node_ids), + ) msgs = driver.send_and_receive( [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout ) state.active_node_ids = { msg.metadata.src_node_id for msg in msgs if not msg.has_error() } + log( + DEBUG, + "[Stage 2] Received masked vectors from %s clients.", + len(state.active_node_ids), + ) # Clear cache del state.forward_ciphertexts, state.forward_srcs, state.nid_to_fitins @@ -487,7 +530,7 @@ def make(nid: int) -> Message: return self._check_threshold(state) - def unmask_stage( # pylint: disable=R0912, R0914 + def unmask_stage( # pylint: disable=R0912, R0914, R0915 self, driver: Driver, context: LegacyContext, state: WorkflowState ) -> bool: """Execute the 'unmask' stage.""" @@ -516,12 +559,22 @@ def make(nid: int) -> Message: ttl="", ) + log( + DEBUG, + "[Stage 3] Requesting key shares from %s clients to remove masks.", + len(state.active_node_ids), + ) msgs = driver.send_and_receive( [make(node_id) for node_id in state.active_node_ids], timeout=self.timeout ) state.active_node_ids = { msg.metadata.src_node_id for msg in msgs if not msg.has_error() } + log( + DEBUG, + "[Stage 3] Received key shares from %s clients.", + len(state.active_node_ids), + ) # Build collected shares dict collected_shares_dict: Dict[int, List[bytes]] = {} @@ -534,7 +587,7 @@ def make(nid: int) -> Message: for owner_nid, share in zip(nids, shares): collected_shares_dict[owner_nid].append(share) - # Remove mask for every client who is available after collect_masked_input stage + # Remove masks for every active client after collect_masked_vectors stage masked_vector = state.aggregate_ndarrays del state.aggregate_ndarrays for nid, share_list in collected_shares_dict.items(): @@ -585,6 +638,15 @@ def make(nid: int) -> Message: vec += offset vec *= inv_dq_total_ratio state.aggregate_ndarrays = aggregated_vector + + # No exception/failure handling currently + log( + INFO, + "aggregate_fit: received %s results and %s failures", + 1, + 0, + ) + final_fitres = FitRes( status=Status(code=Code.OK, message=""), parameters=ndarrays_to_parameters(aggregated_vector), @@ -597,5 +659,18 @@ def make(nid: int) -> Message: False, driver.run_id, # type: ignore ) - context.strategy.aggregate_fit(current_round, [(empty_proxy, final_fitres)], []) + aggregated_result = context.strategy.aggregate_fit( + current_round, [(empty_proxy, final_fitres)], [] + ) + parameters_aggregated, metrics_aggregated = aggregated_result + + # Update the parameters and write history + if parameters_aggregated: + paramsrecord = compat.parameters_to_parametersrecord( + parameters_aggregated, True + ) + context.state.parameters_records[MAIN_PARAMS_RECORD] = paramsrecord + context.history.add_metrics_distributed_fit( + server_round=current_round, metrics=metrics_aggregated + ) return True From a3d83d97d262f48753e497f51877d46663d08f7a Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Mon, 11 Mar 2024 16:12:21 +0000 Subject: [PATCH 09/20] Add SecAggWorkflow (#3110) --- src/py/flwr/client/mod/__init__.py | 3 +- .../client/mod/secure_aggregation/__init__.py | 2 + .../mod/secure_aggregation/secagg_mod.py | 30 +++++ src/py/flwr/server/workflow/__init__.py | 3 +- .../workflow/secure_aggregation/__init__.py | 2 + .../secure_aggregation/secagg_workflow.py | 112 ++++++++++++++++++ 6 files changed, 150 insertions(+), 2 deletions(-) create mode 100644 src/py/flwr/client/mod/secure_aggregation/secagg_mod.py create mode 100644 src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py diff --git a/src/py/flwr/client/mod/__init__.py b/src/py/flwr/client/mod/__init__.py index ae7722849915..69a7d76ce95f 100644 --- a/src/py/flwr/client/mod/__init__.py +++ b/src/py/flwr/client/mod/__init__.py @@ -17,7 +17,7 @@ from .centraldp_mods import adaptiveclipping_mod, fixedclipping_mod from .localdp_mod import LocalDpMod -from .secure_aggregation.secaggplus_mod import secaggplus_mod +from .secure_aggregation import secagg_mod, secaggplus_mod from .utils import make_ffn __all__ = [ @@ -25,5 +25,6 @@ "fixedclipping_mod", "LocalDpMod", "make_ffn", + "secagg_mod", "secaggplus_mod", ] diff --git a/src/py/flwr/client/mod/secure_aggregation/__init__.py b/src/py/flwr/client/mod/secure_aggregation/__init__.py index 863c149f952e..8892d8c03935 100644 --- a/src/py/flwr/client/mod/secure_aggregation/__init__.py +++ b/src/py/flwr/client/mod/secure_aggregation/__init__.py @@ -15,8 +15,10 @@ """Secure Aggregation mods.""" +from .secagg_mod import secagg_mod from .secaggplus_mod import secaggplus_mod __all__ = [ + "secagg_mod", "secaggplus_mod", ] diff --git a/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py new file mode 100644 index 000000000000..d87af59a4e6e --- /dev/null +++ b/src/py/flwr/client/mod/secure_aggregation/secagg_mod.py @@ -0,0 +1,30 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Modifier for the SecAgg protocol.""" + + +from flwr.client.typing import ClientAppCallable +from flwr.common import Context, Message + +from .secaggplus_mod import secaggplus_mod + + +def secagg_mod( + msg: Message, + ctxt: Context, + call_next: ClientAppCallable, +) -> Message: + """Handle incoming message and return results, following the SecAgg protocol.""" + return secaggplus_mod(msg, ctxt, call_next) diff --git a/src/py/flwr/server/workflow/__init__.py b/src/py/flwr/server/workflow/__init__.py index 7b09119726a9..31dee89a185d 100644 --- a/src/py/flwr/server/workflow/__init__.py +++ b/src/py/flwr/server/workflow/__init__.py @@ -16,9 +16,10 @@ from .default_workflows import DefaultWorkflow -from .secure_aggregation import SecAggPlusWorkflow +from .secure_aggregation import SecAggPlusWorkflow, SecAggWorkflow __all__ = [ "DefaultWorkflow", "SecAggPlusWorkflow", + "SecAggWorkflow", ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/__init__.py b/src/py/flwr/server/workflow/secure_aggregation/__init__.py index 03924df5fe91..25e2a32da334 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/__init__.py +++ b/src/py/flwr/server/workflow/secure_aggregation/__init__.py @@ -15,8 +15,10 @@ """Secure Aggregation workflows.""" +from .secagg_workflow import SecAggWorkflow from .secaggplus_workflow import SecAggPlusWorkflow __all__ = [ "SecAggPlusWorkflow", + "SecAggWorkflow", ] diff --git a/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py new file mode 100644 index 000000000000..f56423e4a0d0 --- /dev/null +++ b/src/py/flwr/server/workflow/secure_aggregation/secagg_workflow.py @@ -0,0 +1,112 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Workflow for the SecAgg protocol.""" + + +from typing import Optional, Union + +from .secaggplus_workflow import SecAggPlusWorkflow + + +class SecAggWorkflow(SecAggPlusWorkflow): + """The workflow for the SecAgg protocol. + + The SecAgg protocol ensures the secure summation of integer vectors owned by + multiple parties, without accessing any individual integer vector. This workflow + allows the server to compute the weighted average of model parameters across all + clients, ensuring individual contributions remain private. This is achieved by + clients sending both, a weighting factor and a weighted version of the locally + updated parameters, both of which are masked for privacy. Specifically, each + client uploads "[w, w * params]" with masks, where weighting factor 'w' is the + number of examples ('num_examples') and 'params' represents the model parameters + ('parameters') from the client's `FitRes`. The server then aggregates these + contributions to compute the weighted average of model parameters. + + The protocol involves four main stages: + - 'setup': Send SecAgg configuration to clients and collect their public keys. + - 'share keys': Broadcast public keys among clients and collect encrypted secret + key shares. + - 'collect masked vectors': Forward encrypted secret key shares to target clients + and collect masked model parameters. + - 'unmask': Collect secret key shares to decrypt and aggregate the model parameters. + + Only the aggregated model parameters are exposed and passed to + `Strategy.aggregate_fit`, ensuring individual data privacy. + + Parameters + ---------- + reconstruction_threshold : Union[int, float] + The minimum number of shares required to reconstruct a client's private key, + or, if specified as a float, it represents the proportion of the total number + of shares needed for reconstruction. This threshold ensures privacy by allowing + for the recovery of contributions from dropped clients during aggregation, + without compromising individual client data. + max_weight : Optional[float] (default: 1000.0) + The maximum value of the weight that can be assigned to any single client's + update during the weighted average calculation on the server side, e.g., in the + FedAvg algorithm. + clipping_range : float, optional (default: 8.0) + The range within which model parameters are clipped before quantization. + This parameter ensures each model parameter is bounded within + [-clipping_range, clipping_range], facilitating quantization. + quantization_range : int, optional (default: 4194304, this equals 2**22) + The size of the range into which floating-point model parameters are quantized, + mapping each parameter to an integer in [0, quantization_range-1]. This + facilitates cryptographic operations on the model updates. + modulus_range : int, optional (default: 4294967296, this equals 2**32) + The range of values from which random mask entries are uniformly sampled + ([0, modulus_range-1]). `modulus_range` must be less than 4294967296. + Please use 2**n values for `modulus_range` to prevent overflow issues. + timeout : Optional[float] (default: None) + The timeout duration in seconds. If specified, the workflow will wait for + replies for this duration each time. If `None`, there is no time limit and + the workflow will wait until replies for all messages are received. + + Notes + ----- + - Each client's private key is split into N shares under the SecAgg protocol, where + N is the number of selected clients. + - Generally, higher `reconstruction_threshold` means better privacy guarantees but + less tolerance to dropouts. + - Too large `max_weight` may compromise the precision of the quantization. + - `modulus_range` must be 2**n and larger than `quantization_range`. + - When `reconstruction_threshold` is a float, it is interpreted as the proportion of + the number of all selected clients needed for the reconstruction of a private key. + This feature enables flexibility in setting the security threshold relative to the + number of selected clients. + - `reconstruction_threshold`, and the quantization parameters + (`clipping_range`, `quantization_range`, `modulus_range`) play critical roles in + balancing privacy, robustness, and efficiency within the SecAgg protocol. + """ + + def __init__( # pylint: disable=R0913 + self, + reconstruction_threshold: Union[int, float], + *, + max_weight: float = 1000.0, + clipping_range: float = 8.0, + quantization_range: int = 4194304, + modulus_range: int = 4294967296, + timeout: Optional[float] = None, + ) -> None: + super().__init__( + num_shares=1.0, + reconstruction_threshold=reconstruction_threshold, + max_weight=max_weight, + clipping_range=clipping_range, + quantization_range=quantization_range, + modulus_range=modulus_range, + timeout=timeout, + ) From efc3c909c133e73420255b17cad5607c5682d044 Mon Sep 17 00:00:00 2001 From: Javier Date: Mon, 11 Mar 2024 18:56:37 +0100 Subject: [PATCH 10/20] Make InMemoryState thread-safe (again) (#3113) --- .../server/superlink/state/in_memory_state.py | 66 ++++++++++--------- 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/src/py/flwr/server/superlink/state/in_memory_state.py b/src/py/flwr/server/superlink/state/in_memory_state.py index 690fadc032d7..ac1ab158e254 100644 --- a/src/py/flwr/server/superlink/state/in_memory_state.py +++ b/src/py/flwr/server/superlink/state/in_memory_state.py @@ -122,7 +122,8 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: task_res.task_id = str(task_id) task_res.task.created_at = created_at.isoformat() task_res.task.ttl = ttl.isoformat() - self.task_res_store[task_id] = task_res + with self.lock: + self.task_res_store[task_id] = task_res # Return the new task_id return task_id @@ -132,46 +133,47 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe if limit is not None and limit < 1: raise AssertionError("`limit` must be >= 1") - # Find TaskRes that were not delivered yet - task_res_list: List[TaskRes] = [] - for _, task_res in self.task_res_store.items(): - if ( - UUID(task_res.task.ancestry[0]) in task_ids - and task_res.task.delivered_at == "" - ): - task_res_list.append(task_res) - if limit and len(task_res_list) == limit: - break + with self.lock: + # Find TaskRes that were not delivered yet + task_res_list: List[TaskRes] = [] + for _, task_res in self.task_res_store.items(): + if ( + UUID(task_res.task.ancestry[0]) in task_ids + and task_res.task.delivered_at == "" + ): + task_res_list.append(task_res) + if limit and len(task_res_list) == limit: + break - # Mark all of them as delivered - delivered_at = now().isoformat() - for task_res in task_res_list: - task_res.task.delivered_at = delivered_at + # Mark all of them as delivered + delivered_at = now().isoformat() + for task_res in task_res_list: + task_res.task.delivered_at = delivered_at - # Return TaskRes - return task_res_list + # Return TaskRes + return task_res_list def delete_tasks(self, task_ids: Set[UUID]) -> None: """Delete all delivered TaskIns/TaskRes pairs.""" task_ins_to_be_deleted: Set[UUID] = set() task_res_to_be_deleted: Set[UUID] = set() - for task_ins_id in task_ids: - # Find the task_id of the matching task_res - for task_res_id, task_res in self.task_res_store.items(): - if UUID(task_res.task.ancestry[0]) != task_ins_id: - continue - if task_res.task.delivered_at == "": - continue - - task_ins_to_be_deleted.add(task_ins_id) - task_res_to_be_deleted.add(task_res_id) - - for task_id in task_ins_to_be_deleted: - with self.lock: + with self.lock: + for task_ins_id in task_ids: + # Find the task_id of the matching task_res + for task_res_id, task_res in self.task_res_store.items(): + if UUID(task_res.task.ancestry[0]) != task_ins_id: + continue + if task_res.task.delivered_at == "": + continue + + task_ins_to_be_deleted.add(task_ins_id) + task_res_to_be_deleted.add(task_res_id) + + for task_id in task_ins_to_be_deleted: del self.task_ins_store[task_id] - for task_id in task_res_to_be_deleted: - del self.task_res_store[task_id] + for task_id in task_res_to_be_deleted: + del self.task_res_store[task_id] def num_task_ins(self) -> int: """Calculate the number of task_ins in store. From 0864aa729b0d7adb0488cdf4d6442ddbf8692bff Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Mon, 11 Mar 2024 22:01:43 +0100 Subject: [PATCH 11/20] Add partition division (#2951) Co-authored-by: Javier --- datasets/flwr_datasets/__init__.py | 2 + datasets/flwr_datasets/federated_dataset.py | 133 +++++++++++++-- .../flwr_datasets/federated_dataset_test.py | 44 ++++- datasets/flwr_datasets/utils.py | 156 +++++++++++++++++- datasets/flwr_datasets/utils_test.py | 70 ++++++++ 5 files changed, 392 insertions(+), 13 deletions(-) create mode 100644 datasets/flwr_datasets/utils_test.py diff --git a/datasets/flwr_datasets/__init__.py b/datasets/flwr_datasets/__init__.py index 48d993037708..0b9a6685427b 100644 --- a/datasets/flwr_datasets/__init__.py +++ b/datasets/flwr_datasets/__init__.py @@ -16,6 +16,7 @@ from flwr_datasets import partitioner, resplitter +from flwr_datasets import utils as utils from flwr_datasets.common.version import package_version as _package_version from flwr_datasets.federated_dataset import FederatedDataset @@ -23,6 +24,7 @@ "FederatedDataset", "partitioner", "resplitter", + "utils", ] __version__ = _package_version diff --git a/datasets/flwr_datasets/federated_dataset.py b/datasets/flwr_datasets/federated_dataset.py index c40f8cc34857..588d1ab40aec 100644 --- a/datasets/flwr_datasets/federated_dataset.py +++ b/datasets/flwr_datasets/federated_dataset.py @@ -15,7 +15,7 @@ """FederatedDataset.""" -from typing import Dict, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union, cast import datasets from datasets import Dataset, DatasetDict @@ -25,9 +25,12 @@ _check_if_dataset_tested, _instantiate_partitioners, _instantiate_resplitter_if_needed, + divide_dataset, ) +# flake8: noqa: E501 +# pylint: disable=line-too-long class FederatedDataset: """Representation of a dataset for federated learning/evaluation/analytics. @@ -51,6 +54,19 @@ class FederatedDataset: (representing the number of IID partitions that this split should be partitioned into). One or multiple `Partitioner` objects can be specified in that manner, but at most, one per split. + partition_division : Optional[Union[List[float], Tuple[float, ...], + Dict[str, float], Dict[str, Optional[Union[List[float], Tuple[float, ...], + Dict[str, float]]]]]] + Fractions specifing the division of the partition assiciated with certain split + (and partitioner) that enable returning already divided partition from the + `load_partition` method. You can think of this as on-edge division of the data + into multiple divisions (e.g. into train and validation). You can also name the + divisions by using the Dict or create specify it as a List/Tuple. If you + specified a single partitioner you can provide the simplified form e.g. + [0.8, 0.2] or {"partition_train": 0.8, "partition_test": 0.2} but when multiple + partitioners are specified you need to indicate the result of which partitioner + are further divided e.g. {"train": [0.8, 0.2]} would result in dividing only the + partitions that are created from the "train" split. shuffle : bool Whether to randomize the order of samples. Applied prior to resplitting, speratelly to each of the present splits in the dataset. It uses the `seed` @@ -64,14 +80,18 @@ class FederatedDataset: Use MNIST dataset for Federated Learning with 100 clients (edge devices): >>> mnist_fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) - - Load partition for client with ID 10. - + >>> # Load partition for client with ID 10. >>> partition = mnist_fds.load_partition(10, "train") - - Use test split for centralized evaluation. - + >>> # Use test split for centralized evaluation. >>> centralized = mnist_fds.load_full("test") + + Automatically divde the data returned from `load_partition` + >>> mnist_fds = FederatedDataset( + >>> dataset="mnist", + >>> partitioners={"train": 100}, + >>> partition_division=[0.8, 0.2], + >>> ) + >>> partition_train, partition_test = mnist_fds.load_partition(10, "train") """ # pylint: disable=too-many-instance-attributes @@ -82,6 +102,17 @@ def __init__( subset: Optional[str] = None, resplitter: Optional[Union[Resplitter, Dict[str, Tuple[str, ...]]]] = None, partitioners: Dict[str, Union[Partitioner, int]], + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ] = None, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: @@ -94,6 +125,9 @@ def __init__( self._partitioners: Dict[str, Partitioner] = _instantiate_partitioners( partitioners ) + self._partition_division = self._initialize_partition_division( + partition_division + ) self._shuffle = shuffle self._seed = seed # _dataset is prepared lazily on the first call to `load_partition` @@ -102,7 +136,11 @@ def __init__( # Indicate if the dataset is prepared for `load_partition` or `load_full` self._dataset_prepared: bool = False - def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: + def load_partition( + self, + node_id: int, + split: Optional[str] = None, + ) -> Union[Dataset, List[Dataset], DatasetDict]: """Load the partition specified by the idx in the selected split. The dataset is downloaded only when the first call to `load_partition` or @@ -122,8 +160,13 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: Returns ------- - partition : Dataset - Single partition from the dataset split. + partition : Union[Dataset, List[Dataset], DatasetDict] + Undivided or divided partition from the dataset split. + If `partition_division` is not specified then `Dataset` is returned. + If `partition_division` is specified as `List` or `Tuple` then + `List[Dataset]` is returned. + If `partition_division` is specified as `Dict` then `DatasetDict` is + returned. """ if not self._dataset_prepared: self._prepare_dataset() @@ -136,7 +179,16 @@ def load_partition(self, node_id: int, split: Optional[str] = None) -> Dataset: self._check_if_split_possible_to_federate(split) partitioner: Partitioner = self._partitioners[split] self._assign_dataset_to_partitioner(split) - return partitioner.load_partition(node_id) + partition = partitioner.load_partition(node_id) + if self._partition_division is None: + return partition + partition_division = self._partition_division.get(split) + if partition_division is None: + return partition + divided_partition: Union[List[Dataset], DatasetDict] = divide_dataset( + partition, partition_division + ) + return divided_partition def load_full(self, split: str) -> Dataset: """Load the full split of the dataset. @@ -230,3 +282,62 @@ def _check_if_no_split_keyword_possible(self) -> None: "Please set the `split` argument. You can only omit the split keyword " "if there is exactly one partitioner specified." ) + + def _initialize_partition_division( + self, + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ], + ) -> Optional[ + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ] + ]: + """Create the partition division in the full format. + + Reduced format (possible if only one partitioner exist): + + Union[List[float], Tuple[float, ...], Dict[str, float] + + Full format: Dict[str, Reduced format] + Full format represents the split to division mapping. + """ + # Check for simple dict, list, or tuple types directly + if isinstance(partition_division, (list, tuple)) or ( + isinstance(partition_division, dict) + and all(isinstance(value, float) for value in partition_division.values()) + ): + if len(self._partitioners) > 1: + raise ValueError( + f"The specified partition_division {partition_division} does not " + f"provide mapping to split but more than one partitioners is " + f"specified. Please adjust the partition_division specification to " + f"have the split names as the keys." + ) + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + {list(self._partitioners.keys())[0]: partition_division}, + ) + if isinstance(partition_division, dict): + return cast( + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + partition_division, + ) + if partition_division is None: + return None + raise TypeError("Unsupported type for partition_division") diff --git a/datasets/flwr_datasets/federated_dataset_test.py b/datasets/flwr_datasets/federated_dataset_test.py index e02b6ed5add8..e01f56342954 100644 --- a/datasets/flwr_datasets/federated_dataset_test.py +++ b/datasets/flwr_datasets/federated_dataset_test.py @@ -17,7 +17,7 @@ import unittest -from typing import Dict, Union +from typing import Dict, List, Optional, Tuple, Union from unittest.mock import Mock, patch import pytest @@ -67,6 +67,48 @@ def test_load_partition_size(self, _: str, train_num_partitions: int) -> None: len(dataset_partition0), len(dataset["train"]) // train_num_partitions ) + @parameterized.expand( # type: ignore + [ + ((0.2, 0.8), 2, False), + ({"train": 0.2, "test": 0.8}, 2, False), + ({"train": {"train": 0.2, "test": 0.8}}, 2, True), + # Not full dataset + ([0.2, 0.1], 2, False), + ({"train": 0.2, "test": 0.1}, 2, False), + (None, None, False), + ], + ) + def test_divide_partition_integration_size( + self, + partition_division: Optional[ + Union[ + List[float], + Tuple[float, ...], + Dict[str, float], + Dict[ + str, + Optional[Union[List[float], Tuple[float, ...], Dict[str, float]]], + ], + ] + ], + expected_length: Optional[int], + add_test_partitioner: bool, + ): + """Test is the `partition_division` create correct data.""" + partitioners: Dict[str, Union[Partitioner, int]] = {"train": 10} + if add_test_partitioner: + partitioners[self.test_split] = 10 + dataset_fds = FederatedDataset( + dataset=self.dataset_name, + partitioners=partitioners, + partition_division=partition_division, + ) + partition = dataset_fds.load_partition(0, "train") + if partition_division is None: + self.assertEqual(expected_length, None) + else: + self.assertEqual(len(partition), expected_length) + def test_load_full(self) -> None: """Test if the load_full works with the correct split name.""" dataset_fds = FederatedDataset( diff --git a/datasets/flwr_datasets/utils.py b/datasets/flwr_datasets/utils.py index e3d0fdfffa63..38382508035c 100644 --- a/datasets/flwr_datasets/utils.py +++ b/datasets/flwr_datasets/utils.py @@ -16,8 +16,9 @@ import warnings -from typing import Dict, Optional, Tuple, Union, cast +from typing import Dict, List, Optional, Tuple, Union, cast +from datasets import Dataset, DatasetDict from flwr_datasets.partitioner import IidPartitioner, Partitioner from flwr_datasets.resplitter import Resplitter from flwr_datasets.resplitter.merge_resplitter import MergeResplitter @@ -85,3 +86,156 @@ def _check_if_dataset_tested(dataset: str) -> None: f"The currently tested dataset are {tested_datasets}. Given: {dataset}.", stacklevel=1, ) + + +def divide_dataset( + dataset: Dataset, division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> Union[List[Dataset], DatasetDict]: + """Divide the dataset according to the `division`. + + The division support varying number of splits, which you can name. The splits are + created from the beginning of the dataset. + + Parameters + ---------- + dataset : Dataset + Dataset to be divided. + division: Union[List[float], Tuple[float, ...], Dict[str, float]] + Configuration specifying how the dataset is divided. Each fraction has to be + >0 and <=1. They have to sum up to at most 1 (smaller sum is possible). + + Returns + ------- + divided_dataset : Union[List[Dataset], DatasetDict] + If `division` is `List` or `Tuple` then `List[Dataset]` is returned else if + `division` is `Dict` then `DatasetDict` is returned. + + Examples + -------- + Use `divide_dataset` with division specified as a list. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> division = [0.8, 0.2] + >>> train, test = divide_dataset(dataset=partition, division=division) + + Use `divide_dataset` with division specified as a dict. + >>> from flwr_datasets import FederatedDataset + >>> from flwr_datasets.utils import divide_dataset + >>> + >>> fds = FederatedDataset(dataset="mnist", partitioners={"train": 100}) + >>> partition = fds.load_partition(0) + >>> division = {"train": 0.8, "test": 0.2} + >>> train_test = divide_dataset(dataset=partition, division=division) + >>> train, test = train_test["train"], train_test["test"] + """ + dataset_length = len(dataset) + ranges = _create_division_indices_ranges(dataset_length, division) + if isinstance(division, (list, tuple)): + split_partition: List[Dataset] = [] + for single_range in ranges: + split_partition.append(dataset.select(single_range)) + return split_partition + if isinstance(division, dict): + split_partition_dict: Dict[str, Dataset] = {} + for split_name, single_range in zip(division.keys(), ranges): + split_partition_dict[split_name] = dataset.select(single_range) + return DatasetDict(split_partition_dict) + raise TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead." + ) + + +def _create_division_indices_ranges( + dataset_length: int, + division: Union[List[float], Tuple[float, ...], Dict[str, float]], +) -> List[range]: + ranges = [] + if isinstance(division, (list, tuple)): + start_idx = 0 + end_idx = 0 + for fraction in division: + end_idx += int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx += end_idx + elif isinstance(division, dict): + ranges = [] + start_idx = 0 + end_idx = 0 + for fraction in division.values(): + end_idx += int(dataset_length * fraction) + ranges.append(range(start_idx, end_idx)) + start_idx += end_idx + else: + TypeError( + f"The type of the `division` should be dict, " + f"tuple or list but is {type(division)} instead. " + ) + return ranges + + +def _check_division_config_types_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + if isinstance(division, (list, tuple)): + if not all(isinstance(x, float) for x in division): + raise TypeError( + "List or tuple values of `division` must contain only floats, " + "other types are not allowed." + ) + elif isinstance(division, dict): + if not all(isinstance(x, float) for x in division.values()): + raise TypeError( + "Dict values of `division` must be only floats, " + "other types are not allowed." + ) + else: + raise TypeError("`division` must be a list, tuple, or dict.") + + +def _check_division_config_values_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + if isinstance(division, (list, tuple)): + if not all(0 < x <= 1 for x in division): + raise ValueError( + "All fractions for the division must be greater than 0 and smaller or " + "equal to 1." + ) + fraction_sum_from_list_tuple = sum(division) + if fraction_sum_from_list_tuple > 1: + raise ValueError("Sum of fractions for division must not exceed 1.") + if fraction_sum_from_list_tuple < 1: + warnings.warn( + f"Sum of fractions for division is {sum(division)}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) + elif isinstance(division, dict): + values = list(division.values()) + if not all(0 < x <= 1 for x in values): + raise ValueError( + "All fractions must be greater than 0 and smaller or equal to 1." + ) + if sum(values) > 1: + raise ValueError("Sum of fractions must not exceed 1.") + if sum(values) < 1: + warnings.warn( + f"Sum of fractions in `division` is {values}, which is below 1. " + f"Make sure that's the desired behavior. Some data will not be used " + f"in the current specification.", + stacklevel=1, + ) + else: + raise TypeError("`division` must be a list, tuple, or dict.") + + +def _check_division_config_correctness( + division: Union[List[float], Tuple[float, ...], Dict[str, float]] +) -> None: + _check_division_config_types_correctness(division) + _check_division_config_values_correctness(division) diff --git a/datasets/flwr_datasets/utils_test.py b/datasets/flwr_datasets/utils_test.py new file mode 100644 index 000000000000..26f24519eb76 --- /dev/null +++ b/datasets/flwr_datasets/utils_test.py @@ -0,0 +1,70 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils tests.""" +import unittest +from typing import Dict, List, Tuple, Union + +from parameterized import parameterized_class + +from datasets import Dataset, DatasetDict +from flwr_datasets.utils import divide_dataset + + +@parameterized_class( + ( + "divide", + "sizes", + ), + [ + ((0.2, 0.8), [8, 32]), + ([0.2, 0.8], [8, 32]), + ({"train": 0.2, "test": 0.8}, [8, 32]), + # Not full dataset + ([0.2, 0.1], [8, 4]), + ((0.2, 0.1), [8, 4]), + ({"train": 0.2, "test": 0.1}, [8, 4]), + ], +) +class UtilsTests(unittest.TestCase): + """Utils tests.""" + + divide: Union[List[float], Tuple[float, ...], Dict[str, float]] + sizes: Tuple[int] + + def setUp(self) -> None: + """Set up a dataset.""" + self.dataset = Dataset.from_dict({"data": range(40)}) + + def test_correct_sizes(self) -> None: + """Test correct size of the division.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(divided_dataset, (list, tuple)): + lengths = [len(split) for split in divided_dataset] + else: + lengths = [len(split) for split in divided_dataset.values()] + + self.assertEqual(lengths, self.sizes) + + def test_correct_return_types(self) -> None: + """Test correct types of the divided dataset based on the config.""" + divided_dataset = divide_dataset(self.dataset, self.divide) + if isinstance(self.divide, (list, tuple)): + self.assertIsInstance(divided_dataset, list) + else: + self.assertIsInstance(divided_dataset, DatasetDict) + + +if __name__ == "__main__": + unittest.main() From bde613cb63f0171ba75f8d8101317a41121a5976 Mon Sep 17 00:00:00 2001 From: Yan Gao Date: Tue, 12 Mar 2024 10:36:11 +0000 Subject: [PATCH 12/20] Update required package bitsandbytes==0.41.3 (#3114) --- examples/llm-flowertune/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llm-flowertune/requirements.txt b/examples/llm-flowertune/requirements.txt index e557dbfc2ff8..c7ff57b403f7 100644 --- a/examples/llm-flowertune/requirements.txt +++ b/examples/llm-flowertune/requirements.txt @@ -2,7 +2,7 @@ flwr-nightly[rest,simulation] flwr_datasets==0.0.2 hydra-core==1.3.2 trl==0.7.2 -bitsandbytes==0.40.2 +bitsandbytes==0.41.3 scipy==1.11.2 peft==0.4.0 fschat[model_worker,webui]==0.2.35 From 0fe721b15ec1877c65eb0abdc85514598cb12ad8 Mon Sep 17 00:00:00 2001 From: tabdar-khan <71217662+tabdar-khan@users.noreply.github.com> Date: Tue, 12 Mar 2024 12:28:16 +0100 Subject: [PATCH 13/20] Add validation function for project name to comply with PEP 621 and PEP 503 (#3111) --- src/py/flwr/common/pyproject.py | 41 ++++++++++ src/py/flwr/common/pyproject_test.py | 108 +++++++++++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 src/py/flwr/common/pyproject.py create mode 100644 src/py/flwr/common/pyproject_test.py diff --git a/src/py/flwr/common/pyproject.py b/src/py/flwr/common/pyproject.py new file mode 100644 index 000000000000..66585e422397 --- /dev/null +++ b/src/py/flwr/common/pyproject.py @@ -0,0 +1,41 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Validates the project's name property.""" + +import re + + +def validate_project_name(name: str) -> bool: + """Validate the project name against PEP 621 and PEP 503 specifications. + + Conventions at a glance: + - Must be lowercase + - Must not contain special characters + - Must use hyphens(recommended) or underscores. No spaces. + - Recommended to be no more than 40 characters long (But it can be) + + Parameters + ---------- + name : str + The project name to validate. + + Returns + ------- + bool + True if the name is valid, False otherwise. + """ + if not name or len(name) > 40 or not re.match(r"^[a-z0-9-_]+$", name): + return False + return True diff --git a/src/py/flwr/common/pyproject_test.py b/src/py/flwr/common/pyproject_test.py new file mode 100644 index 000000000000..88a945054b83 --- /dev/null +++ b/src/py/flwr/common/pyproject_test.py @@ -0,0 +1,108 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the function that validates name property.""" + +from .pyproject import validate_project_name + + +# Happy Flow +def test_valid_name_with_lower_case() -> None: + """Test a valid single-word project name with all lower case.""" + # Prepare + name = "myproject" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_valid_name_with_dashes() -> None: + """Test a valid project name with hyphens inbetween.""" + # Prepare + name = "valid-project-name" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_valid_name_with_underscores() -> None: + """Test a valid project name with underscores inbetween.""" + # Prepare + name = "valid_project_name" + expected = True + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, f"Expected {name} to be valid" + + +def test_invalid_name_with_upper_letters() -> None: + """Tests a project name with Spaces and Uppercase letter.""" + # Prepare + name = "Invalid Project Name" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Upper Case and Spaces are not allowed" + + +def test_name_with_spaces() -> None: + """Tests a project name with spaces inbetween.""" + # Prepare + name = "name with spaces" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Spaces are not allowed" + + +def test_empty_name() -> None: + """Tests use-case for an empty project name.""" + # Prepare + name = "" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Empty name is not valid" + + +def test_long_name() -> None: + """Tests for long project names.""" + # Prepare + name = "a" * 41 + expected = False + # Execute + actual = validate_project_name(name) + # Assert + # It can be more than 40 but generally + # it is recommended not to be more than 40 + assert actual == expected, "Name longer than 40 characters is not recommended" + + +def test_name_with_special_characters() -> None: + """Tests for project names with special characters.""" + # Prepare + name = "name!@#" + expected = False + # Execute + actual = validate_project_name(name) + # Assert + assert actual == expected, "Special characters are not allowed" From ad627835d5182265f46e4b2c8e48b51c344703d0 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Mar 2024 11:58:36 +0000 Subject: [PATCH 14/20] Set correct `group_id` in `DefaultWorkflow` (#3115) --- src/py/flwr/server/workflow/default_workflows.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/workflow/default_workflows.py b/src/py/flwr/server/workflow/default_workflows.py index fad85d8eecf8..a5c726b0b191 100644 --- a/src/py/flwr/server/workflow/default_workflows.py +++ b/src/py/flwr/server/workflow/default_workflows.py @@ -127,7 +127,7 @@ def default_init_params_workflow(driver: Driver, context: Context) -> None: content=content, message_type=MessageTypeLegacy.GET_PARAMETERS, dst_node_id=random_client.node_id, - group_id="", + group_id="0", ttl="", ) ] @@ -226,7 +226,7 @@ def default_fit_workflow( # pylint: disable=R0914 content=compat.fitins_to_recordset(fitins, True), message_type=MessageType.TRAIN, dst_node_id=proxy.node_id, - group_id="", + group_id=str(current_round), ttl="", ) for proxy, fitins in client_instructions @@ -306,7 +306,7 @@ def default_evaluate_workflow(driver: Driver, context: Context) -> None: content=compat.evaluateins_to_recordset(evalins, True), message_type=MessageType.EVALUATE, dst_node_id=proxy.node_id, - group_id="", + group_id=str(current_round), ttl="", ) for proxy, evalins in client_instructions From 25b797dd97e30722202bba849779014771542404 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Mar 2024 17:38:16 +0000 Subject: [PATCH 15/20] Change the log-level to `DEBUG` for logs in `secaggplus_mod` (#3119) --- .../mod/secure_aggregation/secaggplus_mod.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index ed0f8f4fd7b5..3e33438c9ddc 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -17,7 +17,7 @@ import os from dataclasses import dataclass, field -from logging import INFO, WARNING +from logging import DEBUG, WARNING from typing import Any, Callable, Dict, List, Tuple, cast from flwr.client.typing import ClientAppCallable @@ -322,7 +322,7 @@ def _setup( # Assigning parameter values to object fields sec_agg_param_dict = configs state.sample_num = cast(int, sec_agg_param_dict[Key.SAMPLE_NUMBER]) - log(INFO, "Node %d: starting stage 0...", state.nid) + log(DEBUG, "Node %d: starting stage 0...", state.nid) state.share_num = cast(int, sec_agg_param_dict[Key.SHARE_NUMBER]) state.threshold = cast(int, sec_agg_param_dict[Key.THRESHOLD]) @@ -347,7 +347,7 @@ def _setup( state.sk1, state.pk1 = private_key_to_bytes(sk1), public_key_to_bytes(pk1) state.sk2, state.pk2 = private_key_to_bytes(sk2), public_key_to_bytes(pk2) - log(INFO, "Node %d: stage 0 completes. uploading public keys...", state.nid) + log(DEBUG, "Node %d: stage 0 completes. uploading public keys...", state.nid) return {Key.PUBLIC_KEY_1: state.pk1, Key.PUBLIC_KEY_2: state.pk2} @@ -357,7 +357,7 @@ def _share_keys( ) -> Dict[str, ConfigsRecordValues]: named_bytes_tuples = cast(Dict[str, Tuple[bytes, bytes]], configs) key_dict = {int(sid): (pk1, pk2) for sid, (pk1, pk2) in named_bytes_tuples.items()} - log(INFO, "Node %d: starting stage 1...", state.nid) + log(DEBUG, "Node %d: starting stage 1...", state.nid) state.public_keys_dict = key_dict # Check if the size is larger than threshold @@ -409,7 +409,7 @@ def _share_keys( dsts.append(nid) ciphertexts.append(ciphertext) - log(INFO, "Node %d: stage 1 completes. uploading key shares...", state.nid) + log(DEBUG, "Node %d: stage 1 completes. uploading key shares...", state.nid) return {Key.DESTINATION_LIST: dsts, Key.CIPHERTEXT_LIST: ciphertexts} @@ -419,7 +419,7 @@ def _collect_masked_vectors( configs: ConfigsRecord, fit: Callable[[], FitRes], ) -> Dict[str, ConfigsRecordValues]: - log(INFO, "Node %d: starting stage 2...", state.nid) + log(DEBUG, "Node %d: starting stage 2...", state.nid) available_clients: List[int] = [] ciphertexts = cast(List[bytes], configs[Key.CIPHERTEXT_LIST]) srcs = cast(List[int], configs[Key.SOURCE_LIST]) @@ -500,7 +500,7 @@ def _collect_masked_vectors( # Take mod of final weight update vector and return to server quantized_parameters = parameters_mod(quantized_parameters, state.mod_range) - log(INFO, "Node %d: stage 2 completed, uploading masked parameters...", state.nid) + log(DEBUG, "Node %d: stage 2 completed, uploading masked parameters...", state.nid) return { Key.MASKED_PARAMETERS: [ndarray_to_bytes(arr) for arr in quantized_parameters] } @@ -509,7 +509,7 @@ def _collect_masked_vectors( def _unmask( state: SecAggPlusState, configs: ConfigsRecord ) -> Dict[str, ConfigsRecordValues]: - log(INFO, "Node %d: starting stage 3...", state.nid) + log(DEBUG, "Node %d: starting stage 3...", state.nid) active_nids = cast(List[int], configs[Key.ACTIVE_NODE_ID_LIST]) dead_nids = cast(List[int], configs[Key.DEAD_NODE_ID_LIST]) @@ -523,5 +523,5 @@ def _unmask( shares += [state.rd_seed_share_dict[nid] for nid in active_nids] shares += [state.sk1_share_dict[nid] for nid in dead_nids] - log(INFO, "Node %d: stage 3 completes. uploading key shares...", state.nid) + log(DEBUG, "Node %d: stage 3 completes. uploading key shares...", state.nid) return {Key.NODE_ID_LIST: all_nids, Key.SHARE_LIST: shares} From 5866311f8d20ce8c7e113b5486cf476dd9be09e3 Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Mar 2024 21:00:00 +0000 Subject: [PATCH 16/20] Fix the module doc string of `secaggplus_mod` (#3106) --- src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py | 2 +- .../flwr/client/mod/secure_aggregation/secaggplus_mod_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 3e33438c9ddc..7d965cb031cb 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Message handler for the SecAgg+ protocol.""" +"""Modifier for the SecAgg+ protocol.""" import os diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py index d72d8b414f65..db5ed67c02a4 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The SecAgg+ protocol handler tests.""" +"""The SecAgg+ protocol modifier tests.""" import unittest from itertools import product From d6f274bf14697b07e45ac6cc254e4635f33e8b6d Mon Sep 17 00:00:00 2001 From: Heng Pan Date: Tue, 12 Mar 2024 21:11:42 +0000 Subject: [PATCH 17/20] Fix `SecAggPlusWorkflow` and `secaggplus_mod` (#3120) --- .../mod/secure_aggregation/secaggplus_mod.py | 51 ++++++++----------- .../secure_aggregation/secaggplus_workflow.py | 43 ++++++++-------- 2 files changed, 41 insertions(+), 53 deletions(-) diff --git a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py index 7d965cb031cb..989d5f6e1361 100644 --- a/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py +++ b/src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py @@ -18,13 +18,14 @@ import os from dataclasses import dataclass, field from logging import DEBUG, WARNING -from typing import Any, Callable, Dict, List, Tuple, cast +from typing import Any, Dict, List, Tuple, cast from flwr.client.typing import ClientAppCallable from flwr.common import ( ConfigsRecord, Context, Message, + Parameters, RecordSet, ndarray_to_bytes, parameters_to_ndarrays, @@ -62,7 +63,7 @@ share_keys_plaintext_concat, share_keys_plaintext_separate, ) -from flwr.common.typing import ConfigsRecordValues, FitRes +from flwr.common.typing import ConfigsRecordValues @dataclass @@ -132,18 +133,6 @@ def to_dict(self) -> Dict[str, ConfigsRecordValues]: return ret -def _get_fit_fn( - msg: Message, ctxt: Context, call_next: ClientAppCallable -) -> Callable[[], FitRes]: - """Get the fit function.""" - - def fit() -> FitRes: - out_msg = call_next(msg, ctxt) - return compat.recordset_to_fitres(out_msg.content, keep_input=False) - - return fit - - def secaggplus_mod( msg: Message, ctxt: Context, @@ -173,25 +162,32 @@ def secaggplus_mod( check_configs(state.current_stage, configs) # Execute + out_content = RecordSet() if state.current_stage == Stage.SETUP: state.nid = msg.metadata.dst_node_id res = _setup(state, configs) elif state.current_stage == Stage.SHARE_KEYS: res = _share_keys(state, configs) elif state.current_stage == Stage.COLLECT_MASKED_VECTORS: - fit = _get_fit_fn(msg, ctxt, call_next) - res = _collect_masked_vectors(state, configs, fit) + out_msg = call_next(msg, ctxt) + out_content = out_msg.content + fitres = compat.recordset_to_fitres(out_content, keep_input=True) + res = _collect_masked_vectors( + state, configs, fitres.num_examples, fitres.parameters + ) + for p_record in out_content.parameters_records.values(): + p_record.clear() elif state.current_stage == Stage.UNMASK: res = _unmask(state, configs) else: - raise ValueError(f"Unknown secagg stage: {state.current_stage}") + raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}") # Save state ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict()) # Return message - content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)}) - return msg.create_reply(content, ttl="") + out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False) + return msg.create_reply(out_content, ttl="") def check_stage(current_stage: str, configs: ConfigsRecord) -> None: @@ -417,7 +413,8 @@ def _share_keys( def _collect_masked_vectors( state: SecAggPlusState, configs: ConfigsRecord, - fit: Callable[[], FitRes], + num_examples: int, + updated_parameters: Parameters, ) -> Dict[str, ConfigsRecordValues]: log(DEBUG, "Node %d: starting stage 2...", state.nid) available_clients: List[int] = [] @@ -447,26 +444,20 @@ def _collect_masked_vectors( state.rd_seed_share_dict[src] = rd_seed_share state.sk1_share_dict[src] = sk1_share - # Fit client - fit_res = fit() - if len(fit_res.metrics) > 0: - log( - WARNING, - "The metrics in FitRes will not be preserved or sent to the server.", - ) - ratio = fit_res.num_examples / state.max_weight + # Fit + ratio = num_examples / state.max_weight if ratio > 1: log( WARNING, "Potential overflow warning: the provided weight (%s) exceeds the specified" " max_weight (%s). This may lead to overflow issues.", - fit_res.num_examples, + num_examples, state.max_weight, ) q_ratio = round(ratio * state.target_range) dq_ratio = q_ratio / state.target_range - parameters = parameters_to_ndarrays(fit_res.parameters) + parameters = parameters_to_ndarrays(updated_parameters) parameters = parameters_multiply(parameters, dq_ratio) # Quantize parameter update (vector) diff --git a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py index 559dc1cf8739..42ee9c15f1cd 100644 --- a/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py +++ b/src/py/flwr/server/workflow/secure_aggregation/secaggplus_workflow.py @@ -18,11 +18,10 @@ import random from dataclasses import dataclass, field from logging import DEBUG, ERROR, INFO, WARN -from typing import Dict, List, Optional, Set, Union, cast +from typing import Dict, List, Optional, Set, Tuple, Union, cast import flwr.common.recordset_compat as compat from flwr.common import ( - Code, ConfigsRecord, Context, FitRes, @@ -30,7 +29,6 @@ MessageType, NDArrays, RecordSet, - Status, bytes_to_ndarray, log, ndarrays_to_parameters, @@ -55,7 +53,7 @@ Stage, ) from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen -from flwr.server.compat.driver_client_proxy import DriverClientProxy +from flwr.server.client_proxy import ClientProxy from flwr.server.compat.legacy_context import LegacyContext from flwr.server.driver import Driver @@ -67,6 +65,7 @@ class WorkflowState: # pylint: disable=R0902 """The state of the SecAgg+ protocol.""" + nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict) nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict) sampled_node_ids: Set[int] = field(default_factory=set) active_node_ids: Set[int] = field(default_factory=set) @@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902 forward_srcs: Dict[int, List[int]] = field(default_factory=dict) forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict) aggregate_ndarrays: NDArrays = field(default_factory=list) + legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list) class SecAggPlusWorkflow: @@ -301,9 +301,10 @@ def setup_stage( # pylint: disable=R0912, R0914, R0915 ) state.nid_to_fitins = { - proxy.node_id: compat.fitins_to_recordset(fitins, False) + proxy.node_id: compat.fitins_to_recordset(fitins, True) for proxy, fitins in proxy_fitins_lst } + state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst} # Protocol config sampled_node_ids = list(state.nid_to_fitins.keys()) @@ -528,6 +529,12 @@ def make(nid: int) -> Message: masked_vector = parameters_mod(masked_vector, state.mod_range) state.aggregate_ndarrays = masked_vector + # Backward compatibility with Strategy + for msg in msgs: + fitres = compat.recordset_to_fitres(msg.content, True) + proxy = state.nid_to_proxies[msg.metadata.src_node_id] + state.legacy_results.append((proxy, fitres)) + return self._check_threshold(state) def unmask_stage( # pylint: disable=R0912, R0914, R0915 @@ -637,31 +644,21 @@ def make(nid: int) -> Message: for vec in aggregated_vector: vec += offset vec *= inv_dq_total_ratio - state.aggregate_ndarrays = aggregated_vector + + # Backward compatibility with Strategy + results = state.legacy_results + parameters = ndarrays_to_parameters(aggregated_vector) + for _, fitres in results: + fitres.parameters = parameters # No exception/failure handling currently log( INFO, "aggregate_fit: received %s results and %s failures", - 1, - 0, - ) - - final_fitres = FitRes( - status=Status(code=Code.OK, message=""), - parameters=ndarrays_to_parameters(aggregated_vector), - num_examples=round(state.max_weight / inv_dq_total_ratio), - metrics={}, - ) - empty_proxy = DriverClientProxy( + len(results), 0, - driver.grpc_driver, # type: ignore - False, - driver.run_id, # type: ignore - ) - aggregated_result = context.strategy.aggregate_fit( - current_round, [(empty_proxy, final_fitres)], [] ) + aggregated_result = context.strategy.aggregate_fit(current_round, results, []) parameters_aggregated, metrics_aggregated = aggregated_result # Update the parameters and write history From 1057001fc05ace6dcb87b373ff251bc870f7fc72 Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Tue, 12 Mar 2024 22:36:40 +0100 Subject: [PATCH 18/20] Fds add num_partitions property to partitioners (#3095) * Add num_partition property * Trigger the partitioning in the num_partitions --------- Co-authored-by: Daniel J. Beutel --- .../partitioner/dirichlet_partitioner.py | 7 +++++++ datasets/flwr_datasets/partitioner/iid_partitioner.py | 5 +++++ .../partitioner/inner_dirichlet_partitioner.py | 11 +++++++++++ .../partitioner/natural_id_partitioner.py | 7 +++++++ datasets/flwr_datasets/partitioner/partitioner.py | 5 +++++ .../flwr_datasets/partitioner/shard_partitioner.py | 9 +++++++++ .../flwr_datasets/partitioner/size_partitioner.py | 6 ++++++ 7 files changed, 50 insertions(+) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 5f1df71991bb..5271aad74a1e 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -132,6 +132,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _initialize_alpha( self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: diff --git a/datasets/flwr_datasets/partitioner/iid_partitioner.py b/datasets/flwr_datasets/partitioner/iid_partitioner.py index c72b34f081f2..faa1dfa10615 100644 --- a/datasets/flwr_datasets/partitioner/iid_partitioner.py +++ b/datasets/flwr_datasets/partitioner/iid_partitioner.py @@ -50,3 +50,8 @@ def load_partition(self, node_id: int) -> datasets.Dataset: return self.dataset.shard( num_shards=self._num_partitions, index=node_id, contiguous=True ) + + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + return self._num_partitions diff --git a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py index c25a9b059d18..bf07ab3591f5 100644 --- a/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/inner_dirichlet_partitioner.py @@ -119,6 +119,17 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._check_partition_sizes_correctness_if_needed() + self._check_the_sum_of_partition_sizes() + self._determine_num_unique_classes_if_needed() + self._alpha = self._initialize_alpha_if_needed(self._initial_alpha) + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _initialize_alpha_if_needed( self, alpha: Union[int, float, List[float], NDArrayFloat] ) -> NDArrayFloat: diff --git a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py index b8f28696f3b7..947501965cc6 100644 --- a/datasets/flwr_datasets/partitioner/natural_id_partitioner.py +++ b/datasets/flwr_datasets/partitioner/natural_id_partitioner.py @@ -65,6 +65,13 @@ def load_partition(self, node_id: int) -> datasets.Dataset: lambda row: row[self._partition_by] == self._node_id_to_natural_id[node_id] ) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + if len(self._node_id_to_natural_id) == 0: + self._create_int_node_id_to_natural_id() + return len(self._node_id_to_natural_id) + @property def node_id_to_natural_id(self) -> Dict[int, str]: """Node id to corresponding natural id present. diff --git a/datasets/flwr_datasets/partitioner/partitioner.py b/datasets/flwr_datasets/partitioner/partitioner.py index 92405152efc6..73eb6f4a17b3 100644 --- a/datasets/flwr_datasets/partitioner/partitioner.py +++ b/datasets/flwr_datasets/partitioner/partitioner.py @@ -79,3 +79,8 @@ def is_dataset_assigned(self) -> bool: True if a dataset is assigned, otherwise False. """ return self._dataset is not None + + @property + @abstractmethod + def num_partitions(self) -> int: + """Total number of partitions.""" diff --git a/datasets/flwr_datasets/partitioner/shard_partitioner.py b/datasets/flwr_datasets/partitioner/shard_partitioner.py index 7c86570fe487..05444f537c8c 100644 --- a/datasets/flwr_datasets/partitioner/shard_partitioner.py +++ b/datasets/flwr_datasets/partitioner/shard_partitioner.py @@ -179,6 +179,15 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._check_num_partitions_correctness_if_needed() + self._check_possibility_of_partitions_creation() + self._sort_dataset_if_needed() + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + def _determine_node_id_to_indices_if_needed(self) -> None: # pylint: disable=R0914 """Assign sample indices to each node id. diff --git a/datasets/flwr_datasets/partitioner/size_partitioner.py b/datasets/flwr_datasets/partitioner/size_partitioner.py index 35ca750949ee..29fc2e5b1add 100644 --- a/datasets/flwr_datasets/partitioner/size_partitioner.py +++ b/datasets/flwr_datasets/partitioner/size_partitioner.py @@ -84,6 +84,12 @@ def load_partition(self, node_id: int) -> datasets.Dataset: self._determine_node_id_to_indices_if_needed() return self.dataset.select(self._node_id_to_indices[node_id]) + @property + def num_partitions(self) -> int: + """Total number of partitions.""" + self._determine_node_id_to_indices_if_needed() + return self._num_partitions + @property def node_id_to_size(self) -> Dict[int, int]: """Node id to the number of samples.""" From 9e7e4a8035911fb2307045aa19fce280f74906b7 Mon Sep 17 00:00:00 2001 From: Charles Beauville Date: Wed, 13 Mar 2024 12:21:46 +0000 Subject: [PATCH 19/20] Improve formatting for changelog generating script (#3122) --- src/py/flwr_tool/update_changelog.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/py/flwr_tool/update_changelog.py b/src/py/flwr_tool/update_changelog.py index a158cca21765..e3cffff7e36c 100644 --- a/src/py/flwr_tool/update_changelog.py +++ b/src/py/flwr_tool/update_changelog.py @@ -50,7 +50,7 @@ def _get_pull_requests_since_tag(gh_api, tag): def _format_pr_reference(title, number, url): """Format a pull request reference as a markdown list item.""" - return f"- **{title}** ([#{number}]({url}))" + return f"- **{title.replace('*', '')}** ([#{number}]({url}))" def _extract_changelog_entry(pr_info): @@ -193,11 +193,24 @@ def _insert_new_entry(content, pr_info, pr_reference, pr_entry_text, unreleased_ content = content[:pr_ref_end] + updated_entry + content[existing_entry_start:] else: insert_index = content.find("\n", unreleased_index) + 1 + + # Split the pr_entry_text into paragraphs + paragraphs = pr_entry_text.split("\n") + + # Indent each paragraph + indented_paragraphs = [ + " " + paragraph if paragraph else paragraph for paragraph in paragraphs + ] + + # Join the paragraphs back together, ensuring each is separated by a newline + indented_pr_entry_text = "\n".join(indented_paragraphs) + content = ( content[:insert_index] + + "\n" + pr_reference - + "\n " - + pr_entry_text + + "\n\n" + + indented_pr_entry_text + "\n" + content[insert_index:] ) From 930cdafe2405cd0f5064d33664aad9ab6869a23c Mon Sep 17 00:00:00 2001 From: Adam Narozniak <51029327+adam-narozniak@users.noreply.github.com> Date: Wed, 13 Mar 2024 16:00:45 +0100 Subject: [PATCH 20/20] Change the `self_balancing` to `False` (#3123) --- datasets/flwr_datasets/partitioner/dirichlet_partitioner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py index 5271aad74a1e..cb23acea01b6 100644 --- a/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py +++ b/datasets/flwr_datasets/partitioner/dirichlet_partitioner.py @@ -89,7 +89,7 @@ def __init__( # pylint: disable=R0913 partition_by: str, alpha: Union[int, float, List[float], NDArrayFloat], min_partition_size: int = 10, - self_balancing: bool = True, + self_balancing: bool = False, shuffle: bool = True, seed: Optional[int] = 42, ) -> None: