Skip to content

Commit

Permalink
Fix SecAggPlusWorkflow and secaggplus_mod (#3120)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Mar 12, 2024
1 parent 5866311 commit d6f274b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 53 deletions.
51 changes: 21 additions & 30 deletions src/py/flwr/client/mod/secure_aggregation/secaggplus_mod.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
import os
from dataclasses import dataclass, field
from logging import DEBUG, WARNING
from typing import Any, Callable, Dict, List, Tuple, cast
from typing import Any, Dict, List, Tuple, cast

from flwr.client.typing import ClientAppCallable
from flwr.common import (
ConfigsRecord,
Context,
Message,
Parameters,
RecordSet,
ndarray_to_bytes,
parameters_to_ndarrays,
Expand Down Expand Up @@ -62,7 +63,7 @@
share_keys_plaintext_concat,
share_keys_plaintext_separate,
)
from flwr.common.typing import ConfigsRecordValues, FitRes
from flwr.common.typing import ConfigsRecordValues


@dataclass
Expand Down Expand Up @@ -132,18 +133,6 @@ def to_dict(self) -> Dict[str, ConfigsRecordValues]:
return ret


def _get_fit_fn(
msg: Message, ctxt: Context, call_next: ClientAppCallable
) -> Callable[[], FitRes]:
"""Get the fit function."""

def fit() -> FitRes:
out_msg = call_next(msg, ctxt)
return compat.recordset_to_fitres(out_msg.content, keep_input=False)

return fit


def secaggplus_mod(
msg: Message,
ctxt: Context,
Expand Down Expand Up @@ -173,25 +162,32 @@ def secaggplus_mod(
check_configs(state.current_stage, configs)

# Execute
out_content = RecordSet()
if state.current_stage == Stage.SETUP:
state.nid = msg.metadata.dst_node_id
res = _setup(state, configs)
elif state.current_stage == Stage.SHARE_KEYS:
res = _share_keys(state, configs)
elif state.current_stage == Stage.COLLECT_MASKED_VECTORS:
fit = _get_fit_fn(msg, ctxt, call_next)
res = _collect_masked_vectors(state, configs, fit)
out_msg = call_next(msg, ctxt)
out_content = out_msg.content
fitres = compat.recordset_to_fitres(out_content, keep_input=True)
res = _collect_masked_vectors(
state, configs, fitres.num_examples, fitres.parameters
)
for p_record in out_content.parameters_records.values():
p_record.clear()
elif state.current_stage == Stage.UNMASK:
res = _unmask(state, configs)
else:
raise ValueError(f"Unknown secagg stage: {state.current_stage}")
raise ValueError(f"Unknown SecAgg/SecAgg+ stage: {state.current_stage}")

# Save state
ctxt.state.configs_records[RECORD_KEY_STATE] = ConfigsRecord(state.to_dict())

# Return message
content = RecordSet(configs_records={RECORD_KEY_CONFIGS: ConfigsRecord(res, False)})
return msg.create_reply(content, ttl="")
out_content.configs_records[RECORD_KEY_CONFIGS] = ConfigsRecord(res, False)
return msg.create_reply(out_content, ttl="")


def check_stage(current_stage: str, configs: ConfigsRecord) -> None:
Expand Down Expand Up @@ -417,7 +413,8 @@ def _share_keys(
def _collect_masked_vectors(
state: SecAggPlusState,
configs: ConfigsRecord,
fit: Callable[[], FitRes],
num_examples: int,
updated_parameters: Parameters,
) -> Dict[str, ConfigsRecordValues]:
log(DEBUG, "Node %d: starting stage 2...", state.nid)
available_clients: List[int] = []
Expand Down Expand Up @@ -447,26 +444,20 @@ def _collect_masked_vectors(
state.rd_seed_share_dict[src] = rd_seed_share
state.sk1_share_dict[src] = sk1_share

# Fit client
fit_res = fit()
if len(fit_res.metrics) > 0:
log(
WARNING,
"The metrics in FitRes will not be preserved or sent to the server.",
)
ratio = fit_res.num_examples / state.max_weight
# Fit
ratio = num_examples / state.max_weight
if ratio > 1:
log(
WARNING,
"Potential overflow warning: the provided weight (%s) exceeds the specified"
" max_weight (%s). This may lead to overflow issues.",
fit_res.num_examples,
num_examples,
state.max_weight,
)
q_ratio = round(ratio * state.target_range)
dq_ratio = q_ratio / state.target_range

parameters = parameters_to_ndarrays(fit_res.parameters)
parameters = parameters_to_ndarrays(updated_parameters)
parameters = parameters_multiply(parameters, dq_ratio)

# Quantize parameter update (vector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@
import random
from dataclasses import dataclass, field
from logging import DEBUG, ERROR, INFO, WARN
from typing import Dict, List, Optional, Set, Union, cast
from typing import Dict, List, Optional, Set, Tuple, Union, cast

import flwr.common.recordset_compat as compat
from flwr.common import (
Code,
ConfigsRecord,
Context,
FitRes,
Message,
MessageType,
NDArrays,
RecordSet,
Status,
bytes_to_ndarray,
log,
ndarrays_to_parameters,
Expand All @@ -55,7 +53,7 @@
Stage,
)
from flwr.common.secure_aggregation.secaggplus_utils import pseudo_rand_gen
from flwr.server.compat.driver_client_proxy import DriverClientProxy
from flwr.server.client_proxy import ClientProxy
from flwr.server.compat.legacy_context import LegacyContext
from flwr.server.driver import Driver

Expand All @@ -67,6 +65,7 @@
class WorkflowState: # pylint: disable=R0902
"""The state of the SecAgg+ protocol."""

nid_to_proxies: Dict[int, ClientProxy] = field(default_factory=dict)
nid_to_fitins: Dict[int, RecordSet] = field(default_factory=dict)
sampled_node_ids: Set[int] = field(default_factory=set)
active_node_ids: Set[int] = field(default_factory=set)
Expand All @@ -81,6 +80,7 @@ class WorkflowState: # pylint: disable=R0902
forward_srcs: Dict[int, List[int]] = field(default_factory=dict)
forward_ciphertexts: Dict[int, List[bytes]] = field(default_factory=dict)
aggregate_ndarrays: NDArrays = field(default_factory=list)
legacy_results: List[Tuple[ClientProxy, FitRes]] = field(default_factory=list)


class SecAggPlusWorkflow:
Expand Down Expand Up @@ -301,9 +301,10 @@ def setup_stage( # pylint: disable=R0912, R0914, R0915
)

state.nid_to_fitins = {
proxy.node_id: compat.fitins_to_recordset(fitins, False)
proxy.node_id: compat.fitins_to_recordset(fitins, True)
for proxy, fitins in proxy_fitins_lst
}
state.nid_to_proxies = {proxy.node_id: proxy for proxy, _ in proxy_fitins_lst}

# Protocol config
sampled_node_ids = list(state.nid_to_fitins.keys())
Expand Down Expand Up @@ -528,6 +529,12 @@ def make(nid: int) -> Message:
masked_vector = parameters_mod(masked_vector, state.mod_range)
state.aggregate_ndarrays = masked_vector

# Backward compatibility with Strategy
for msg in msgs:
fitres = compat.recordset_to_fitres(msg.content, True)
proxy = state.nid_to_proxies[msg.metadata.src_node_id]
state.legacy_results.append((proxy, fitres))

return self._check_threshold(state)

def unmask_stage( # pylint: disable=R0912, R0914, R0915
Expand Down Expand Up @@ -637,31 +644,21 @@ def make(nid: int) -> Message:
for vec in aggregated_vector:
vec += offset
vec *= inv_dq_total_ratio
state.aggregate_ndarrays = aggregated_vector

# Backward compatibility with Strategy
results = state.legacy_results
parameters = ndarrays_to_parameters(aggregated_vector)
for _, fitres in results:
fitres.parameters = parameters

# No exception/failure handling currently
log(
INFO,
"aggregate_fit: received %s results and %s failures",
1,
0,
)

final_fitres = FitRes(
status=Status(code=Code.OK, message=""),
parameters=ndarrays_to_parameters(aggregated_vector),
num_examples=round(state.max_weight / inv_dq_total_ratio),
metrics={},
)
empty_proxy = DriverClientProxy(
len(results),
0,
driver.grpc_driver, # type: ignore
False,
driver.run_id, # type: ignore
)
aggregated_result = context.strategy.aggregate_fit(
current_round, [(empty_proxy, final_fitres)], []
)
aggregated_result = context.strategy.aggregate_fit(current_round, results, [])
parameters_aggregated, metrics_aggregated = aggregated_result

# Update the parameters and write history
Expand Down

0 comments on commit d6f274b

Please sign in to comment.