Skip to content

Commit

Permalink
Merge branch 'main' into retire-xgboost-horizontal
Browse files Browse the repository at this point in the history
  • Loading branch information
danieljanes authored Jan 17, 2024
2 parents 2bfacdc + 66b3bbe commit af55e6b
Show file tree
Hide file tree
Showing 31 changed files with 460 additions and 86 deletions.
22 changes: 22 additions & 0 deletions doc/source/how-to-install-flower.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Flower requires at least `Python 3.8 <https://docs.python.org/3.8/>`_, but `Pyth
Install stable release
----------------------

Using pip
~~~~~~~~~

Stable releases are available on `PyPI <https://pypi.org/project/flwr/>`_::

python -m pip install flwr
Expand All @@ -20,6 +23,25 @@ For simulations that use the Virtual Client Engine, ``flwr`` should be installed
python -m pip install flwr[simulation]


Using conda (or mamba)
~~~~~~~~~~~~~~~~~~~~~~

Flower can also be installed from the ``conda-forge`` channel.

If you have not added ``conda-forge`` to your channels, you will first need to run the following::

conda config --add channels conda-forge
conda config --set channel_priority strict

Once the ``conda-forge`` channel has been enabled, ``flwr`` can be installed with ``conda``::

conda install flwr

or with ``mamba``::

mamba install flwr


Verify installation
-------------------

Expand Down
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ isort = "==5.12.0"
black = { version = "==23.10.1", extras = ["jupyter"] }
docformatter = "==1.7.5"
mypy = "==1.6.1"
pylint = "==2.13.9"
pylint = "==3.0.3"
flake8 = "==5.0.4"
pytest = "==7.4.3"
pytest-cov = "==4.1.0"
Expand Down Expand Up @@ -137,7 +137,7 @@ line-length = 88
target-version = ["py38", "py39", "py310", "py311"]

[tool.pylint."MESSAGES CONTROL"]
disable = "bad-continuation,duplicate-code,too-few-public-methods,useless-import-alias"
disable = "duplicate-code,too-few-public-methods,useless-import-alias"

