From d0d4ecde858f4ab56b1365acb71e6e4643adca2f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:33:51 +0100 Subject: [PATCH 01/49] Add authentication state and test --- .../crypto/symmetric_encryption.py | 14 ++++ .../authentication/authentication_state.py | 52 ++++++++++++++ .../authentication_state_test.py | 65 +++++++++++++++++ .../authentication/in_memory_auth_state.py | 66 ++++++++++++++++++ .../state/authentication/sqlite_auth_state.py | 69 +++++++++++++++++++ src/py/flwr/server/state/sqlite_state.py | 24 ++++++- 6 files changed, 289 insertions(+), 1 deletion(-) create mode 100644 src/py/flwr/server/state/authentication/authentication_state.py create mode 100644 src/py/flwr/server/state/authentication/authentication_state_test.py create mode 100644 src/py/flwr/server/state/authentication/in_memory_auth_state.py create mode 100644 src/py/flwr/server/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 844a93f3bde9..7b22565e2803 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -98,3 +98,17 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: # The input key must be url safe fernet = Fernet(key) return fernet.decrypt(ciphertext) + +def compute_hmac(key: bytes, message: bytes) -> bytes: + computed_hmac = hmac.HMAC(key, hashes.SHA256()) + computed_hmac.update(message) + return computed_hmac.finalize() + +def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: + computed_hmac = hmac.HMAC(key, hashes.SHA256()) + computed_hmac.update(message) + try: + computed_hmac.verify(hmac_value) + return True + except: + return False diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py new file mode 100644 index 000000000000..4a6fc9a6ab57 --- /dev/null +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -0,0 +1,52 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstract base class AuthenticationState.""" + +import abc +from typing import Set + +class AuthenticationState(abc.ABC): + """Abstract State.""" + @abc.abstractmethod + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + + @abc.abstractmethod + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + + @abc.abstractmethod + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + + @abc.abstractmethod + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + + @abc.abstractmethod + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + + @abc.abstractmethod + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + + @abc.abstractmethod + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + + @abc.abstractmethod + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py new file mode 100644 index 000000000000..473f8e37eb36 --- /dev/null +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -0,0 +1,65 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test for authentication state.""" + +import os +from in_memory_auth_state import InMemoryAuthState +from sqlite_auth_state import SqliteAuthState +from common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + public_key_to_bytes, + generate_shared_key, + verify_hmac, + compute_hmac +) + +def test_client_public_keys() -> None: + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + in_memory_auth_state = InMemoryAuthState() + in_memory_auth_state.store_client_public_keys(public_keys) + + assert in_memory_auth_state.get_client_public_keys == public_keys + +def test_node_id_public_key_pair() -> None: + node_id = int.from_bytes(os.urandom(8), "little", signed=True) + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + in_memory_auth_state = InMemoryAuthState() + in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key + +def test_generate_shared_key() -> None: + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + + assert client_shared_secret == server_shared_secret + +def test_hmac() -> None: + client_keys = generate_key_pairs() + server_keys = generate_key_pairs() + client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) + server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) + message = b"Flower is the future of AI" + + client_compute_hmac = compute_hmac(client_shared_secret, message) + + assert verify_hmac(server_shared_secret, message, client_compute_hmac) + \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py new file mode 100644 index 000000000000..d4b387881962 --- /dev/null +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -0,0 +1,66 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""In-memory Authentication State implementation.""" + +from in_memory_state import InMemoryState +from authentication_state import AuthenticationState +from typing import Dict, Set + +class InMemoryAuthState(AuthenticationState, InMemoryState): + def __init__(self) -> None: + super().__init__() + self.node_id_public_key_dict: Dict[int, bytes] = {} + self.client_public_keys: Set[bytes] = set() + self.server_public_key: bytes = bytes() + self.server_private_key: bytes = bytes() + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + if node_id not in self.node_ids: + raise ValueError(f"Node {node_id} not found") + if node_id in self.node_id_public_key_dict: + raise ValueError(f"Node {node_id} has already assigned a public key") + self.node_id_public_key_dict[node_id] = public_key + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + if node_id in self.node_id_public_key_dict: + return self.node_id_public_key_dict[node_id] + return bytes() + + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + self.server_private_key = private_key + self.server_public_key = public_key + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + return self.server_private_key + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + return self.server_public_key + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py new file mode 100644 index 000000000000..852629896778 --- /dev/null +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -0,0 +1,69 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SQLite based implementation of server authentication state.""" + +from sqlite_state import SqliteState +from authentication_state import AuthenticationState +from typing import Set + +class SqliteAuthState(AuthenticationState, SqliteState): + def __init__(self) -> None: + super().__init__() + + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: + """Store `node_id` and `public_key` as key-value pair in state.""" + query = "INSERT OR REPLACE INTO node_key (node_id, public_key) VALUES (:node_id, :public_key)" + self.query(query, {"node_id": node_id, "public_key": public_key}) + + def get_public_key_from_node_id(self, node_id: int) -> bytes: + """Get client's public key in urlsafe bytes for `node_id`.""" + query = "SELECT public_key FROM node_key WHERE node_id = :node_id" + rows = self.query(query, {"node_id": node_id}) + return rows[0]["public_key"] + + def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + """Store server's `public_key` and `private_key` in state.""" + query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" + self.query(query, {"public_key": public_key, "private_key": private_key}) + + def get_server_private_key(self) -> bytes: + """Get server private key in urlsafe bytes.""" + query = "SELECT private_key FROM credential" + rows = self.query(query) + return rows[0]["private_key"] + + def get_server_public_key(self) -> bytes: + """Get server public key in urlsafe bytes.""" + query = "SELECT public_key FROM credential" + rows = self.query(query) + return rows[0]["public_key"] + + def store_client_public_keys(self, public_keys: Set[bytes]) -> None: + """Store a set of client public keys in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + for public_key in public_keys: + self.query(query, {"public_key": public_key}) + + def store_client_public_key(self, public_key: bytes) -> None: + """Retrieve a client public key in state.""" + query = "INSERT INTO public_key (public_key) VALUES (:public_key)" + self.query(query, {"public_key": public_key}) + + def get_client_public_keys(self) -> Set[bytes]: + """Retrieve all currently stored client public keys as a set.""" + query = "SELECT public_key FROM public_key" + rows = self.query(query) + result: Set[bytes] = {row["public_key"] for row in rows} + return result diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index 224c16cdf013..f89df6301334 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -37,6 +37,26 @@ ); """ +SQL_CREATE_TABLE_NODE_KEY = """ +CREATE TABLE IF NOT EXISTS node_key( + node_id INTEGER PRIMARY KEY, + public_key BLOB +); +""" + +SQL_CREATE_TABLE_CREDENTIAL = """ +CREATE TABLE IF NOT EXISTS credential( + public_key BLOB PRIMARY KEY, + private_key BLOB +); +""" + +SQL_CREATE_TABLE_PUBLIC_KEY = """ +CREATE TABLE IF NOT EXISTS public_key( + public_key BLOB UNIQUE +); +""" + SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE @@ -123,6 +143,9 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) + cur.execute(SQL_CREATE_TABLE_CREDENTIAL) + cur.execute(SQL_CREATE_TABLE_NODE_KEY) + cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) res = cur.execute("SELECT name FROM sqlite_schema;") return res.fetchall() @@ -519,7 +542,6 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 - def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, From 91a2f18d523cb25f773e3366e2d80b3e5196b6ed Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:46:44 +0100 Subject: [PATCH 02/49] Fix isort --- .../secure_aggregation/crypto/symmetric_encryption.py | 2 +- .../state/authentication/authentication_state_test.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 7b22565e2803..67b49d85cc53 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -19,7 +19,7 @@ from typing import Tuple, cast from cryptography.fernet import Fernet -from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.kdf.hkdf import HKDF diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 473f8e37eb36..4f7328894152 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -15,15 +15,15 @@ """Test for authentication state.""" import os -from in_memory_auth_state import InMemoryAuthState -from sqlite_auth_state import SqliteAuthState from common.secure_aggregation.crypto.symmetric_encryption import ( - generate_key_pairs, - public_key_to_bytes, + compute_hmac, + generate_key_pairs, generate_shared_key, + public_key_to_bytes, verify_hmac, - compute_hmac ) +from in_memory_auth_state import InMemoryAuthState + def test_client_public_keys() -> None: key_pairs = [generate_key_pairs() for _ in range(3)] From db16c1003e0504ecfbd4fd7a209ae4e88248fb60 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:50:39 +0100 Subject: [PATCH 03/49] Fix isort --- .../server/state/authentication/authentication_state.py | 1 + .../state/authentication/authentication_state_test.py | 1 + .../server/state/authentication/in_memory_auth_state.py | 6 ++++-- .../flwr/server/state/authentication/sqlite_auth_state.py | 6 ++++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index 4a6fc9a6ab57..1a9831cf2a12 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,6 +17,7 @@ import abc from typing import Set + class AuthenticationState(abc.ABC): """Abstract State.""" @abc.abstractmethod diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 4f7328894152..db4a18b27512 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -15,6 +15,7 @@ """Test for authentication state.""" import os + from common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index d4b387881962..ba5db177f99c 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -14,10 +14,12 @@ # ============================================================================== """In-memory Authentication State implementation.""" -from in_memory_state import InMemoryState -from authentication_state import AuthenticationState from typing import Dict, Set +from authentication_state import AuthenticationState +from in_memory_state import InMemoryState + + class InMemoryAuthState(AuthenticationState, InMemoryState): def __init__(self) -> None: super().__init__() diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 852629896778..542fab26ca08 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -14,10 +14,12 @@ # ============================================================================== """SQLite based implementation of server authentication state.""" -from sqlite_state import SqliteState -from authentication_state import AuthenticationState from typing import Set +from authentication_state import AuthenticationState +from sqlite_state import SqliteState + + class SqliteAuthState(AuthenticationState, SqliteState): def __init__(self) -> None: super().__init__() From 28876bc415ee577bf8e48ca87007bb27dcff5015 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 13:56:06 +0100 Subject: [PATCH 04/49] Run format.sh --- .../crypto/symmetric_encryption.py | 2 ++ .../authentication/authentication_state.py | 11 +++++++---- .../authentication_state_test.py | 4 +++- .../authentication/in_memory_auth_state.py | 18 ++++++++++-------- .../state/authentication/sqlite_auth_state.py | 4 +++- src/py/flwr/server/state/sqlite_state.py | 1 + 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 67b49d85cc53..0ad3bef18045 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -99,11 +99,13 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: fernet = Fernet(key) return fernet.decrypt(ciphertext) + def compute_hmac(key: bytes, message: bytes) -> bytes: computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) return computed_hmac.finalize() + def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index 1a9831cf2a12..a886f9b6510d 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -20,6 +20,7 @@ class AuthenticationState(abc.ABC): """Abstract State.""" + @abc.abstractmethod def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" @@ -29,7 +30,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" @abc.abstractmethod - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" @abc.abstractmethod @@ -42,12 +45,12 @@ def get_server_public_key(self) -> bytes: @abc.abstractmethod def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" + """Store a set of client public keys in state.""" @abc.abstractmethod def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" + """Retrieve a client public key in state.""" @abc.abstractmethod def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" + """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index db4a18b27512..32c3363ab5d3 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -35,6 +35,7 @@ def test_client_public_keys() -> None: assert in_memory_auth_state.get_client_public_keys == public_keys + def test_node_id_public_key_pair() -> None: node_id = int.from_bytes(os.urandom(8), "little", signed=True) public_key = public_key_to_bytes(generate_key_pairs()[1]) @@ -44,6 +45,7 @@ def test_node_id_public_key_pair() -> None: assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key + def test_generate_shared_key() -> None: client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -53,6 +55,7 @@ def test_generate_shared_key() -> None: assert client_shared_secret == server_shared_secret + def test_hmac() -> None: client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -63,4 +66,3 @@ def test_hmac() -> None: client_compute_hmac = compute_hmac(client_shared_secret, message) assert verify_hmac(server_shared_secret, message, client_compute_hmac) - \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index ba5db177f99c..c6dbcf4a5e1b 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -27,7 +27,7 @@ def __init__(self) -> None: self.client_public_keys: Set[bytes] = set() self.server_public_key: bytes = bytes() self.server_private_key: bytes = bytes() - + def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" if node_id not in self.node_ids: @@ -42,7 +42,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: return self.node_id_public_key_dict[node_id] return bytes() - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" self.server_private_key = private_key self.server_public_key = public_key @@ -56,13 +58,13 @@ def get_server_public_key(self) -> bytes: return self.server_public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - self.client_public_keys = public_keys + """Store a set of client public keys in state.""" + self.client_public_keys = public_keys def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - self.client_public_keys.add(public_key) + """Retrieve a client public key in state.""" + self.client_public_keys.add(public_key) def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - return self.client_public_keys + """Retrieve all currently stored client public keys as a set.""" + return self.client_public_keys diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 542fab26ca08..381240df5d1a 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -35,7 +35,9 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: rows = self.query(query, {"node_id": node_id}) return rows[0]["public_key"] - def store_server_public_private_key(self, public_key: bytes, private_key: bytes) -> None: + def store_server_public_private_key( + self, public_key: bytes, private_key: bytes + ) -> None: """Store server's `public_key` and `private_key` in state.""" query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" self.query(query, {"public_key": public_key, "private_key": private_key}) diff --git a/src/py/flwr/server/state/sqlite_state.py b/src/py/flwr/server/state/sqlite_state.py index f89df6301334..e91d8553863c 100644 --- a/src/py/flwr/server/state/sqlite_state.py +++ b/src/py/flwr/server/state/sqlite_state.py @@ -542,6 +542,7 @@ def create_run(self) -> int: log(ERROR, "Unexpected run creation failure.") return 0 + def dict_factory( cursor: sqlite3.Cursor, row: sqlite3.Row, From 42a7d386274b56d53d8268cafc698420b655911f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:00:44 +0100 Subject: [PATCH 05/49] Add init.py --- .../flwr/server/state/authentication/__init__.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 src/py/flwr/server/state/authentication/__init__.py diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py new file mode 100644 index 000000000000..3203b3230b5c --- /dev/null +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower server authentication state.""" From 8ec63c96203a8c19545958c47ed7661c81c53c32 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:17:30 +0100 Subject: [PATCH 06/49] Fix line too long --- .../secure_aggregation/crypto/symmetric_encryption.py | 4 +++- .../state/authentication/authentication_state_test.py | 4 ++++ .../server/state/authentication/in_memory_auth_state.py | 7 ++++--- .../flwr/server/state/authentication/sqlite_auth_state.py | 7 +++++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index 0ad3bef18045..e38bdb6d7859 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -101,16 +101,18 @@ def decrypt(key: bytes, ciphertext: bytes) -> bytes: def compute_hmac(key: bytes, message: bytes) -> bytes: + """Compute hmac of a message using key as hash.""" computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) return computed_hmac.finalize() def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: + """Verify hmac of a message using key as hash.""" computed_hmac = hmac.HMAC(key, hashes.SHA256()) computed_hmac.update(message) try: computed_hmac.verify(hmac_value) return True - except: + except Exception: return False diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 32c3363ab5d3..35d36e7c8782 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -27,6 +27,7 @@ def test_client_public_keys() -> None: + """Test client public keys store and get from state.""" key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -37,6 +38,7 @@ def test_client_public_keys() -> None: def test_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" node_id = int.from_bytes(os.urandom(8), "little", signed=True) public_key = public_key_to_bytes(generate_key_pairs()[1]) @@ -47,6 +49,7 @@ def test_node_id_public_key_pair() -> None: def test_generate_shared_key() -> None: + """Test util function generate_shared_key.""" client_keys = generate_key_pairs() server_keys = generate_key_pairs() @@ -57,6 +60,7 @@ def test_generate_shared_key() -> None: def test_hmac() -> None: + """Test util function compute and verify hmac.""" client_keys = generate_key_pairs() server_keys = generate_key_pairs() client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index c6dbcf4a5e1b..2d51494e5158 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -22,11 +22,12 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): def __init__(self) -> None: + """Init InMemoryAuthState.""" super().__init__() self.node_id_public_key_dict: Dict[int, bytes] = {} self.client_public_keys: Set[bytes] = set() - self.server_public_key: bytes = bytes() - self.server_private_key: bytes = bytes() + self.server_public_key: bytes = b"" + self.server_private_key: bytes = b"" def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" @@ -40,7 +41,7 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" if node_id in self.node_id_public_key_dict: return self.node_id_public_key_dict[node_id] - return bytes() + return b"" def store_server_public_private_key( self, public_key: bytes, private_key: bytes diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 381240df5d1a..5513814348f5 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -22,11 +22,13 @@ class SqliteAuthState(AuthenticationState, SqliteState): def __init__(self) -> None: + """Init SqliteAuthState.""" super().__init__() def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" - query = "INSERT OR REPLACE INTO node_key (node_id, public_key) VALUES (:node_id, :public_key)" + query = "INSERT OR REPLACE INTO node_key (node_id, public_key) "\ + "VALUES (:node_id, :public_key)" self.query(query, {"node_id": node_id, "public_key": public_key}) def get_public_key_from_node_id(self, node_id: int) -> bytes: @@ -39,7 +41,8 @@ def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: """Store server's `public_key` and `private_key` in state.""" - query = "INSERT OR REPLACE INTO credential (public_key, private_key) VALUES (:public_key, :private_key)" + query = "INSERT OR REPLACE INTO credential (public_key, private_key) "\ + "VALUES (:public_key, :private_key)" self.query(query, {"public_key": public_key, "private_key": private_key}) def get_server_private_key(self) -> bytes: From 8f04e25ca1977ee437776e7fe2ef41eb714c036e Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:21:32 +0100 Subject: [PATCH 07/49] Fix line too long --- .../state/authentication/in_memory_auth_state.py | 1 + .../state/authentication/sqlite_auth_state.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 2d51494e5158..67cdeb1347ef 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -21,6 +21,7 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): + """In-memory-based authentication state implementation.""" def __init__(self) -> None: """Init InMemoryAuthState.""" super().__init__() diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 5513814348f5..287e75a3c4fc 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -21,14 +21,18 @@ class SqliteAuthState(AuthenticationState, SqliteState): + """SQLite-based authentication state implementation.""" + def __init__(self) -> None: """Init SqliteAuthState.""" super().__init__() def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" - query = "INSERT OR REPLACE INTO node_key (node_id, public_key) "\ - "VALUES (:node_id, :public_key)" + query = ( + "INSERT OR REPLACE INTO node_key (node_id, public_key) " + "VALUES (:node_id, :public_key)" + ) self.query(query, {"node_id": node_id, "public_key": public_key}) def get_public_key_from_node_id(self, node_id: int) -> bytes: @@ -41,8 +45,10 @@ def store_server_public_private_key( self, public_key: bytes, private_key: bytes ) -> None: """Store server's `public_key` and `private_key` in state.""" - query = "INSERT OR REPLACE INTO credential (public_key, private_key) "\ - "VALUES (:public_key, :private_key)" + query = ( + "INSERT OR REPLACE INTO credential (public_key, private_key) " + "VALUES (:public_key, :private_key)" + ) self.query(query, {"public_key": public_key, "private_key": private_key}) def get_server_private_key(self) -> bytes: From e8813fcf729293cf239bf79c43b2a64f23da1d4c Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:24:18 +0100 Subject: [PATCH 08/49] Fix line too long --- src/py/flwr/server/state/authentication/in_memory_auth_state.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 67cdeb1347ef..9bd68ee1b7db 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -22,6 +22,7 @@ class InMemoryAuthState(AuthenticationState, InMemoryState): """In-memory-based authentication state implementation.""" + def __init__(self) -> None: """Init InMemoryAuthState.""" super().__init__() From d9f3fb04b565696eb888316afc4103b12f0865f0 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:35:36 +0100 Subject: [PATCH 09/49] Fix subclassing --- src/py/flwr/server/state/authentication/__init__.py | 10 ++++++++++ .../server/state/authentication/sqlite_auth_state.py | 9 ++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py index 3203b3230b5c..95f3e3fbbd57 100644 --- a/src/py/flwr/server/state/authentication/__init__.py +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -13,3 +13,13 @@ # limitations under the License. # ============================================================================== """Flower server authentication state.""" + +from .authentication_state import AuthenticationState as AuthenticationState +from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState as SqliteAuthState + +__all__ = [ + "AuthenticationState", + "InMemoryAuthState", + "SqliteAuthState", +] \ No newline at end of file diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 287e75a3c4fc..c93a148d9956 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -39,7 +39,8 @@ def get_public_key_from_node_id(self, node_id: int) -> bytes: """Get client's public key in urlsafe bytes for `node_id`.""" query = "SELECT public_key FROM node_key WHERE node_id = :node_id" rows = self.query(query, {"node_id": node_id}) - return rows[0]["public_key"] + public_key: bytes = rows[0]["public_key"] + return public_key def store_server_public_private_key( self, public_key: bytes, private_key: bytes @@ -55,13 +56,15 @@ def get_server_private_key(self) -> bytes: """Get server private key in urlsafe bytes.""" query = "SELECT private_key FROM credential" rows = self.query(query) - return rows[0]["private_key"] + private_key: bytes = rows[0]["private_key"] + return private_key def get_server_public_key(self) -> bytes: """Get server public key in urlsafe bytes.""" query = "SELECT public_key FROM credential" rows = self.query(query) - return rows[0]["public_key"] + public_key: bytes = rows[0]["public_key"] + return public_key def store_client_public_keys(self, public_keys: Set[bytes]) -> None: """Store a set of client public keys in state.""" From caf6695d7e9d87f19db2b8fd7b4c935607c1ffa4 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 14:39:29 +0100 Subject: [PATCH 10/49] Fix subclassing --- src/py/flwr/server/state/authentication/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/state/authentication/__init__.py index 95f3e3fbbd57..8f5c0a97ab1f 100644 --- a/src/py/flwr/server/state/authentication/__init__.py +++ b/src/py/flwr/server/state/authentication/__init__.py @@ -22,4 +22,4 @@ "AuthenticationState", "InMemoryAuthState", "SqliteAuthState", -] \ No newline at end of file +] From fa217ae7f16430c48babf84d160717c792a0d6a8 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:03:56 +0100 Subject: [PATCH 11/49] Fix subclassing --- .../flwr/server/state/authentication/authentication_state.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index a886f9b6510d..c881af432f4b 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,8 +17,10 @@ import abc from typing import Set +from state import State -class AuthenticationState(abc.ABC): + +class AuthenticationState(State, abc.ABC): """Abstract State.""" @abc.abstractmethod From 6edddd631dee4bba954c68074ebbe683fcf6b888 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:08:53 +0100 Subject: [PATCH 12/49] Fix subclassing --- .../flwr/server/state/authentication/authentication_state.py | 2 +- .../flwr/server/state/authentication/in_memory_auth_state.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index c881af432f4b..fb538038dcbb 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -21,7 +21,7 @@ class AuthenticationState(State, abc.ABC): - """Abstract State.""" + """Abstract Authentication State.""" @abc.abstractmethod def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 9bd68ee1b7db..9ddc958c18d3 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -16,8 +16,8 @@ from typing import Dict, Set -from authentication_state import AuthenticationState -from in_memory_state import InMemoryState +from .authentication_state import AuthenticationState +from flwr.server.state.in_memory_state import InMemoryState class InMemoryAuthState(AuthenticationState, InMemoryState): From 8bb15a569927b8c16f214a286309ee05ccb5a3b6 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sat, 10 Feb 2024 15:11:49 +0100 Subject: [PATCH 13/49] Fix subclassing --- src/py/flwr/server/state/authentication/in_memory_auth_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/state/authentication/in_memory_auth_state.py index 9ddc958c18d3..fe10c1301b11 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/state/authentication/in_memory_auth_state.py @@ -16,7 +16,7 @@ from typing import Dict, Set -from .authentication_state import AuthenticationState +from flwr.server.state.authentication.authentication_state import AuthenticationState from flwr.server.state.in_memory_state import InMemoryState From c5bac4f2ba5807190c501a40440b048de3c2c9f5 Mon Sep 17 00:00:00 2001 From: jafermarq Date: Sun, 11 Feb 2024 09:20:04 +0000 Subject: [PATCH 14/49] fixes --- .../server/state/authentication/authentication_state.py | 2 +- .../state/authentication/authentication_state_test.py | 7 ++++--- .../flwr/server/state/authentication/sqlite_auth_state.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/state/authentication/authentication_state.py index fb538038dcbb..3adb450dc215 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/state/authentication/authentication_state.py @@ -17,7 +17,7 @@ import abc from typing import Set -from state import State +from flwr.server.state import State class AuthenticationState(State, abc.ABC): diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 35d36e7c8782..2aaf736a8d68 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -16,14 +16,15 @@ import os -from common.secure_aggregation.crypto.symmetric_encryption import ( +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, generate_shared_key, public_key_to_bytes, verify_hmac, ) -from in_memory_auth_state import InMemoryAuthState + +from .in_memory_auth_state import InMemoryAuthState def test_client_public_keys() -> None: @@ -34,7 +35,7 @@ def test_client_public_keys() -> None: in_memory_auth_state = InMemoryAuthState() in_memory_auth_state.store_client_public_keys(public_keys) - assert in_memory_auth_state.get_client_public_keys == public_keys + assert in_memory_auth_state.get_client_public_keys() == public_keys def test_node_id_public_key_pair() -> None: diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index c93a148d9956..0e0436f0bae1 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -16,8 +16,8 @@ from typing import Set -from authentication_state import AuthenticationState -from sqlite_state import SqliteState +from flwr.server.state.authentication.authentication_state import AuthenticationState +from flwr.server.state.sqlite_state import SqliteState class SqliteAuthState(AuthenticationState, SqliteState): From c856b7c441a79bc60c82ec3359291dc30287a469 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 11:13:13 +0100 Subject: [PATCH 15/49] Fix state tests --- .../server/state/authentication/authentication_state_test.py | 5 ++--- src/py/flwr/server/state/authentication/sqlite_auth_state.py | 4 ---- src/py/flwr/server/state/state_test.py | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 2aaf736a8d68..1495fdf5084e 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -14,7 +14,6 @@ # ============================================================================== """Test for authentication state.""" -import os from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, @@ -40,10 +39,10 @@ def test_client_public_keys() -> None: def test_node_id_public_key_pair() -> None: """Test store and get node_id public_key pair.""" - node_id = int.from_bytes(os.urandom(8), "little", signed=True) + in_memory_auth_state = InMemoryAuthState() + node_id = in_memory_auth_state.create_node() public_key = public_key_to_bytes(generate_key_pairs()[1]) - in_memory_auth_state = InMemoryAuthState() in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/state/authentication/sqlite_auth_state.py index 0e0436f0bae1..55e4bc73a63b 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/state/authentication/sqlite_auth_state.py @@ -23,10 +23,6 @@ class SqliteAuthState(AuthenticationState, SqliteState): """SQLite-based authentication state implementation.""" - def __init__(self) -> None: - """Init SqliteAuthState.""" - super().__init__() - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: """Store `node_id` and `public_key` as key-value pair in state.""" query = ( diff --git a/src/py/flwr/server/state/state_test.py b/src/py/flwr/server/state/state_test.py index 95d764792ff3..9395083e2648 100644 --- a/src/py/flwr/server/state/state_test.py +++ b/src/py/flwr/server/state/state_test.py @@ -477,7 +477,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 13 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -502,7 +502,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 8 + assert len(result) == 13 if __name__ == "__main__": From 475850733ef370442203dea3ccc28bb281463b79 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 11:17:10 +0100 Subject: [PATCH 16/49] Fix too broad exception --- .../common/secure_aggregation/crypto/symmetric_encryption.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py index e38bdb6d7859..1d004a398ea8 100644 --- a/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py +++ b/src/py/flwr/common/secure_aggregation/crypto/symmetric_encryption.py @@ -18,6 +18,7 @@ import base64 from typing import Tuple, cast +from cryptography.exceptions import InvalidSignature from cryptography.fernet import Fernet from cryptography.hazmat.primitives import hashes, hmac, serialization from cryptography.hazmat.primitives.asymmetric import ec @@ -114,5 +115,5 @@ def verify_hmac(key: bytes, message: bytes, hmac_value: bytes) -> bool: try: computed_hmac.verify(hmac_value) return True - except Exception: + except InvalidSignature: return False From e666da57e1c62d1d0f6c9a50f8d86c073a707e0f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Sun, 11 Feb 2024 14:16:50 +0100 Subject: [PATCH 17/49] Add sqlite auth state test --- .../authentication_state_test.py | 29 +++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/state/authentication/authentication_state_test.py index 1495fdf5084e..f18c428d3044 100644 --- a/src/py/flwr/server/state/authentication/authentication_state_test.py +++ b/src/py/flwr/server/state/authentication/authentication_state_test.py @@ -24,9 +24,10 @@ ) from .in_memory_auth_state import InMemoryAuthState +from .sqlite_auth_state import SqliteAuthState -def test_client_public_keys() -> None: +def test_in_memory_client_public_keys() -> None: """Test client public keys store and get from state.""" key_pairs = [generate_key_pairs() for _ in range(3)] public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} @@ -37,7 +38,19 @@ def test_client_public_keys() -> None: assert in_memory_auth_state.get_client_public_keys() == public_keys -def test_node_id_public_key_pair() -> None: +def test_sqlite_client_public_keys() -> None: + """Test client public keys store and get from state.""" + key_pairs = [generate_key_pairs() for _ in range(3)] + public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} + + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + sqlite_auth_state.store_client_public_keys(public_keys) + + assert sqlite_auth_state.get_client_public_keys() == public_keys + + +def test_in_memory_node_id_public_key_pair() -> None: """Test store and get node_id public_key pair.""" in_memory_auth_state = InMemoryAuthState() node_id = in_memory_auth_state.create_node() @@ -48,6 +61,18 @@ def test_node_id_public_key_pair() -> None: assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key +def test_sqlite_node_id_public_key_pair() -> None: + """Test store and get node_id public_key pair.""" + sqlite_auth_state = SqliteAuthState(":memory:") + sqlite_auth_state.initialize() + node_id = sqlite_auth_state.create_node() + public_key = public_key_to_bytes(generate_key_pairs()[1]) + + sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) + + assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key + + def test_generate_shared_key() -> None: """Test util function generate_shared_key.""" client_keys = generate_key_pairs() From e6421cb6e6d41ec98d9aa213ee660584f39b150c Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Tue, 13 Feb 2024 16:01:54 +0100 Subject: [PATCH 18/49] Add client interceptor --- src/py/flwr/client/client_interceptor.py | 80 +++++++++++++++++++ src/py/flwr/client/client_interceptor_test.py | 16 ++++ 2 files changed, 96 insertions(+) create mode 100644 src/py/flwr/client/client_interceptor.py create mode 100644 src/py/flwr/client/client_interceptor_test.py diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py new file mode 100644 index 000000000000..77009fdb6f25 --- /dev/null +++ b/src/py/flwr/client/client_interceptor.py @@ -0,0 +1,80 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower client interceptor.""" + +import grpc +import collections +from typing import Callable, Union, Sequence, Tuple +from flwr.proto.fleet_pb2 import ( + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, +) +from cryptography.hazmat.primitives.asymmetric import ec +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, + bytes_to_public_key, + generate_shared_key, + public_key_to_bytes, +) + +_PUBLIC_KEY_HEADER = "public-key" +_AUTH_TOKEN_HEADER = "auth-token" + +Request = Union[CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest] + +def _get_value_from_tuples(key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]) -> Union[str, bytes]: + return next((value[::-1] for key, value in tuples if key == key_string), "") + +class _ClientCallDetails( + collections.namedtuple( + '_ClientCallDetails', + ('method', 'timeout', 'metadata', 'credentials')), + grpc.ClientCallDetails): + pass + +class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): + def __init__(self, private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey): + self.private_key = private_key + self.public_key = public_key + + def intercept_unary_unary(self, continuation: Callable, client_call_details: grpc.ClientCallDetails, request: Request): + """Flower client interceptor.""" + metadata = [] + postprocess = False + if client_call_details.metadata is not None: + metadata = list(client_call_details.metadata) + + if isinstance(request, CreateNodeRequest): + metadata.append((_PUBLIC_KEY_HEADER, public_key_to_bytes(self.public_key))) + postprocess = True + + elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest)): + metadata.append((_AUTH_TOKEN_HEADER, compute_hmac(self.shared_secret, request))) + else: + pass + + client_call_details = _ClientCallDetails( + client_call_details.method, client_call_details.timeout, metadata, + client_call_details.credentials) + + response = continuation(client_call_details, request) + if postprocess: + server_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.trailing_metadata) + self.server_public_key = bytes_to_public_key(server_public_key_bytes) + self.shared_secret = generate_shared_key(self.private_key, self.server_public_key) + return response + \ No newline at end of file diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py new file mode 100644 index 000000000000..4752633e0ad4 --- /dev/null +++ b/src/py/flwr/client/client_interceptor_test.py @@ -0,0 +1,16 @@ +# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower client interceptor tests.""" + From 3ee430ad074215d7f33e9394d672b54a72ec58d9 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Tue, 13 Feb 2024 16:39:03 +0100 Subject: [PATCH 19/49] Add mock servicer --- src/py/flwr/client/client_interceptor_test.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 4752633e0ad4..c08f81d67941 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -14,3 +14,83 @@ # ============================================================================== """Flower client interceptor tests.""" +import unittest +import threading +import grpc +from concurrent import futures +from flwr.proto.fleet_pb2 import ( + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, +) + +from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + generate_key_pairs, + bytes_to_public_key, + generate_shared_key, + public_key_to_bytes, +) + +_PUBLIC_KEY_HEADER = "public-key" + +class _MockServicer(object): + def __init__(self): + self._lock = threading.Lock() + self._received_client_metadata = None + _, self._server_public_key = generate_key_pairs() + + + def unary_unary(self, request, context): + with self._lock: + self._received_client_metadata = context.invocation_metadata() + if isinstance(request, CreateNodeRequest): + context.set_trailing_metadata( + ( + (_PUBLIC_KEY_HEADER, self._server_public_key), + ) + ) + + return object() + + def received_client_metadata(self): + with self._lock: + return self._received_client_metadata + +def _generic_handler(servicer: _MockServicer): + rpc_method_handlers = { + 'CreateNode': grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeResponse.SerializeToString, + ), + 'DeleteNode': grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, + ), + 'PullTaskIns': grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString, + ), + 'PushTaskRes': grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, + ), + } + return grpc.method_handlers_generic_handler('flwr.proto.Fleet', rpc_method_handlers) + +class TestAuthenticateClientInterceptor(unittest.TestCase): + def setUp(self): + self._server = grpc.server( + futures.ThreadPoolExecutor(max_workers=10), + options=(("grpc.so_reuseport", int(False)),), + ) + self._server.add_generic_rpc_handlers( + (_generic_handler(self._servicer),) + ) + port = self._server.add_insecure_port("[::]:0") + self._server.start() From bcc481e341dcd8c6fdc200c3a2ef14af33b1f7f7 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Tue, 13 Feb 2024 16:40:48 +0100 Subject: [PATCH 20/49] Add unittest initializer --- src/py/flwr/client/client_interceptor_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index c08f81d67941..fc4768b34b89 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -94,3 +94,7 @@ def setUp(self): ) port = self._server.add_insecure_port("[::]:0") self._server.start() + + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file From abe18a578041a30e246edf6d489b974e615d756b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Tue, 13 Feb 2024 17:07:13 +0100 Subject: [PATCH 21/49] Integrate client interceptor --- src/py/flwr/client/grpc_rere_client/connection.py | 5 ++++- src/py/flwr/common/grpc.py | 6 +++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 07635d002721..280cd48f1d38 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -18,7 +18,8 @@ from contextlib import contextmanager from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast, Sequence +import grpc from flwr.client.message_handler.task_handler import ( configure_task_res, @@ -56,6 +57,7 @@ def grpc_request_response( insecure: bool, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, + interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], @@ -99,6 +101,7 @@ def grpc_request_response( insecure=insecure, root_certificates=root_certificates, max_message_length=max_message_length, + interceptors=interceptors, ) channel.subscribe(on_channel_state_change) stub = FleetStub(channel) diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 9d0543ea8c75..436e129aeb35 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -16,7 +16,7 @@ from logging import INFO -from typing import Optional +from typing import Optional, Sequence import grpc @@ -30,6 +30,7 @@ def create_channel( insecure: bool, root_certificates: Optional[bytes] = None, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, + interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None, ) -> grpc.Channel: """Create a gRPC channel, either secure or insecure.""" # Check for conflicting parameters @@ -56,5 +57,8 @@ def create_channel( server_address, ssl_channel_credentials, options=channel_options ) log(INFO, "Opened secure gRPC connection using certificates") + + if interceptors is not None: + channel = grpc.intercept_channel(channel, interceptors) return channel From 8d0b7c5803c47449656a1ba845e291a859efe2b8 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Tue, 13 Feb 2024 19:57:42 +0100 Subject: [PATCH 22/49] Integrate client interceptor --- src/py/flwr/client/client_interceptor.py | 81 +++++++++----- src/py/flwr/client/client_interceptor_test.py | 104 +++++++++++------- .../client/grpc_rere_client/connection.py | 3 +- src/py/flwr/common/grpc.py | 2 +- 4 files changed, 122 insertions(+), 68 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index 77009fdb6f25..cfc290f04aa2 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -14,44 +14,65 @@ # ============================================================================== """Flower client interceptor.""" -import grpc import collections -from typing import Callable, Union, Sequence, Tuple +from typing import Callable, Sequence, Tuple, Union + +import grpc +from cryptography.hazmat.primitives.asymmetric import ec + +from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + bytes_to_public_key, + compute_hmac, + generate_shared_key, + public_key_to_bytes, +) from flwr.proto.fleet_pb2 import ( CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, ) -from cryptography.hazmat.primitives.asymmetric import ec -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, - bytes_to_public_key, - generate_shared_key, - public_key_to_bytes, -) _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" -Request = Union[CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest] +Request = Union[ + CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest +] + -def _get_value_from_tuples(key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]]) -> Union[str, bytes]: +def _get_value_from_tuples( + key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] +) -> Union[str, bytes]: return next((value[::-1] for key, value in tuples if key == key_string), "") + class _ClientCallDetails( - collections.namedtuple( - '_ClientCallDetails', - ('method', 'timeout', 'metadata', 'credentials')), - grpc.ClientCallDetails): + collections.namedtuple( + "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") + ), + grpc.ClientCallDetails, +): pass + class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): - def __init__(self, private_key: ec.EllipticCurvePrivateKey, public_key: ec.EllipticCurvePublicKey): + """Client interceptor for client authentication.""" + + def __init__( + self, + private_key: ec.EllipticCurvePrivateKey, + public_key: ec.EllipticCurvePublicKey, + ): self.private_key = private_key self.public_key = public_key - def intercept_unary_unary(self, continuation: Callable, client_call_details: grpc.ClientCallDetails, request: Request): + def intercept_unary_unary( + self, + continuation: Callable, + client_call_details: grpc.ClientCallDetails, + request: Request, + ): """Flower client interceptor.""" metadata = [] postprocess = False @@ -62,19 +83,29 @@ def intercept_unary_unary(self, continuation: Callable, client_call_details: grp metadata.append((_PUBLIC_KEY_HEADER, public_key_to_bytes(self.public_key))) postprocess = True - elif isinstance(request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest)): - metadata.append((_AUTH_TOKEN_HEADER, compute_hmac(self.shared_secret, request))) + elif isinstance( + request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) + ): + metadata.append( + (_AUTH_TOKEN_HEADER, compute_hmac(self.shared_secret, request)) + ) else: pass client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials) - + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + ) + response = continuation(client_call_details, request) if postprocess: - server_public_key_bytes = _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.trailing_metadata) + server_public_key_bytes = _get_value_from_tuples( + _PUBLIC_KEY_HEADER, response.trailing_metadata + ) self.server_public_key = bytes_to_public_key(server_public_key_bytes) - self.shared_secret = generate_shared_key(self.private_key, self.server_public_key) + self.shared_secret = generate_shared_key( + self.private_key, self.server_public_key + ) return response - \ No newline at end of file diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index fc4768b34b89..8df66d385577 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -14,87 +14,109 @@ # ============================================================================== """Flower client interceptor tests.""" -import unittest import threading -import grpc +import unittest from concurrent import futures -from flwr.proto.fleet_pb2 import ( - CreateNodeRequest, - DeleteNodeRequest, - PullTaskInsRequest, - PushTaskResRequest, -) +from typing import Callable, ContextManager, Optional, Tuple, Union -from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 +import grpc + +from flwr.client.grpc_rere_client.connection import grpc_request_response +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.message import Message from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, - bytes_to_public_key, - generate_shared_key, - public_key_to_bytes, ) +from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 +from flwr.proto.fleet_pb2 import CreateNodeRequest _PUBLIC_KEY_HEADER = "public-key" -class _MockServicer(object): + +class _MockServicer: def __init__(self): self._lock = threading.Lock() self._received_client_metadata = None _, self._server_public_key = generate_key_pairs() - def unary_unary(self, request, context): with self._lock: self._received_client_metadata = context.invocation_metadata() if isinstance(request, CreateNodeRequest): context.set_trailing_metadata( - ( - (_PUBLIC_KEY_HEADER, self._server_public_key), - ) + ((_PUBLIC_KEY_HEADER, self._server_public_key),) ) - + return object() def received_client_metadata(self): with self._lock: return self._received_client_metadata - + + def _generic_handler(servicer: _MockServicer): rpc_method_handlers = { - 'CreateNode': grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeResponse.SerializeToString, + "CreateNode": grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeResponse.SerializeToString, ), - 'DeleteNode': grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, + "DeleteNode": grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, ), - 'PullTaskIns': grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString, + "PullTaskIns": grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString, ), - 'PushTaskRes': grpc.unary_unary_rpc_method_handler( - servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, + "PushTaskRes": grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString, + response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, ), } - return grpc.method_handlers_generic_handler('flwr.proto.Fleet', rpc_method_handlers) + return grpc.method_handlers_generic_handler("flwr.proto.Fleet", rpc_method_handlers) + class TestAuthenticateClientInterceptor(unittest.TestCase): + """Test for client interceptor client authentication.""" + def setUp(self): + """Initialize mock server and client.""" self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=(("grpc.so_reuseport", int(False)),), ) - self._server.add_generic_rpc_handlers( - (_generic_handler(self._servicer),) - ) + self._server.add_generic_rpc_handlers((_generic_handler(self._servicer),)) port = self._server.add_insecure_port("[::]:0") self._server.start() - + + self._connection: Callable[ + [str, bool, int, Union[bytes, str, None]], + ContextManager[ + Tuple[ + Callable[[], Optional[Message]], + Callable[[Message], None], + Optional[Callable[[], None]], + Optional[Callable[[], None]], + ] + ], + ] = grpc_request_response(f"localhost:{port}") + self._address = f"localhost:{port}" + + def test_client_auth_create_node(self) -> None: + """Test client authentication during create node.""" + with self._connection( + self._address, + True, + GRPC_MAX_MESSAGE_LENGTH, + None, + ) as conn: + receive, send, create_node, delete_node = conn + create_node() + if __name__ == "__main__": - unittest.main(verbosity=2) \ No newline at end of file + unittest.main(verbosity=2) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 280cd48f1d38..4d997f10f4b7 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -18,7 +18,8 @@ from contextlib import contextmanager from logging import DEBUG, ERROR from pathlib import Path -from typing import Callable, Dict, Iterator, Optional, Tuple, Union, cast, Sequence +from typing import Callable, Dict, Iterator, Optional, Sequence, Tuple, Union, cast + import grpc from flwr.client.message_handler.task_handler import ( diff --git a/src/py/flwr/common/grpc.py b/src/py/flwr/common/grpc.py index 436e129aeb35..62ac6d6caf1d 100644 --- a/src/py/flwr/common/grpc.py +++ b/src/py/flwr/common/grpc.py @@ -57,7 +57,7 @@ def create_channel( server_address, ssl_channel_credentials, options=channel_options ) log(INFO, "Opened secure gRPC connection using certificates") - + if interceptors is not None: channel = grpc.intercept_channel(channel, interceptors) From 328118e2be29a7e5e08cb56ad349a426fba72a60 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 10:49:50 +0100 Subject: [PATCH 23/49] Add type:ignore to grpc primitives --- src/py/flwr/client/client_interceptor.py | 5 +-- src/py/flwr/client/client_interceptor_test.py | 33 ++++++++++++++----- .../state/authentication/__init__.py | 0 .../authentication/authentication_state.py | 2 +- .../authentication_state_test.py | 0 .../authentication/in_memory_auth_state.py | 4 +-- .../state/authentication/sqlite_auth_state.py | 4 +-- 7 files changed, 33 insertions(+), 15 deletions(-) rename src/py/flwr/server/{ => superlink}/state/authentication/__init__.py (100%) rename src/py/flwr/server/{ => superlink}/state/authentication/authentication_state.py (97%) rename src/py/flwr/server/{ => superlink}/state/authentication/authentication_state_test.py (100%) rename src/py/flwr/server/{ => superlink}/state/authentication/in_memory_auth_state.py (95%) rename src/py/flwr/server/{ => superlink}/state/authentication/sqlite_auth_state.py (96%) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index cfc290f04aa2..3aae58908f77 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -19,6 +19,7 @@ import grpc from cryptography.hazmat.primitives.asymmetric import ec +from grpc import ClientCallDetails, UnaryUnaryClientInterceptor from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_public_key, @@ -51,12 +52,12 @@ class _ClientCallDetails( collections.namedtuple( "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") ), - grpc.ClientCallDetails, + ClientCallDetails, # type: ignore ): pass -class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): +class AuthenticateClientInterceptor(UnaryUnaryClientInterceptor): # type: ignore """Client interceptor for client authentication.""" def __init__( diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 8df66d385577..24c68f7a1eab 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -30,16 +30,18 @@ from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 from flwr.proto.fleet_pb2 import CreateNodeRequest +from .client_interceptor import AuthenticateClientInterceptor, Request + _PUBLIC_KEY_HEADER = "public-key" class _MockServicer: - def __init__(self): + def __init__(self) -> None: self._lock = threading.Lock() self._received_client_metadata = None _, self._server_public_key = generate_key_pairs() - def unary_unary(self, request, context): + def unary_unary(self, request: Request, context: grpc.ServicerContext) -> object: with self._lock: self._received_client_metadata = context.invocation_metadata() if isinstance(request, CreateNodeRequest): @@ -54,7 +56,7 @@ def received_client_metadata(self): return self._received_client_metadata -def _generic_handler(servicer: _MockServicer): +def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: rpc_method_handlers = { "CreateNode": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, @@ -77,21 +79,29 @@ def _generic_handler(servicer: _MockServicer): response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, ), } - return grpc.method_handlers_generic_handler("flwr.proto.Fleet", rpc_method_handlers) + generic_handler = grpc.method_handlers_generic_handler( + "flwr.proto.Fleet", rpc_method_handlers + ) + server.add_generic_rpc_handlers((generic_handler,)) class TestAuthenticateClientInterceptor(unittest.TestCase): """Test for client interceptor client authentication.""" - def setUp(self): + def setUp(self) -> None: """Initialize mock server and client.""" self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=10), options=(("grpc.so_reuseport", int(False)),), ) - self._server.add_generic_rpc_handlers((_generic_handler(self._servicer),)) + self._servicer = _MockServicer() + _add_generic_handler(self._servicer, self._server) port = self._server.add_insecure_port("[::]:0") self._server.start() + self._client_private_key, self._client_public_key = generate_key_pairs() + self._client_interceptor = AuthenticateClientInterceptor( + self._client_private_key, self._client_public_key + ) self._connection: Callable[ [str, bool, int, Union[bytes, str, None]], @@ -103,7 +113,13 @@ def setUp(self): Optional[Callable[[], None]], ] ], - ] = grpc_request_response(f"localhost:{port}") + ] = grpc_request_response( + f"localhost:{port}", + False, + GRPC_MAX_MESSAGE_LENGTH, + None, + [self._client_interceptor], + ) self._address = f"localhost:{port}" def test_client_auth_create_node(self) -> None: @@ -114,8 +130,9 @@ def test_client_auth_create_node(self) -> None: GRPC_MAX_MESSAGE_LENGTH, None, ) as conn: - receive, send, create_node, delete_node = conn + _, _, create_node, _ = conn create_node() + assert self._servicer.received_client_metadata is not None if __name__ == "__main__": diff --git a/src/py/flwr/server/state/authentication/__init__.py b/src/py/flwr/server/superlink/state/authentication/__init__.py similarity index 100% rename from src/py/flwr/server/state/authentication/__init__.py rename to src/py/flwr/server/superlink/state/authentication/__init__.py diff --git a/src/py/flwr/server/state/authentication/authentication_state.py b/src/py/flwr/server/superlink/state/authentication/authentication_state.py similarity index 97% rename from src/py/flwr/server/state/authentication/authentication_state.py rename to src/py/flwr/server/superlink/state/authentication/authentication_state.py index 3adb450dc215..2c0b1ee567fa 100644 --- a/src/py/flwr/server/state/authentication/authentication_state.py +++ b/src/py/flwr/server/superlink/state/authentication/authentication_state.py @@ -17,7 +17,7 @@ import abc from typing import Set -from flwr.server.state import State +from flwr.server.superlink.state import State class AuthenticationState(State, abc.ABC): diff --git a/src/py/flwr/server/state/authentication/authentication_state_test.py b/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py similarity index 100% rename from src/py/flwr/server/state/authentication/authentication_state_test.py rename to src/py/flwr/server/superlink/state/authentication/authentication_state_test.py diff --git a/src/py/flwr/server/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py similarity index 95% rename from src/py/flwr/server/state/authentication/in_memory_auth_state.py rename to src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py index fe10c1301b11..77aa5fdc5af6 100644 --- a/src/py/flwr/server/state/authentication/in_memory_auth_state.py +++ b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py @@ -16,8 +16,8 @@ from typing import Dict, Set -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.in_memory_state import InMemoryState +from flwr.server.superlink.state import InMemoryState +from flwr.server.superlink.state.authentication import AuthenticationState class InMemoryAuthState(AuthenticationState, InMemoryState): diff --git a/src/py/flwr/server/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py similarity index 96% rename from src/py/flwr/server/state/authentication/sqlite_auth_state.py rename to src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py index 55e4bc73a63b..f19a9f340ac3 100644 --- a/src/py/flwr/server/state/authentication/sqlite_auth_state.py +++ b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py @@ -16,8 +16,8 @@ from typing import Set -from flwr.server.state.authentication.authentication_state import AuthenticationState -from flwr.server.state.sqlite_state import SqliteState +from flwr.server.superlink.state import SqliteState +from flwr.server.superlink.state.authentication import AuthenticationState class SqliteAuthState(AuthenticationState, SqliteState): From a63d664f25ec290d785454c0fb6c12c50f74a965 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 12:34:33 +0100 Subject: [PATCH 24/49] Remove auth state and format client interceptor test --- src/py/flwr/client/client_interceptor.py | 28 ++++-- src/py/flwr/client/client_interceptor_test.py | 64 ++++++------ .../state/authentication/__init__.py | 25 ----- .../authentication/authentication_state.py | 58 ----------- .../authentication_state_test.py | 97 ------------------- .../authentication/in_memory_auth_state.py | 73 -------------- .../state/authentication/sqlite_auth_state.py | 81 ---------------- 7 files changed, 51 insertions(+), 375 deletions(-) delete mode 100644 src/py/flwr/server/superlink/state/authentication/__init__.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/authentication_state_test.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py delete mode 100644 src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index 3aae58908f77..9962764f6c60 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -15,11 +15,10 @@ """Flower client interceptor.""" import collections -from typing import Callable, Sequence, Tuple, Union +from typing import Any, Callable, Optional, Sequence, Tuple, Union import grpc from cryptography.hazmat.primitives.asymmetric import ec -from grpc import ClientCallDetails, UnaryUnaryClientInterceptor from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( bytes_to_public_key, @@ -27,7 +26,7 @@ generate_shared_key, public_key_to_bytes, ) -from flwr.proto.fleet_pb2 import ( +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, @@ -44,20 +43,24 @@ def _get_value_from_tuples( key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] -) -> Union[str, bytes]: - return next((value[::-1] for key, value in tuples if key == key_string), "") +) -> bytes: + value = next((value[::-1] for key, value in tuples if key == key_string), "") + if isinstance(value, str): + return value.encode() + + return value class _ClientCallDetails( collections.namedtuple( "_ClientCallDetails", ("method", "timeout", "metadata", "credentials") ), - ClientCallDetails, # type: ignore + grpc.ClientCallDetails, # type: ignore ): pass -class AuthenticateClientInterceptor(UnaryUnaryClientInterceptor): # type: ignore +class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore """Client interceptor for client authentication.""" def __init__( @@ -67,13 +70,15 @@ def __init__( ): self.private_key = private_key self.public_key = public_key + self.shared_secret = b"" + self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None def intercept_unary_unary( self, - continuation: Callable, + continuation: Callable[[Any, Any], Any], client_call_details: grpc.ClientCallDetails, request: Request, - ): + ) -> grpc.Call: """Flower client interceptor.""" metadata = [] postprocess = False @@ -88,7 +93,10 @@ def intercept_unary_unary( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) ): metadata.append( - (_AUTH_TOKEN_HEADER, compute_hmac(self.shared_secret, request)) + ( + _AUTH_TOKEN_HEADER, + compute_hmac(self.shared_secret, request.SerializeToString(True)), + ) ) else: pass diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 24c68f7a1eab..9f4d65a9333a 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -17,18 +17,25 @@ import threading import unittest from concurrent import futures -from typing import Callable, ContextManager, Optional, Tuple, Union +from typing import Optional, Sequence, Tuple, Union import grpc from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH -from flwr.common.message import Message from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, ) -from flwr.proto import fleet_pb2 as flwr_dot_proto_dot_fleet__pb2 -from flwr.proto.fleet_pb2 import CreateNodeRequest +from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 + CreateNodeRequest, + CreateNodeResponse, + DeleteNodeRequest, + DeleteNodeResponse, + PullTaskInsRequest, + PullTaskInsResponse, + PushTaskResRequest, + PushTaskResResponse, +) from .client_interceptor import AuthenticateClientInterceptor, Request @@ -36,12 +43,18 @@ class _MockServicer: + """Mock Servicer for Flower clients.""" + def __init__(self) -> None: + """Initialize mock servicer.""" self._lock = threading.Lock() - self._received_client_metadata = None + self._received_client_metadata: Optional[ + Sequence[Tuple[str, Union[str, bytes]]] + ] = None _, self._server_public_key = generate_key_pairs() def unary_unary(self, request: Request, context: grpc.ServicerContext) -> object: + """Handle unary call.""" with self._lock: self._received_client_metadata = context.invocation_metadata() if isinstance(request, CreateNodeRequest): @@ -51,7 +64,10 @@ def unary_unary(self, request: Request, context: grpc.ServicerContext) -> object return object() - def received_client_metadata(self): + def received_client_metadata( + self, + ) -> Optional[Sequence[Tuple[str, Union[str, bytes]]]]: + """Return received client metadata.""" with self._lock: return self._received_client_metadata @@ -60,23 +76,23 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: rpc_method_handlers = { "CreateNode": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.CreateNodeResponse.SerializeToString, + request_deserializer=CreateNodeRequest.FromString, + response_serializer=CreateNodeResponse.SerializeToString, ), "DeleteNode": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.DeleteNodeResponse.SerializeToString, + request_deserializer=DeleteNodeRequest.FromString, + response_serializer=DeleteNodeResponse.SerializeToString, ), "PullTaskIns": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PullTaskInsResponse.SerializeToString, + request_deserializer=PullTaskInsRequest.FromString, + response_serializer=PullTaskInsResponse.SerializeToString, ), "PushTaskRes": grpc.unary_unary_rpc_method_handler( servicer.unary_unary, - request_deserializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResRequest.FromString, - response_serializer=flwr_dot_proto_dot_fleet__pb2.PushTaskResResponse.SerializeToString, + request_deserializer=PushTaskResRequest.FromString, + response_serializer=PushTaskResResponse.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -103,23 +119,7 @@ def setUp(self) -> None: self._client_private_key, self._client_public_key ) - self._connection: Callable[ - [str, bool, int, Union[bytes, str, None]], - ContextManager[ - Tuple[ - Callable[[], Optional[Message]], - Callable[[Message], None], - Optional[Callable[[], None]], - Optional[Callable[[], None]], - ] - ], - ] = grpc_request_response( - f"localhost:{port}", - False, - GRPC_MAX_MESSAGE_LENGTH, - None, - [self._client_interceptor], - ) + self._connection = grpc_request_response self._address = f"localhost:{port}" def test_client_auth_create_node(self) -> None: @@ -129,8 +129,10 @@ def test_client_auth_create_node(self) -> None: True, GRPC_MAX_MESSAGE_LENGTH, None, + (self._client_interceptor), ) as conn: _, _, create_node, _ = conn + assert create_node is not None create_node() assert self._servicer.received_client_metadata is not None diff --git a/src/py/flwr/server/superlink/state/authentication/__init__.py b/src/py/flwr/server/superlink/state/authentication/__init__.py deleted file mode 100644 index 8f5c0a97ab1f..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Flower server authentication state.""" - -from .authentication_state import AuthenticationState as AuthenticationState -from .in_memory_auth_state import InMemoryAuthState as InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState as SqliteAuthState - -__all__ = [ - "AuthenticationState", - "InMemoryAuthState", - "SqliteAuthState", -] diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state.py b/src/py/flwr/server/superlink/state/authentication/authentication_state.py deleted file mode 100644 index 2c0b1ee567fa..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/authentication_state.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Abstract base class AuthenticationState.""" - -import abc -from typing import Set - -from flwr.server.superlink.state import State - - -class AuthenticationState(State, abc.ABC): - """Abstract Authentication State.""" - - @abc.abstractmethod - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - - @abc.abstractmethod - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - - @abc.abstractmethod - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - - @abc.abstractmethod - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - - @abc.abstractmethod - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - - @abc.abstractmethod - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - - @abc.abstractmethod - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - - @abc.abstractmethod - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" diff --git a/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py b/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py deleted file mode 100644 index f18c428d3044..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/authentication_state_test.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test for authentication state.""" - - -from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( - compute_hmac, - generate_key_pairs, - generate_shared_key, - public_key_to_bytes, - verify_hmac, -) - -from .in_memory_auth_state import InMemoryAuthState -from .sqlite_auth_state import SqliteAuthState - - -def test_in_memory_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - in_memory_auth_state = InMemoryAuthState() - in_memory_auth_state.store_client_public_keys(public_keys) - - assert in_memory_auth_state.get_client_public_keys() == public_keys - - -def test_sqlite_client_public_keys() -> None: - """Test client public keys store and get from state.""" - key_pairs = [generate_key_pairs() for _ in range(3)] - public_keys = {public_key_to_bytes(pair[1]) for pair in key_pairs} - - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - sqlite_auth_state.store_client_public_keys(public_keys) - - assert sqlite_auth_state.get_client_public_keys() == public_keys - - -def test_in_memory_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - in_memory_auth_state = InMemoryAuthState() - node_id = in_memory_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - in_memory_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert in_memory_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_sqlite_node_id_public_key_pair() -> None: - """Test store and get node_id public_key pair.""" - sqlite_auth_state = SqliteAuthState(":memory:") - sqlite_auth_state.initialize() - node_id = sqlite_auth_state.create_node() - public_key = public_key_to_bytes(generate_key_pairs()[1]) - - sqlite_auth_state.store_node_id_public_key_pair(node_id, public_key) - - assert sqlite_auth_state.get_public_key_from_node_id(node_id) == public_key - - -def test_generate_shared_key() -> None: - """Test util function generate_shared_key.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - - assert client_shared_secret == server_shared_secret - - -def test_hmac() -> None: - """Test util function compute and verify hmac.""" - client_keys = generate_key_pairs() - server_keys = generate_key_pairs() - client_shared_secret = generate_shared_key(client_keys[0], server_keys[1]) - server_shared_secret = generate_shared_key(server_keys[0], client_keys[1]) - message = b"Flower is the future of AI" - - client_compute_hmac = compute_hmac(client_shared_secret, message) - - assert verify_hmac(server_shared_secret, message, client_compute_hmac) diff --git a/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py b/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py deleted file mode 100644 index 77aa5fdc5af6..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/in_memory_auth_state.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""In-memory Authentication State implementation.""" - -from typing import Dict, Set - -from flwr.server.superlink.state import InMemoryState -from flwr.server.superlink.state.authentication import AuthenticationState - - -class InMemoryAuthState(AuthenticationState, InMemoryState): - """In-memory-based authentication state implementation.""" - - def __init__(self) -> None: - """Init InMemoryAuthState.""" - super().__init__() - self.node_id_public_key_dict: Dict[int, bytes] = {} - self.client_public_keys: Set[bytes] = set() - self.server_public_key: bytes = b"" - self.server_private_key: bytes = b"" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - if node_id not in self.node_ids: - raise ValueError(f"Node {node_id} not found") - if node_id in self.node_id_public_key_dict: - raise ValueError(f"Node {node_id} has already assigned a public key") - self.node_id_public_key_dict[node_id] = public_key - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - if node_id in self.node_id_public_key_dict: - return self.node_id_public_key_dict[node_id] - return b"" - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - self.server_private_key = private_key - self.server_public_key = public_key - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - return self.server_private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - return self.server_public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - self.client_public_keys = public_keys - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - self.client_public_keys.add(public_key) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - return self.client_public_keys diff --git a/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py b/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py deleted file mode 100644 index f19a9f340ac3..000000000000 --- a/src/py/flwr/server/superlink/state/authentication/sqlite_auth_state.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SQLite based implementation of server authentication state.""" - -from typing import Set - -from flwr.server.superlink.state import SqliteState -from flwr.server.superlink.state.authentication import AuthenticationState - - -class SqliteAuthState(AuthenticationState, SqliteState): - """SQLite-based authentication state implementation.""" - - def store_node_id_public_key_pair(self, node_id: int, public_key: bytes) -> None: - """Store `node_id` and `public_key` as key-value pair in state.""" - query = ( - "INSERT OR REPLACE INTO node_key (node_id, public_key) " - "VALUES (:node_id, :public_key)" - ) - self.query(query, {"node_id": node_id, "public_key": public_key}) - - def get_public_key_from_node_id(self, node_id: int) -> bytes: - """Get client's public key in urlsafe bytes for `node_id`.""" - query = "SELECT public_key FROM node_key WHERE node_id = :node_id" - rows = self.query(query, {"node_id": node_id}) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_server_public_private_key( - self, public_key: bytes, private_key: bytes - ) -> None: - """Store server's `public_key` and `private_key` in state.""" - query = ( - "INSERT OR REPLACE INTO credential (public_key, private_key) " - "VALUES (:public_key, :private_key)" - ) - self.query(query, {"public_key": public_key, "private_key": private_key}) - - def get_server_private_key(self) -> bytes: - """Get server private key in urlsafe bytes.""" - query = "SELECT private_key FROM credential" - rows = self.query(query) - private_key: bytes = rows[0]["private_key"] - return private_key - - def get_server_public_key(self) -> bytes: - """Get server public key in urlsafe bytes.""" - query = "SELECT public_key FROM credential" - rows = self.query(query) - public_key: bytes = rows[0]["public_key"] - return public_key - - def store_client_public_keys(self, public_keys: Set[bytes]) -> None: - """Store a set of client public keys in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - for public_key in public_keys: - self.query(query, {"public_key": public_key}) - - def store_client_public_key(self, public_key: bytes) -> None: - """Retrieve a client public key in state.""" - query = "INSERT INTO public_key (public_key) VALUES (:public_key)" - self.query(query, {"public_key": public_key}) - - def get_client_public_keys(self) -> Set[bytes]: - """Retrieve all currently stored client public keys as a set.""" - query = "SELECT public_key FROM public_key" - rows = self.query(query) - result: Set[bytes] = {row["public_key"] for row in rows} - return result From 01be61b23a49bd326fc507e4e22cd2500fceaa8f Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 12:39:35 +0100 Subject: [PATCH 25/49] Revert auth state changes --- .../server/superlink/state/sqlite_state.py | 23 ------------------- .../flwr/server/superlink/state/state_test.py | 4 ++-- 2 files changed, 2 insertions(+), 25 deletions(-) diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index e91d8553863c..224c16cdf013 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -37,26 +37,6 @@ ); """ -SQL_CREATE_TABLE_NODE_KEY = """ -CREATE TABLE IF NOT EXISTS node_key( - node_id INTEGER PRIMARY KEY, - public_key BLOB -); -""" - -SQL_CREATE_TABLE_CREDENTIAL = """ -CREATE TABLE IF NOT EXISTS credential( - public_key BLOB PRIMARY KEY, - private_key BLOB -); -""" - -SQL_CREATE_TABLE_PUBLIC_KEY = """ -CREATE TABLE IF NOT EXISTS public_key( - public_key BLOB UNIQUE -); -""" - SQL_CREATE_TABLE_RUN = """ CREATE TABLE IF NOT EXISTS run( run_id INTEGER UNIQUE @@ -143,9 +123,6 @@ def initialize(self, log_queries: bool = False) -> List[Tuple[str]]: cur.execute(SQL_CREATE_TABLE_TASK_INS) cur.execute(SQL_CREATE_TABLE_TASK_RES) cur.execute(SQL_CREATE_TABLE_NODE) - cur.execute(SQL_CREATE_TABLE_CREDENTIAL) - cur.execute(SQL_CREATE_TABLE_NODE_KEY) - cur.execute(SQL_CREATE_TABLE_PUBLIC_KEY) res = cur.execute("SELECT name FROM sqlite_schema;") return res.fetchall() diff --git a/src/py/flwr/server/superlink/state/state_test.py b/src/py/flwr/server/superlink/state/state_test.py index 803702bb97bb..d0470a7ce7f7 100644 --- a/src/py/flwr/server/superlink/state/state_test.py +++ b/src/py/flwr/server/superlink/state/state_test.py @@ -477,7 +477,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 8 class SqliteFileBasedTest(StateTest, unittest.TestCase): @@ -502,7 +502,7 @@ def test_initialize(self) -> None: result = state.query("SELECT name FROM sqlite_schema;") # Assert - assert len(result) == 13 + assert len(result) == 8 if __name__ == "__main__": From d922cccec0ce61f13beabf1d1cfeb1cd378fb3e3 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 15:01:22 +0100 Subject: [PATCH 26/49] add auth-token metadata client interceptor test --- src/py/flwr/client/client_interceptor_test.py | 57 +++++++++++++++++-- 1 file changed, 51 insertions(+), 6 deletions(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 9f4d65a9333a..b9aff476da8b 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -24,7 +24,10 @@ from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( + compute_hmac, generate_key_pairs, + generate_shared_key, + public_key_to_bytes, ) from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, @@ -37,9 +40,12 @@ PushTaskResResponse, ) -from .client_interceptor import AuthenticateClientInterceptor, Request - -_PUBLIC_KEY_HEADER = "public-key" +from .client_interceptor import ( + _AUTH_TOKEN_HEADER, + _PUBLIC_KEY_HEADER, + AuthenticateClientInterceptor, + Request, +) class _MockServicer: @@ -51,15 +57,17 @@ def __init__(self) -> None: self._received_client_metadata: Optional[ Sequence[Tuple[str, Union[str, bytes]]] ] = None - _, self._server_public_key = generate_key_pairs() + self.server_private_key, self.server_public_key = generate_key_pairs() + self._received_message_bytes: bytes = b"" def unary_unary(self, request: Request, context: grpc.ServicerContext) -> object: """Handle unary call.""" with self._lock: self._received_client_metadata = context.invocation_metadata() + self._received_message_bytes = request.SerializeToString(True) if isinstance(request, CreateNodeRequest): context.set_trailing_metadata( - ((_PUBLIC_KEY_HEADER, self._server_public_key),) + ((_PUBLIC_KEY_HEADER, self.server_public_key),) ) return object() @@ -71,6 +79,11 @@ def received_client_metadata( with self._lock: return self._received_client_metadata + def received_message_bytes(self) -> bytes: + """Return received message bytes.""" + with self._lock: + return self._received_message_bytes + def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: rpc_method_handlers = { @@ -134,7 +147,39 @@ def test_client_auth_create_node(self) -> None: _, _, create_node, _ = conn assert create_node is not None create_node() - assert self._servicer.received_client_metadata is not None + expected_client_metadata = ( + _PUBLIC_KEY_HEADER, + public_key_to_bytes(self._client_public_key), + ) + assert self._servicer.received_client_metadata() == expected_client_metadata + assert ( + self._client_interceptor.server_public_key + == self._servicer.server_public_key + ) + + def test_client_auth_delete_node(self) -> None: + """Test client authentication during delete node.""" + with self._connection( + self._address, + True, + GRPC_MAX_MESSAGE_LENGTH, + None, + (self._client_interceptor), + ) as conn: + _, _, _, delete_node = conn + assert delete_node is not None + delete_node() + shared_secret = generate_shared_key( + self._servicer.server_private_key, self._client_public_key + ) + expected_hmac = compute_hmac( + shared_secret, self._servicer.received_message_bytes() + ) + expected_client_metadata = ( + _AUTH_TOKEN_HEADER, + expected_hmac, + ) + assert self._servicer.received_client_metadata() == expected_client_metadata if __name__ == "__main__": From f8fefe1dc5495f8a1901d4f7e96cd4c067a621b3 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 14 Feb 2024 15:19:01 +0100 Subject: [PATCH 27/49] Remove pass --- src/py/flwr/client/client_interceptor.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index 9962764f6c60..e091d92b0bca 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -98,8 +98,6 @@ def intercept_unary_unary( compute_hmac(self.shared_secret, request.SerializeToString(True)), ) ) - else: - pass client_call_details = _ClientCallDetails( client_call_details.method, From bcde8755e92db4291166110880c6b0b09d0408bc Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 21 Feb 2024 19:59:42 +0100 Subject: [PATCH 28/49] Update client-interceptor-test --- src/py/flwr/client/client_interceptor.py | 26 +++++++++++++--- src/py/flwr/client/client_interceptor_test.py | 31 +++++++++++++++---- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index e091d92b0bca..4fa4158bd48e 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower client interceptor.""" +import base64 import collections from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -44,7 +45,7 @@ def _get_value_from_tuples( key_string: str, tuples: Sequence[Tuple[str, Union[str, bytes]]] ) -> bytes: - value = next((value[::-1] for key, value in tuples if key == key_string), "") + value = next((value for key, value in tuples if key == key_string), "") if isinstance(value, str): return value.encode() @@ -86,16 +87,31 @@ def intercept_unary_unary( metadata = list(client_call_details.metadata) if isinstance(request, CreateNodeRequest): - metadata.append((_PUBLIC_KEY_HEADER, public_key_to_bytes(self.public_key))) + metadata.append( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode(public_key_to_bytes(self.public_key)), + ) + ) postprocess = True elif isinstance( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) ): + metadata.append( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode(public_key_to_bytes(self.public_key)), + ) + ) metadata.append( ( _AUTH_TOKEN_HEADER, - compute_hmac(self.shared_secret, request.SerializeToString(True)), + base64.urlsafe_b64encode( + compute_hmac( + self.shared_secret, request.SerializeToString(True) + ) + ), ) ) @@ -108,8 +124,8 @@ def intercept_unary_unary( response = continuation(client_call_details, request) if postprocess: - server_public_key_bytes = _get_value_from_tuples( - _PUBLIC_KEY_HEADER, response.trailing_metadata + server_public_key_bytes = base64.urlsafe_b64decode( + _get_value_from_tuples(_PUBLIC_KEY_HEADER, response.initial_metadata()) ) self.server_public_key = bytes_to_public_key(server_public_key_bytes) self.shared_secret = generate_shared_key( diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index b9aff476da8b..3d2ef2c5e5fd 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower client interceptor tests.""" +import base64 import threading import unittest from concurrent import futures @@ -60,17 +61,27 @@ def __init__(self) -> None: self.server_private_key, self.server_public_key = generate_key_pairs() self._received_message_bytes: bytes = b"" - def unary_unary(self, request: Request, context: grpc.ServicerContext) -> object: + def unary_unary( + self, request: Request, context: grpc.ServicerContext + ) -> Union[ + CreateNodeResponse, DeleteNodeResponse, PushTaskResResponse, PullTaskInsResponse + ]: """Handle unary call.""" with self._lock: self._received_client_metadata = context.invocation_metadata() self._received_message_bytes = request.SerializeToString(True) + if isinstance(request, CreateNodeRequest): - context.set_trailing_metadata( + context.send_initial_metadata( ((_PUBLIC_KEY_HEADER, self.server_public_key),) ) + return CreateNodeResponse() + if isinstance(request, DeleteNodeRequest): + return DeleteNodeResponse() + if isinstance(request, PushTaskResRequest): + return PushTaskResResponse() - return object() + return PullTaskInsResponse() def received_client_metadata( self, @@ -149,7 +160,7 @@ def test_client_auth_create_node(self) -> None: create_node() expected_client_metadata = ( _PUBLIC_KEY_HEADER, - public_key_to_bytes(self._client_public_key), + base64.urlsafe_b64encode(public_key_to_bytes(self._client_public_key)), ) assert self._servicer.received_client_metadata() == expected_client_metadata assert ( @@ -176,8 +187,16 @@ def test_client_auth_delete_node(self) -> None: shared_secret, self._servicer.received_message_bytes() ) expected_client_metadata = ( - _AUTH_TOKEN_HEADER, - expected_hmac, + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ), + ), + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode(expected_hmac), + ), ) assert self._servicer.received_client_metadata() == expected_client_metadata From b0473911f740b653b24b3d87ec65afd15bab5cc4 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 29 Feb 2024 19:12:29 +0100 Subject: [PATCH 29/49] Address review --- src/py/flwr/client/client_interceptor.py | 10 ++++++++-- src/py/flwr/client/client_interceptor_test.py | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index 4fa4158bd48e..b92a97002942 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -1,4 +1,4 @@ -# Copyright 2020 Flower Labs GmbH. All Rights Reserved. +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # ============================================================================== """Flower client interceptor.""" + import base64 import collections from typing import Any, Callable, Optional, Sequence, Tuple, Union @@ -80,7 +81,12 @@ def intercept_unary_unary( client_call_details: grpc.ClientCallDetails, request: Request, ) -> grpc.Call: - """Flower client interceptor.""" + """Flower client interceptor. + + Intercept unary call from client and do authentication process by validating + metadata sent from client. Continue RPC call if client is authenticated, else, + terminate RPC call by setting context to abort. + """ metadata = [] postprocess = False if client_call_details.metadata is not None: diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 3d2ef2c5e5fd..db067d0a7db0 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -14,6 +14,7 @@ # ============================================================================== """Flower client interceptor tests.""" + import base64 import threading import unittest From c43c3b55c222dce7434ea79a23eae467c8d54bbd Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 29 Feb 2024 19:23:56 +0100 Subject: [PATCH 30/49] Fix docstring --- src/py/flwr/client/client_interceptor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index b92a97002942..d95d4ea7626c 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -83,9 +83,8 @@ def intercept_unary_unary( ) -> grpc.Call: """Flower client interceptor. - Intercept unary call from client and do authentication process by validating - metadata sent from client. Continue RPC call if client is authenticated, else, - terminate RPC call by setting context to abort. + Intercept unary call from client and add necessary authentication header + in the RPC metadata. """ metadata = [] postprocess = False From 031c9f7673a54a4e43ba380def22a8527ad45369 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Fri, 1 Mar 2024 07:42:40 +0100 Subject: [PATCH 31/49] Format --- src/py/flwr/client/client_interceptor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/client_interceptor.py index d95d4ea7626c..a9510f0e1dc7 100644 --- a/src/py/flwr/client/client_interceptor.py +++ b/src/py/flwr/client/client_interceptor.py @@ -83,8 +83,8 @@ def intercept_unary_unary( ) -> grpc.Call: """Flower client interceptor. - Intercept unary call from client and add necessary authentication header - in the RPC metadata. + Intercept unary call from client and add necessary authentication header in the + RPC metadata. """ metadata = [] postprocess = False From f6832e8ac113d4b38050309205f324d0c296969d Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 4 Apr 2024 07:53:19 +0200 Subject: [PATCH 32/49] Add retry invoker --- src/py/flwr/client/client_interceptor_test.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index db067d0a7db0..b70cf6ca4908 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -20,6 +20,7 @@ import unittest from concurrent import futures from typing import Optional, Sequence, Tuple, Union +from flwr.common.retry_invoker import RetryInvoker, exponential import grpc @@ -149,9 +150,19 @@ def setUp(self) -> None: def test_client_auth_create_node(self) -> None: """Test client authentication during create node.""" + retry_invoker = RetryInvoker( + wait_gen_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=None, + max_time=None, + on_giveup=lambda retry_state: (), + on_success=lambda retry_state: (), + on_backoff=lambda retry_state: (), + ) with self._connection( self._address, True, + retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, (self._client_interceptor), @@ -171,9 +182,19 @@ def test_client_auth_create_node(self) -> None: def test_client_auth_delete_node(self) -> None: """Test client authentication during delete node.""" + retry_invoker = RetryInvoker( + wait_gen_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=None, + max_time=None, + on_giveup=lambda retry_state: (), + on_success=lambda retry_state: (), + on_backoff=lambda retry_state: (), + ) with self._connection( self._address, True, + retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, (self._client_interceptor), From 9ca4e561efb27e84e0c77d43720d847fa22461b8 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Thu, 4 Apr 2024 08:36:54 +0200 Subject: [PATCH 33/49] Fix mypy --- src/py/flwr/client/client_interceptor_test.py | 68 +++++++++++++++++-- .../client/grpc_rere_client/connection.py | 2 +- 2 files changed, 62 insertions(+), 8 deletions(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index b70cf6ca4908..2b24a79c0161 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -19,13 +19,15 @@ import threading import unittest from concurrent import futures +from logging import DEBUG, INFO, WARN from typing import Optional, Sequence, Tuple, Union -from flwr.common.retry_invoker import RetryInvoker, exponential import grpc from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, generate_key_pairs, @@ -155,9 +157,35 @@ def test_client_auth_create_node(self) -> None: recoverable_exceptions=grpc.RpcError, max_tries=None, max_time=None, - on_giveup=lambda retry_state: (), - on_success=lambda retry_state: (), - on_backoff=lambda retry_state: (), + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), ) with self._connection( self._address, @@ -187,9 +215,35 @@ def test_client_auth_delete_node(self) -> None: recoverable_exceptions=grpc.RpcError, max_tries=None, max_time=None, - on_giveup=lambda retry_state: (), - on_success=lambda retry_state: (), - on_backoff=lambda retry_state: (), + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), ) with self._connection( self._address, diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 1cccbc23e48a..06987d9c5d5d 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -59,7 +59,7 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_request_response( # pylint: disable=R0914, R0915 +def grpc_request_response( # pylint: disable=R0913, R0914, R0915 server_address: str, insecure: bool, retry_invoker: RetryInvoker, From 03039fc7f4e0571c54ce8e23a22423fa2b44f2b6 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 12:47:22 +0200 Subject: [PATCH 34/49] Format --- src/py/flwr/client/client_interceptor_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 2b24a79c0161..b814d23bb364 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -195,7 +195,7 @@ def test_client_auth_create_node(self) -> None: None, (self._client_interceptor), ) as conn: - _, _, create_node, _ = conn + _, _, create_node, _, _ = conn assert create_node is not None create_node() expected_client_metadata = ( @@ -253,7 +253,7 @@ def test_client_auth_delete_node(self) -> None: None, (self._client_interceptor), ) as conn: - _, _, _, delete_node = conn + _, _, _, delete_node, _ = conn assert delete_node is not None delete_node() shared_secret = generate_shared_key( From 4a6dc38bfad7373228809a010d74076a20c0cc3b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 13:51:31 +0200 Subject: [PATCH 35/49] Add more tests --- src/py/flwr/client/client_interceptor_test.py | 200 ++++++++++++------ 1 file changed, 130 insertions(+), 70 deletions(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index b814d23bb364..0f3570d1bde6 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -27,6 +27,7 @@ from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.logger import log +from flwr.common.message import Message, Metadata, RecordSet from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, @@ -129,6 +130,44 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: server.add_generic_rpc_handlers((generic_handler,)) +def _init_retry_invoker() -> RetryInvoker: + return RetryInvoker( + wait_gen_factory=exponential, + recoverable_exceptions=grpc.RpcError, + max_tries=None, + max_time=None, + on_giveup=lambda retry_state: ( + log( + WARN, + "Giving up reconnection after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_success=lambda retry_state: ( + log( + INFO, + "Connection successful after %.2f seconds and %s tries.", + retry_state.elapsed_time, + retry_state.tries, + ) + if retry_state.tries > 1 + else None + ), + on_backoff=lambda retry_state: ( + log(WARN, "Connection attempt failed, retrying...") + if retry_state.tries == 1 + else log( + DEBUG, + "Connection attempt failed, retrying in %.2f seconds", + retry_state.actual_wait, + ) + ), + ) + + class TestAuthenticateClientInterceptor(unittest.TestCase): """Test for client interceptor client authentication.""" @@ -152,41 +191,10 @@ def setUp(self) -> None: def test_client_auth_create_node(self) -> None: """Test client authentication during create node.""" - retry_invoker = RetryInvoker( - wait_gen_factory=exponential, - recoverable_exceptions=grpc.RpcError, - max_tries=None, - max_time=None, - on_giveup=lambda retry_state: ( - log( - WARN, - "Giving up reconnection after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - if retry_state.tries > 1 - else None - ), - on_success=lambda retry_state: ( - log( - INFO, - "Connection successful after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - if retry_state.tries > 1 - else None - ), - on_backoff=lambda retry_state: ( - log(WARN, "Connection attempt failed, retrying...") - if retry_state.tries == 1 - else log( - DEBUG, - "Connection attempt failed, retrying in %.2f seconds", - retry_state.actual_wait, - ) - ), - ) + # Prepare + retry_invoker = _init_retry_invoker() + + # Execute with self._connection( self._address, True, @@ -202,6 +210,8 @@ def test_client_auth_create_node(self) -> None: _PUBLIC_KEY_HEADER, base64.urlsafe_b64encode(public_key_to_bytes(self._client_public_key)), ) + + # Assert assert self._servicer.received_client_metadata() == expected_client_metadata assert ( self._client_interceptor.server_public_key @@ -210,41 +220,10 @@ def test_client_auth_create_node(self) -> None: def test_client_auth_delete_node(self) -> None: """Test client authentication during delete node.""" - retry_invoker = RetryInvoker( - wait_gen_factory=exponential, - recoverable_exceptions=grpc.RpcError, - max_tries=None, - max_time=None, - on_giveup=lambda retry_state: ( - log( - WARN, - "Giving up reconnection after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - if retry_state.tries > 1 - else None - ), - on_success=lambda retry_state: ( - log( - INFO, - "Connection successful after %.2f seconds and %s tries.", - retry_state.elapsed_time, - retry_state.tries, - ) - if retry_state.tries > 1 - else None - ), - on_backoff=lambda retry_state: ( - log(WARN, "Connection attempt failed, retrying...") - if retry_state.tries == 1 - else log( - DEBUG, - "Connection attempt failed, retrying in %.2f seconds", - retry_state.actual_wait, - ) - ), - ) + # Prepare + retry_invoker = _init_retry_invoker() + + # Execute with self._connection( self._address, True, @@ -274,6 +253,87 @@ def test_client_auth_delete_node(self) -> None: base64.urlsafe_b64encode(expected_hmac), ), ) + + # Assert + assert self._servicer.received_client_metadata() == expected_client_metadata + + def test_client_auth_receive(self) -> None: + """Test client authentication during receive node.""" + # Prepare + retry_invoker = _init_retry_invoker() + + # Execute + with self._connection( + self._address, + True, + retry_invoker, + GRPC_MAX_MESSAGE_LENGTH, + None, + (self._client_interceptor), + ) as conn: + receive, _, _, _, _ = conn + assert receive is not None + receive() + shared_secret = generate_shared_key( + self._servicer.server_private_key, self._client_public_key + ) + expected_hmac = compute_hmac( + shared_secret, self._servicer.received_message_bytes() + ) + expected_client_metadata = ( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ), + ), + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode(expected_hmac), + ), + ) + + # Assert + assert self._servicer.received_client_metadata() == expected_client_metadata + + def test_client_auth_send(self) -> None: + """Test client authentication during send node.""" + # Prepare + retry_invoker = _init_retry_invoker() + message = Message(Metadata(0, "1", 0, 0, "", "", 0, ""), RecordSet()) + + # Execute + with self._connection( + self._address, + True, + retry_invoker, + GRPC_MAX_MESSAGE_LENGTH, + None, + (self._client_interceptor), + ) as conn: + _, send, _, _, _ = conn + assert send is not None + send(message) + shared_secret = generate_shared_key( + self._servicer.server_private_key, self._client_public_key + ) + expected_hmac = compute_hmac( + shared_secret, self._servicer.received_message_bytes() + ) + expected_client_metadata = ( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ), + ), + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode(expected_hmac), + ), + ) + + # Assert assert self._servicer.received_client_metadata() == expected_client_metadata From 05edfbdd5ec74a93de4b2b975f78a70bcf08218b Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 14:01:06 +0200 Subject: [PATCH 36/49] Format --- src/py/flwr/client/client_interceptor_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index 0f3570d1bde6..e7d6e360c3e9 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -27,7 +27,8 @@ from flwr.client.grpc_rere_client.connection import grpc_request_response from flwr.common import GRPC_MAX_MESSAGE_LENGTH from flwr.common.logger import log -from flwr.common.message import Message, Metadata, RecordSet +from flwr.common.message import Message, Metadata +from flwr.common.record import RecordSet from flwr.common.retry_invoker import RetryInvoker, exponential from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, From 0d9363138847f1999aeddb487a36c979affd15af Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:06:21 +0200 Subject: [PATCH 37/49] Add interceptors to other contextmanager --- src/py/flwr/client/grpc_client/connection.py | 9 +++++++-- src/py/flwr/client/rest_client/connection.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 9b722037b52c..399d9b295e0a 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,7 +20,9 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast + +import grpc from flwr.common import ( DEFAULT_TTL, @@ -56,12 +58,15 @@ def on_channel_state_change(channel_connectivity: str) -> None: @contextmanager -def grpc_connection( # pylint: disable=R0915 +def grpc_connection( # pylint: disable=R0913, R0915 server_address: str, insecure: bool, retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, + interceptors: Optional[ # pylint: disable=unused-argument + Sequence[grpc.UnaryUnaryClientInterceptor] + ] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 54d919f619e0..79e0d1d70f86 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -21,8 +21,9 @@ from contextlib import contextmanager from copy import copy from logging import ERROR, INFO, WARN -from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union +from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, TypeVar, Union +import grpc from google.protobuf.message import Message as GrpcMessage from flwr.client.heartbeat import start_ping_loop @@ -74,7 +75,7 @@ @contextmanager -def http_request_response( # pylint: disable=R0914, R0915 +def http_request_response( # pylint: disable=,R0913, R0914, R0915 server_address: str, insecure: bool, # pylint: disable=unused-argument retry_invoker: RetryInvoker, @@ -82,6 +83,9 @@ def http_request_response( # pylint: disable=R0914, R0915 root_certificates: Optional[ Union[bytes, str] ] = None, # pylint: disable=unused-argument + interceptors: Optional[ # pylint: disable=unused-argument + Sequence[grpc.UnaryUnaryClientInterceptor] + ] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], From 7f695f5dcfcd122c23fe4e293f53dc4647198638 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:21:04 +0200 Subject: [PATCH 38/49] Replace interceptors with authentication keys --- src/py/flwr/client/client_interceptor_test.py | 22 +++++-------------- src/py/flwr/client/grpc_client/connection.py | 8 +++---- .../client/grpc_rere_client/connection.py | 12 +++++++++- src/py/flwr/client/rest_client/connection.py | 8 +++---- 4 files changed, 24 insertions(+), 26 deletions(-) diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/client_interceptor_test.py index e7d6e360c3e9..eeacc9765977 100644 --- a/src/py/flwr/client/client_interceptor_test.py +++ b/src/py/flwr/client/client_interceptor_test.py @@ -47,12 +47,7 @@ PushTaskResResponse, ) -from .client_interceptor import ( - _AUTH_TOKEN_HEADER, - _PUBLIC_KEY_HEADER, - AuthenticateClientInterceptor, - Request, -) +from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request class _MockServicer: @@ -183,9 +178,6 @@ def setUp(self) -> None: port = self._server.add_insecure_port("[::]:0") self._server.start() self._client_private_key, self._client_public_key = generate_key_pairs() - self._client_interceptor = AuthenticateClientInterceptor( - self._client_private_key, self._client_public_key - ) self._connection = grpc_request_response self._address = f"localhost:{port}" @@ -202,7 +194,7 @@ def test_client_auth_create_node(self) -> None: retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, - (self._client_interceptor), + (self._client_private_key, self._client_public_key), ) as conn: _, _, create_node, _, _ = conn assert create_node is not None @@ -214,10 +206,6 @@ def test_client_auth_create_node(self) -> None: # Assert assert self._servicer.received_client_metadata() == expected_client_metadata - assert ( - self._client_interceptor.server_public_key - == self._servicer.server_public_key - ) def test_client_auth_delete_node(self) -> None: """Test client authentication during delete node.""" @@ -231,7 +219,7 @@ def test_client_auth_delete_node(self) -> None: retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, - (self._client_interceptor), + (self._client_private_key, self._client_public_key), ) as conn: _, _, _, delete_node, _ = conn assert delete_node is not None @@ -270,7 +258,7 @@ def test_client_auth_receive(self) -> None: retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, - (self._client_interceptor), + (self._client_private_key, self._client_public_key), ) as conn: receive, _, _, _, _ = conn assert receive is not None @@ -310,7 +298,7 @@ def test_client_auth_send(self) -> None: retry_invoker, GRPC_MAX_MESSAGE_LENGTH, None, - (self._client_interceptor), + (self._client_private_key, self._client_public_key), ) as conn: _, send, _, _, _ = conn assert send is not None diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index 399d9b295e0a..ef7e565492ab 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -20,9 +20,9 @@ from logging import DEBUG from pathlib import Path from queue import Queue -from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast +from typing import Callable, Iterator, Optional, Tuple, Union, cast -import grpc +from cryptography.hazmat.primitives.asymmetric import ec from flwr.common import ( DEFAULT_TTL, @@ -64,8 +64,8 @@ def grpc_connection( # pylint: disable=R0913, R0915 retry_invoker: RetryInvoker, # pylint: disable=unused-argument max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, - interceptors: Optional[ # pylint: disable=unused-argument - Sequence[grpc.UnaryUnaryClientInterceptor] + authentication_keys: Optional[ # pylint: disable=unused-argument + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ Tuple[ diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 913d3cc0c217..eba90ff19670 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -24,7 +24,9 @@ from typing import Callable, Iterator, Optional, Sequence, Tuple, Union, cast import grpc +from cryptography.hazmat.primitives.asymmetric import ec +from flwr.client.client_interceptor import AuthenticateClientInterceptor from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins @@ -67,7 +69,9 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 retry_invoker: RetryInvoker, max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, - interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None, + authentication_keys: Optional[ + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + ] = None, ) -> Iterator[ Tuple[ Callable[[], Optional[Message]], @@ -112,6 +116,12 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 if isinstance(root_certificates, str): root_certificates = Path(root_certificates).read_bytes() + interceptors: Optional[Sequence[grpc.UnaryUnaryClientInterceptor]] = None + if authentication_keys is not None: + interceptors = AuthenticateClientInterceptor( + authentication_keys[0], authentication_keys[1] + ) + channel = create_channel( server_address=server_address, insecure=insecure, diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 79e0d1d70f86..4f060b9ef236 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -21,9 +21,9 @@ from contextlib import contextmanager from copy import copy from logging import ERROR, INFO, WARN -from typing import Callable, Iterator, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar, Union -import grpc +from cryptography.hazmat.primitives.asymmetric import ec from google.protobuf.message import Message as GrpcMessage from flwr.client.heartbeat import start_ping_loop @@ -83,8 +83,8 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 root_certificates: Optional[ Union[bytes, str] ] = None, # pylint: disable=unused-argument - interceptors: Optional[ # pylint: disable=unused-argument - Sequence[grpc.UnaryUnaryClientInterceptor] + authentication_keys: Optional[ # pylint: disable=unused-argument + tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ Tuple[ From 8a80eada16fad140f37d69a006601cf96a8b1cb8 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:26:31 +0200 Subject: [PATCH 39/49] Replace interceptors with authentication keys --- src/py/flwr/client/grpc_client/connection.py | 2 +- src/py/flwr/client/grpc_rere_client/connection.py | 2 +- src/py/flwr/client/rest_client/connection.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/py/flwr/client/grpc_client/connection.py b/src/py/flwr/client/grpc_client/connection.py index ef7e565492ab..6e5227cf5e5f 100644 --- a/src/py/flwr/client/grpc_client/connection.py +++ b/src/py/flwr/client/grpc_client/connection.py @@ -65,7 +65,7 @@ def grpc_connection( # pylint: disable=R0913, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ # pylint: disable=unused-argument - tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ Tuple[ diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index eba90ff19670..484f98ec3d78 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -70,7 +70,7 @@ def grpc_request_response( # pylint: disable=R0913, R0914, R0915 max_message_length: int = GRPC_MAX_MESSAGE_LENGTH, # pylint: disable=W0613 root_certificates: Optional[Union[bytes, str]] = None, authentication_keys: Optional[ - tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ Tuple[ diff --git a/src/py/flwr/client/rest_client/connection.py b/src/py/flwr/client/rest_client/connection.py index 4f060b9ef236..da8fbd351ab1 100644 --- a/src/py/flwr/client/rest_client/connection.py +++ b/src/py/flwr/client/rest_client/connection.py @@ -84,7 +84,7 @@ def http_request_response( # pylint: disable=,R0913, R0914, R0915 Union[bytes, str] ] = None, # pylint: disable=unused-argument authentication_keys: Optional[ # pylint: disable=unused-argument - tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] + Tuple[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey] ] = None, ) -> Iterator[ Tuple[ From 03e7c12d25d49113fb40dbe26af85bb413c9fe82 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:30:25 +0200 Subject: [PATCH 40/49] Move to grpc_rere_client --- src/py/flwr/client/{ => grpc_rere_client}/client_interceptor.py | 0 .../client/{ => grpc_rere_client}/client_interceptor_test.py | 0 src/py/flwr/client/grpc_rere_client/connection.py | 2 +- 3 files changed, 1 insertion(+), 1 deletion(-) rename src/py/flwr/client/{ => grpc_rere_client}/client_interceptor.py (100%) rename src/py/flwr/client/{ => grpc_rere_client}/client_interceptor_test.py (100%) diff --git a/src/py/flwr/client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py similarity index 100% rename from src/py/flwr/client/client_interceptor.py rename to src/py/flwr/client/grpc_rere_client/client_interceptor.py diff --git a/src/py/flwr/client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py similarity index 100% rename from src/py/flwr/client/client_interceptor_test.py rename to src/py/flwr/client/grpc_rere_client/client_interceptor_test.py diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 484f98ec3d78..39882c28b8f6 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -26,7 +26,7 @@ import grpc from cryptography.hazmat.primitives.asymmetric import ec -from flwr.client.client_interceptor import AuthenticateClientInterceptor +from .client_interceptor import AuthenticateClientInterceptor from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins From c655c7bcbd4d9dfee0a931f369fbb216bd3e8351 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:33:10 +0200 Subject: [PATCH 41/49] Format --- src/py/flwr/client/grpc_rere_client/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/client/grpc_rere_client/connection.py b/src/py/flwr/client/grpc_rere_client/connection.py index 39882c28b8f6..3778fd4061f9 100644 --- a/src/py/flwr/client/grpc_rere_client/connection.py +++ b/src/py/flwr/client/grpc_rere_client/connection.py @@ -26,7 +26,6 @@ import grpc from cryptography.hazmat.primitives.asymmetric import ec -from .client_interceptor import AuthenticateClientInterceptor from flwr.client.heartbeat import start_ping_loop from flwr.client.message_handler.message_handler import validate_out_message from flwr.client.message_handler.task_handler import get_task_ins, validate_task_ins @@ -56,6 +55,8 @@ from flwr.proto.node_pb2 import Node # pylint: disable=E0611 from flwr.proto.task_pb2 import TaskIns # pylint: disable=E0611 +from .client_interceptor import AuthenticateClientInterceptor + def on_channel_state_change(channel_connectivity: str) -> None: """Log channel connectivity.""" From 537e05ede33003f34d980603f25b26c0ab0af0fa Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:39:18 +0200 Subject: [PATCH 42/49] Add get_run --- .../grpc_rere_client/client_interceptor.py | 5 +- .../client_interceptor_test.py | 46 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index a9510f0e1dc7..61969c28d949 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -33,13 +33,14 @@ DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, + GetRunRequest ) _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" Request = Union[ - CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest + CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest ] @@ -101,7 +102,7 @@ def intercept_unary_unary( postprocess = True elif isinstance( - request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest) + request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest) ): metadata.append( ( diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index eeacc9765977..fed9f3fcefaa 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -45,6 +45,8 @@ PullTaskInsResponse, PushTaskResRequest, PushTaskResResponse, + GetRunRequest, + GetRunResponse, ) from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request @@ -119,6 +121,11 @@ def _add_generic_handler(servicer: _MockServicer, server: grpc.Server) -> None: request_deserializer=PushTaskResRequest.FromString, response_serializer=PushTaskResResponse.SerializeToString, ), + "GetRun": grpc.unary_unary_rpc_method_handler( + servicer.unary_unary, + request_deserializer=GetRunRequest.FromString, + response_serializer=GetRunResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "flwr.proto.Fleet", rpc_method_handlers @@ -325,6 +332,45 @@ def test_client_auth_send(self) -> None: # Assert assert self._servicer.received_client_metadata() == expected_client_metadata + def test_client_auth_get_run(self) -> None: + """Test client authentication during send node.""" + # Prepare + retry_invoker = _init_retry_invoker() + + # Execute + with self._connection( + self._address, + True, + retry_invoker, + GRPC_MAX_MESSAGE_LENGTH, + None, + (self._client_private_key, self._client_public_key), + ) as conn: + _, _, _, _, get_run = conn + assert get_run is not None + get_run(0) + shared_secret = generate_shared_key( + self._servicer.server_private_key, self._client_public_key + ) + expected_hmac = compute_hmac( + shared_secret, self._servicer.received_message_bytes() + ) + expected_client_metadata = ( + ( + _PUBLIC_KEY_HEADER, + base64.urlsafe_b64encode( + public_key_to_bytes(self._client_public_key) + ), + ), + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode(expected_hmac), + ), + ) + + # Assert + assert self._servicer.received_client_metadata() == expected_client_metadata + if __name__ == "__main__": unittest.main(verbosity=2) From f313480f08ef42ff972c7e952efe6879dfe08e78 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:42:40 +0200 Subject: [PATCH 43/49] Encode only once --- .../grpc_rere_client/client_interceptor.py | 33 ++++++++++--------- .../client_interceptor_test.py | 4 +-- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 61969c28d949..ca5777e6c282 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -31,16 +31,20 @@ from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611 CreateNodeRequest, DeleteNodeRequest, + GetRunRequest, PullTaskInsRequest, PushTaskResRequest, - GetRunRequest ) _PUBLIC_KEY_HEADER = "public-key" _AUTH_TOKEN_HEADER = "auth-token" Request = Union[ - CreateNodeRequest, DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest + CreateNodeRequest, + DeleteNodeRequest, + PullTaskInsRequest, + PushTaskResRequest, + GetRunRequest, ] @@ -75,6 +79,9 @@ def __init__( self.public_key = public_key self.shared_secret = b"" self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None + self.encoded_public_key = base64.urlsafe_b64encode( + public_key_to_bytes(self.public_key) + ) def intercept_unary_unary( self, @@ -92,24 +99,20 @@ def intercept_unary_unary( if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) - if isinstance(request, CreateNodeRequest): - metadata.append( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode(public_key_to_bytes(self.public_key)), - ) + metadata.append( + ( + _PUBLIC_KEY_HEADER, + self.encoded_public_key, ) + ) + + if isinstance(request, CreateNodeRequest): postprocess = True elif isinstance( - request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest) + request, + (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest), ): - metadata.append( - ( - _PUBLIC_KEY_HEADER, - base64.urlsafe_b64encode(public_key_to_bytes(self.public_key)), - ) - ) metadata.append( ( _AUTH_TOKEN_HEADER, diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py index fed9f3fcefaa..487361a06026 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor_test.py @@ -41,12 +41,12 @@ CreateNodeResponse, DeleteNodeRequest, DeleteNodeResponse, + GetRunRequest, + GetRunResponse, PullTaskInsRequest, PullTaskInsResponse, PushTaskResRequest, PushTaskResResponse, - GetRunRequest, - GetRunResponse, ) from .client_interceptor import _AUTH_TOKEN_HEADER, _PUBLIC_KEY_HEADER, Request From 28b2f4c43325059de871802a021d68de923e9821 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:48:55 +0200 Subject: [PATCH 44/49] Check if shared secret is not none --- .../grpc_rere_client/client_interceptor.py | 23 +++++++++++-------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index ca5777e6c282..53dc7cce9b5c 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -77,7 +77,7 @@ def __init__( ): self.private_key = private_key self.public_key = public_key - self.shared_secret = b"" + self.shared_secret: Optional[bytes] = None self.server_public_key: Optional[ec.EllipticCurvePublicKey] = None self.encoded_public_key = base64.urlsafe_b64encode( public_key_to_bytes(self.public_key) @@ -113,16 +113,19 @@ def intercept_unary_unary( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest), ): - metadata.append( - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode( - compute_hmac( - self.shared_secret, request.SerializeToString(True) - ) - ), + if self.shared_secret is not None: + metadata.append( + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode( + compute_hmac( + self.shared_secret, request.SerializeToString(True) + ) + ), + ) ) - ) + else: + raise RuntimeError("Failure to compute hmac") client_call_details = _ClientCallDetails( client_call_details.method, From 5da8cdba3a7738ece52b1f74e55fe560d6686f5e Mon Sep 17 00:00:00 2001 From: "Daniel J. Beutel" Date: Wed, 24 Apr 2024 16:50:55 +0200 Subject: [PATCH 45/49] Update src/py/flwr/client/grpc_rere_client/client_interceptor.py --- src/py/flwr/client/grpc_rere_client/client_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 53dc7cce9b5c..6e7c948fae9a 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -99,6 +99,7 @@ def intercept_unary_unary( if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) + # Always add the public key header metadata.append( ( _PUBLIC_KEY_HEADER, @@ -108,7 +109,6 @@ def intercept_unary_unary( if isinstance(request, CreateNodeRequest): postprocess = True - elif isinstance( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest), From c678f11f482d10cec469519a40e9ed1f1822a5e4 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:53:56 +0200 Subject: [PATCH 46/49] Simplify if/else --- .../grpc_rere_client/client_interceptor.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 53dc7cce9b5c..ed9dc3cc6476 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -113,19 +113,18 @@ def intercept_unary_unary( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest), ): - if self.shared_secret is not None: - metadata.append( - ( - _AUTH_TOKEN_HEADER, - base64.urlsafe_b64encode( - compute_hmac( - self.shared_secret, request.SerializeToString(True) - ) - ), - ) - ) - else: + if self.shared_secret None: raise RuntimeError("Failure to compute hmac") + metadata.append( + ( + _AUTH_TOKEN_HEADER, + base64.urlsafe_b64encode( + compute_hmac( + self.shared_secret, request.SerializeToString(True) + ) + ), + ) + ) client_call_details = _ClientCallDetails( client_call_details.method, From 9f6d8ec6bbb3958d5ca3658d9aece14825ab6760 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:54:28 +0200 Subject: [PATCH 47/49] Simplify --- src/py/flwr/client/grpc_rere_client/client_interceptor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index ed9dc3cc6476..2f05d4baeb2b 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -113,8 +113,9 @@ def intercept_unary_unary( request, (DeleteNodeRequest, PullTaskInsRequest, PushTaskResRequest, GetRunRequest), ): - if self.shared_secret None: + if self.shared_secret is None: raise RuntimeError("Failure to compute hmac") + metadata.append( ( _AUTH_TOKEN_HEADER, From 4e2a0c5d02cc740b0e658e9d08a020591d494f56 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 16:55:02 +0200 Subject: [PATCH 48/49] Format --- src/py/flwr/client/grpc_rere_client/client_interceptor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index d557c431325f..3a95f1edc1ec 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -115,7 +115,7 @@ def intercept_unary_unary( ): if self.shared_secret is None: raise RuntimeError("Failure to compute hmac") - + metadata.append( ( _AUTH_TOKEN_HEADER, From ca7267a45e020fdf9834cf068d9f233372862572 Mon Sep 17 00:00:00 2001 From: Daniel Nata Nugraha Date: Wed, 24 Apr 2024 17:12:00 +0200 Subject: [PATCH 49/49] Add docstring to ClientCallDetails --- src/py/flwr/client/grpc_rere_client/client_interceptor.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/py/flwr/client/grpc_rere_client/client_interceptor.py b/src/py/flwr/client/grpc_rere_client/client_interceptor.py index 3a95f1edc1ec..7e044266c2e7 100644 --- a/src/py/flwr/client/grpc_rere_client/client_interceptor.py +++ b/src/py/flwr/client/grpc_rere_client/client_interceptor.py @@ -64,7 +64,11 @@ class _ClientCallDetails( ), grpc.ClientCallDetails, # type: ignore ): - pass + """Details for each client call. + + The class will be passed on as the first argument in continuation function. + In our case, `AuthenticateClientInterceptor` adds new metadata to the construct. + """ class AuthenticateClientInterceptor(grpc.UnaryUnaryClientInterceptor): # type: ignore