[tool.pytest.ini_options]
minversion = "6.2"
Expand Down Expand Up @@ -184,7 +184,7 @@ target-version = "py38"
line-length = 88
select = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
fixable = ["D", "E", "F", "W", "B", "ISC", "C4", "UP"]
ignore = ["B024", "B027"]
ignore = ["B024", "B027", "D205", "D209"]
exclude = [
".bzr",
".direnv",
Expand Down
9 changes: 6 additions & 3 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,12 @@ def _check_actionable_client(
client: Optional[Client], client_fn: Optional[ClientFn]
) -> None:
if client_fn is None and client is None:
raise Exception("Both `client_fn` and `client` are `None`, but one is required")
raise ValueError(
"Both `client_fn` and `client` are `None`, but one is required"
)

if client_fn is not None and client is not None:
raise Exception(
raise ValueError(
"Both `client_fn` and `client` are provided, but only one is allowed"
)

Expand All @@ -150,6 +152,7 @@ def _check_actionable_client(
# pylint: disable=too-many-branches
# pylint: disable=too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=too-many-arguments
def start_client(
*,
server_address: str,
Expand Down Expand Up @@ -299,7 +302,7 @@ def single_client_factory(
cid: str, # pylint: disable=unused-argument
) -> Client:
if client is None: # Added this to keep mypy happy
raise Exception(
raise ValueError(
"Both `client_fn` and `client` are `None`, but one is required"
)
return client # Always return the same instance
Expand Down
16 changes: 8 additions & 8 deletions src/py/flwr/client/app_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,43 @@ class PlainClient(Client):

def get_properties(self, ins: GetPropertiesIns) -> GetPropertiesRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def get_parameters(self, ins: GetParametersIns) -> GetParametersRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def fit(self, ins: FitIns) -> FitRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def evaluate(self, ins: EvaluateIns) -> EvaluateRes:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()


class NeedsWrappingClient(NumPyClient):
"""Client implementation extending the high-level NumPyClient."""

def get_properties(self, config: Config) -> Dict[str, Scalar]:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def get_parameters(self, config: Config) -> NDArrays:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def fit(
self, parameters: NDArrays, config: Config
) -> Tuple[NDArrays, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()

def evaluate(
self, parameters: NDArrays, config: Config
) -> Tuple[float, int, Dict[str, Scalar]]:
"""Raise an Exception because this method is not expected to be called."""
raise Exception()
raise NotImplementedError()


def test_to_client_with_client() -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/client/dpfedavg_numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ def fit(
update = [np.subtract(x, y) for (x, y) in zip(updated_params, original_params)]

if "dpfedavg_clip_norm" not in config:
raise Exception("Clipping threshold not supplied by the server.")
raise KeyError("Clipping threshold not supplied by the server.")
if not isinstance(config["dpfedavg_clip_norm"], float):
raise Exception("Clipping threshold should be a floating point value.")
raise TypeError("Clipping threshold should be a floating point value.")

# Clipping
update, clipped = clip_by_l2(update, config["dpfedavg_clip_norm"])

if "dpfedavg_noise_stddev" in config:
if not isinstance(config["dpfedavg_noise_stddev"], float):
raise Exception(
raise TypeError(
"Scale of noise to be added should be a floating point value."
)
# Noising
Expand All @@ -138,7 +138,7 @@ def fit(
# Calculating value of norm indicator bit, required for adaptive clipping
if "dpfedavg_adaptive_clip_enabled" in config:
if not isinstance(config["dpfedavg_adaptive_clip_enabled"], bool):
raise Exception(
raise TypeError(
"dpfedavg_adaptive_clip_enabled should be a boolean-valued flag."
)
metrics["dpfedavg_norm_bit"] = not clipped
Expand Down
3 changes: 1 addition & 2 deletions src/py/flwr/client/message_handler/task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ def validate_task_res(task_res: TaskRes) -> bool:
initialized_fields_in_task = {field.name for field, _ in task_res.task.ListFields()}

# Check if certain fields are already initialized
# pylint: disable-next=too-many-boolean-expressions
if (
if ( # pylint: disable-next=too-many-boolean-expressions
"task_id" in initialized_fields_in_task_res
or "group_id" in initialized_fields_in_task_res
or "run_id" in initialized_fields_in_task_res
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/client/numpy_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def _fit(self: Client, ins: FitIns) -> FitRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_FIT)

# Return FitRes
parameters_prime, num_examples, metrics = results
Expand All @@ -266,7 +266,7 @@ def _evaluate(self: Client, ins: EvaluateIns) -> EvaluateRes:
and isinstance(results[1], int)
and isinstance(results[2], dict)
):
raise Exception(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)
raise TypeError(EXCEPTION_MESSAGE_WRONG_RETURN_TYPE_EVALUATE)

# Return EvaluateRes
loss, num_examples, metrics = results
Expand Down
4 changes: 4 additions & 0 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def create_node() -> None:
},
data=create_node_req_bytes,
verify=verify,
timeout=None,
)

# Check status code and headers
Expand Down Expand Up @@ -185,6 +186,7 @@ def delete_node() -> None:
},
data=delete_node_req_req_bytes,
verify=verify,
timeout=None,
)

# Check status code and headers
Expand Down Expand Up @@ -225,6 +227,7 @@ def receive() -> Optional[TaskIns]:
},
data=pull_task_ins_req_bytes,
verify=verify,
timeout=None,
)

# Check status code and headers
Expand Down Expand Up @@ -303,6 +306,7 @@ def send(task_res: TaskRes) -> None:
},
data=push_task_res_request_bytes,
verify=verify,
timeout=None,
)

state[KEY_TASK_INS] = None
Expand Down
12 changes: 6 additions & 6 deletions src/py/flwr/client/secure_aggregation/secaggplus_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,22 +333,22 @@ def _share_keys(

# Check if the size is larger than threshold
if len(state.public_keys_dict) < state.threshold:
raise Exception("Available neighbours number smaller than threshold")
raise ValueError("Available neighbours number smaller than threshold")

# Check if all public keys are unique
pk_list: List[bytes] = []
for pk1, pk2 in state.public_keys_dict.values():
pk_list.append(pk1)
pk_list.append(pk2)
if len(set(pk_list)) != len(pk_list):
raise Exception("Some public keys are identical")
raise ValueError("Some public keys are identical")

# Check if public keys of this client are correct in the dictionary
if (
state.public_keys_dict[state.sid][0] != state.pk1
or state.public_keys_dict[state.sid][1] != state.pk2
):
raise Exception(
raise ValueError(
"Own public keys are displayed in dict incorrectly, should not happen!"
)

Expand Down Expand Up @@ -393,7 +393,7 @@ def _collect_masked_input(
ciphertexts = cast(List[bytes], named_values[KEY_CIPHERTEXT_LIST])
srcs = cast(List[int], named_values[KEY_SOURCE_LIST])
if len(ciphertexts) + 1 < state.threshold:
raise Exception("Not enough available neighbour clients.")
raise ValueError("Not enough available neighbour clients.")

# Decrypt ciphertexts, verify their sources, and store shares.
for src, ciphertext in zip(srcs, ciphertexts):
Expand All @@ -409,7 +409,7 @@ def _collect_masked_input(
f"from {actual_src} instead of {src}."
)
if dst != state.sid:
ValueError(
raise ValueError(
f"Client {state.sid}: received an encrypted message"
f"for Client {dst} from Client {src}."
)
Expand Down Expand Up @@ -476,7 +476,7 @@ def _unmask(state: SecAggPlusState, named_values: Dict[str, Value]) -> Dict[str,
# Send private mask seed share for every avaliable client (including itclient)
# Send first private key share for building pairwise mask for every dropped client
if len(active_sids) < state.threshold:
raise Exception("Available neighbours number smaller than threshold")
raise ValueError("Available neighbours number smaller than threshold")

sids, shares = [], []
sids += active_sids
Expand Down
110 changes: 110 additions & 0 deletions src/py/flwr/common/parametersrecord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ParametersRecord and Array."""


from dataclasses import dataclass, field
from typing import List, Optional, OrderedDict


@dataclass
class Array:
"""Array type.
A dataclass containing serialized data from an array-like or tensor-like object
along with some metadata about it.
Parameters
----------
dtype : str
A string representing the data type of the serialised object (e.g. `np.float32`)
shape : List[int]
A list representing the shape of the unserialized array-like object. This is
used to deserialize the data (depending on the serialization method) or simply
as a metadata field.
stype : str
A string indicating the type of serialisation mechanism used to generate the
bytes in `data` from an array-like or tensor-like object.
data: bytes
A buffer of bytes containing the data.
"""

dtype: str
shape: List[int]
stype: str
data: bytes


@dataclass
class ParametersRecord:
"""Parameters record.
A dataclass storing named Arrays in order. This means that it holds entries as an
OrderedDict[str, Array]. ParametersRecord objects can be viewed as an equivalent to
PyTorch's state_dict, but holding serialised tensors instead.
"""

keep_input: bool
data: OrderedDict[str, Array] = field(default_factory=OrderedDict[str, Array])

def __init__(
self,
array_dict: Optional[OrderedDict[str, Array]] = None,
keep_input: bool = False,
) -> None:
"""Construct a ParametersRecord object.
Parameters
----------
array_dict : Optional[OrderedDict[str, Array]]
A dictionary that stores serialized array-like or tensor-like objects.
keep_input : bool (default: False)
A boolean indicating whether parameters should be deleted from the input
dictionary immediately after adding them to the record. If False, the
dictionary passed to `set_parameters()` will be empty once exiting from that
function. This is the desired behaviour when working with very large
models/tensors/arrays. However, if you plan to continue working with your
parameters after adding it to the record, set this flag to True. When set
to True, the data is duplicated in memory.
"""
self.keep_input = keep_input
self.data = OrderedDict()
if array_dict:
self.set_parameters(array_dict)

def set_parameters(self, array_dict: OrderedDict[str, Array]) -> None:
"""Add parameters to record.
Parameters
----------
array_dict : OrderedDict[str, Array]
A dictionary that stores serialized array-like or tensor-like objects.
"""
if any(not isinstance(k, str) for k in array_dict.keys()):
raise TypeError(f"Not all keys are of valid type. Expected {str}")
if any(not isinstance(v, Array) for v in array_dict.values()):
raise TypeError(f"Not all values are of valid type. Expected {Array}")

if self.keep_input:
# Copy
self.data = OrderedDict(array_dict)
else:
# Add entries to dataclass without duplicating memory
for key in list(array_dict.keys()):
self.data[key] = array_dict[key]
del array_dict[key]
Loading

0 comments on commit af55e6b

Please sign in to comment.