diff --git a/src/proto/flwr/proto/clientappio.proto b/src/proto/flwr/proto/clientappio.proto index 898cb04c5b5b..376df1e28001 100644 --- a/src/proto/flwr/proto/clientappio.proto +++ b/src/proto/flwr/proto/clientappio.proto @@ -30,9 +30,9 @@ message ClientAppOutputStatus { } message GetTokenRequest {} -message GetTokenResponse { sint64 token = 1; } +message GetTokenResponse { uint64 token = 1; } -message PullClientAppInputsRequest { sint64 token = 1; } +message PullClientAppInputsRequest { uint64 token = 1; } message PullClientAppInputsResponse { Message message = 1; Context context = 2; @@ -41,7 +41,7 @@ message PullClientAppInputsResponse { } message PushClientAppOutputsRequest { - sint64 token = 1; + uint64 token = 1; Message message = 2; Context context = 3; } diff --git a/src/proto/flwr/proto/driver.proto b/src/proto/flwr/proto/driver.proto index 63a2f78e6f6d..c7ae7dcf30f0 100644 --- a/src/proto/flwr/proto/driver.proto +++ b/src/proto/flwr/proto/driver.proto @@ -50,10 +50,10 @@ message CreateRunRequest { map override_config = 3; Fab fab = 4; } -message CreateRunResponse { sint64 run_id = 1; } +message CreateRunResponse { uint64 run_id = 1; } // GetNodes messages -message GetNodesRequest { sint64 run_id = 1; } +message GetNodesRequest { uint64 run_id = 1; } message GetNodesResponse { repeated Node nodes = 1; } // PushTaskIns messages diff --git a/src/proto/flwr/proto/exec.proto b/src/proto/flwr/proto/exec.proto index 65faf4386ea0..ad0723c0480c 100644 --- a/src/proto/flwr/proto/exec.proto +++ b/src/proto/flwr/proto/exec.proto @@ -33,6 +33,6 @@ message StartRunRequest { map override_config = 2; map federation_config = 3; } -message StartRunResponse { sint64 run_id = 1; } -message StreamLogsRequest { sint64 run_id = 1; } +message StartRunResponse { uint64 run_id = 1; } +message StreamLogsRequest { uint64 run_id = 1; } message StreamLogsResponse { string log_output = 1; } diff --git a/src/proto/flwr/proto/message.proto b/src/proto/flwr/proto/message.proto index 3230ab0609a9..7066da5b7e76 100644 --- a/src/proto/flwr/proto/message.proto +++ b/src/proto/flwr/proto/message.proto @@ -28,17 +28,17 @@ message Message { } message Context { - sint64 node_id = 1; + uint64 node_id = 1; map node_config = 2; RecordSet state = 3; map run_config = 4; } message Metadata { - sint64 run_id = 1; + uint64 run_id = 1; string message_id = 2; - sint64 src_node_id = 3; - sint64 dst_node_id = 4; + uint64 src_node_id = 3; + uint64 dst_node_id = 4; string reply_to_message = 5; string group_id = 6; double ttl = 7; diff --git a/src/proto/flwr/proto/node.proto b/src/proto/flwr/proto/node.proto index e61d44f0f783..ec72b51b44ec 100644 --- a/src/proto/flwr/proto/node.proto +++ b/src/proto/flwr/proto/node.proto @@ -18,6 +18,6 @@ syntax = "proto3"; package flwr.proto; message Node { - sint64 node_id = 1; + uint64 node_id = 1; bool anonymous = 2; } diff --git a/src/proto/flwr/proto/run.proto b/src/proto/flwr/proto/run.proto index 6adca5c2437b..fc3294f7a583 100644 --- a/src/proto/flwr/proto/run.proto +++ b/src/proto/flwr/proto/run.proto @@ -20,11 +20,11 @@ package flwr.proto; import "flwr/proto/transport.proto"; message Run { - sint64 run_id = 1; + uint64 run_id = 1; string fab_id = 2; string fab_version = 3; map override_config = 4; string fab_hash = 5; } -message GetRunRequest { sint64 run_id = 1; } +message GetRunRequest { uint64 run_id = 1; } message GetRunResponse { Run run = 1; } diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 936b8120e495..324a70a5359c 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -37,13 +37,13 @@ message Task { message TaskIns { string task_id = 1; string group_id = 2; - sint64 run_id = 3; + uint64 run_id = 3; Task task = 4; } message TaskRes { string task_id = 1; string group_id = 2; - sint64 run_id = 3; + uint64 run_id = 3; Task task = 4; } diff --git a/src/py/flwr/server/superlink/state/sqlite_state.py b/src/py/flwr/server/superlink/state/sqlite_state.py index daa211560912..4be0e60d4708 100644 --- a/src/py/flwr/server/superlink/state/sqlite_state.py +++ b/src/py/flwr/server/superlink/state/sqlite_state.py @@ -32,7 +32,12 @@ from flwr.server.utils.validator import validate_task_ins_or_res from .state import State -from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres +from .utils import ( + generate_rand_int_from_bytes, + make_node_unavailable_taskres, + uint64_to_sint64, + sint64_to_uint64, +) SQL_CREATE_TABLE_NODE = """ CREATE TABLE IF NOT EXISTS node( @@ -219,6 +224,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]: # Create task_id task_id = uuid4() + # Convert a uint64 value to sint64 for SQLite + task_ins.run_id = uint64_to_sint64(task_ins.run_id) + # Store TaskIns task_ins.task_id = str(task_id) data = (task_ins_to_dict(task_ins),) @@ -291,7 +299,10 @@ def get_task_ins( AND consumer_node_id == :node_id AND delivered_at = "" """ - data["node_id"] = node_id + + # Convert a uint64 value to sint64 for SQLite + sint64_node_id = uint64_to_sint64(node_id) + data["node_id"] = sint64_node_id if limit is not None: query += " LIMIT :limit" @@ -350,6 +361,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]: # Create task_id task_id = uuid4() + # Convert a uint64 value to sint64 for SQLite + task_res.run_id = uint64_to_sint64(task_res.run_id) + # Store TaskIns task_res.task_id = str(task_id) data = (task_res_to_dict(task_res),) @@ -473,6 +487,10 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe for row in task_ins_rows: if limit and len(result) == limit: break + + # Convert run_id from sint64 to uint64 + row["run_id"] = sint64_to_uint64(row["run_id"]) + task_ins = dict_to_task_ins(row) err_taskres = make_node_unavailable_taskres( ref_taskins=task_ins, @@ -546,6 +564,9 @@ def create_node( # Sample a random int64 as node_id node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES) + # Convert a uint64 value to sint64 for SQLite + sint64_node_id = uint64_to_sint64(node_id) + query = "SELECT node_id FROM node WHERE public_key = :public_key;" row = self.query(query, {"public_key": public_key}) @@ -561,17 +582,29 @@ def create_node( try: self.query( - query, (node_id, time.time() + ping_interval, ping_interval, public_key) + query, + ( + sint64_node_id, + time.time() + ping_interval, + ping_interval, + public_key, + ), ) except sqlite3.IntegrityError: log(ERROR, "Unexpected node registration failure.") return 0 + + # Return the uint64 value of the node_id return node_id def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None: """Delete a node.""" + + # Convert the uint64 value to sint64 for SQLite + sint64_node_id = uint64_to_sint64(node_id) + query = "DELETE FROM node WHERE node_id = ?" - params = (node_id,) + params = (sint64_node_id,) if public_key is not None: query += " AND public_key = ?" @@ -596,15 +629,22 @@ def get_nodes(self, run_id: int) -> Set[int]: If the provided `run_id` does not exist or has no matching nodes, an empty `Set` MUST be returned. """ + + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = uint64_to_sint64(run_id) + # Validate run ID - query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" - if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: + query = "SELECT COUNT(*) FROM run WHERE sint64_run_id = ?;" + if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0: return set() # Get nodes query = "SELECT node_id FROM node WHERE online_until > ?;" rows = self.query(query, (time.time(),)) - result: Set[int] = {row["node_id"] for row in rows} + + # Convert sint64 node_ids to uint64 + result: Set[int] = {sint64_to_uint64(row["node_id"]) for row in rows} + return result def get_node_id(self, node_public_key: bytes) -> Optional[int]: @@ -613,7 +653,11 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: row = self.query(query, {"public_key": node_public_key}) if len(row) > 0: node_id: int = row[0]["node_id"] - return node_id + + # Convert a sint64 value to uint64 after reading from SQLite + uint64_node_id = sint64_to_uint64(node_id) + + return uint64_node_id return None def create_run( @@ -627,24 +671,35 @@ def create_run( # Sample a random int64 as run_id run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES) + # Convert a uint64 value to sint64 for SQLite + sint64_run_id = uint64_to_sint64(run_id) + # Check conflicts - query = "SELECT COUNT(*) FROM run WHERE run_id = ?;" - # If run_id does not exist - if self.query(query, (run_id,))[0]["COUNT(*)"] == 0: + query = "SELECT COUNT(*) FROM run WHERE sint64_run_id = ?;" + # If sint64_run_id does not exist + if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0: query = ( "INSERT INTO run " - "(run_id, fab_id, fab_version, fab_hash, override_config)" + "(sint64_run_id, fab_id, fab_version, fab_hash, override_config)" "VALUES (?, ?, ?, ?, ?);" ) if fab_hash: self.query( - query, (run_id, "", "", fab_hash, json.dumps(override_config)) + query, + (sint64_run_id, "", "", fab_hash, json.dumps(override_config)), ) else: self.query( query, - (run_id, fab_id, fab_version, "", json.dumps(override_config)), + ( + sint64_run_id, + fab_id, + fab_version, + "", + json.dumps(override_config), + ), ) + # Return the uint64 value of the run_id return run_id log(ERROR, "Unexpected run creation failure.") return 0 diff --git a/src/py/flwr/server/superlink/state/utils.py b/src/py/flwr/server/superlink/state/utils.py index b12a87ac998d..913a051b9946 100644 --- a/src/py/flwr/server/superlink/state/utils.py +++ b/src/py/flwr/server/superlink/state/utils.py @@ -33,8 +33,38 @@ def generate_rand_int_from_bytes(num_bytes: int) -> int: - """Generate a random `num_bytes` integer.""" - return int.from_bytes(urandom(num_bytes), "little", signed=True) + """Generate a random unsigned integer from `num_bytes` bytes.""" + return int.from_bytes(urandom(num_bytes), "little", signed=False) + + +def uint64_to_sint64(u: int) -> int: + """ + Convert a uint64 value to sint64. + + Args: + u (int): The unsigned 64-bit integer to convert. + + Returns: + int: The signed 64-bit integer equivalent. + """ + if u >= 2**63: + return u - 2**64 + return u + + +def sint64_to_uint64(s: int) -> int: + """ + Convert a sint64 value to uint64. + + Args: + s (int): The signed 64-bit integer to convert. + + Returns: + int: The unsigned 64-bit integer equivalent. + """ + if s < 0: + return s + 2**64 + return s def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: diff --git a/src/py/flwr/server/superlink/state/utils_test.py b/src/py/flwr/server/superlink/state/utils_test.py new file mode 100644 index 000000000000..3d16676431d3 --- /dev/null +++ b/src/py/flwr/server/superlink/state/utils_test.py @@ -0,0 +1,53 @@ +# Copyright 2023 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils tests.""" + +import unittest + +from .utils import uint64_to_sint64, sint64_to_uint64, generate_rand_int_from_bytes + + +class UtilsTest(unittest.TestCase): + """Test utils code.""" + + def test_uint64_to_sint64(self): + # Test values below 2^63 + self.assertEqual(uint64_to_sint64(0), 0) + self.assertEqual(uint64_to_sint64(2**62), 2**62) + self.assertEqual(uint64_to_sint64(2**63 - 1), 2**63 - 1) + + # Test values at and above 2^63 + self.assertEqual(uint64_to_sint64(2**63), -(2**63)) + self.assertEqual(uint64_to_sint64(2**63 + 1), -(2**63) + 1) + self.assertEqual(uint64_to_sint64(2**64 - 1), -1) + + def test_sint64_to_uint64(self): + # Test values within the range of sint64 + self.assertEqual(sint64_to_uint64(-(2**63)), 2**63) + self.assertEqual(sint64_to_uint64(-(2**63) + 1), 2**63 + 1) + self.assertEqual(sint64_to_uint64(-1), 2**64 - 1) + self.assertEqual(sint64_to_uint64(0), 0) + self.assertEqual(sint64_to_uint64(2**63 - 1), 2**63 - 1) + + # Test values above 2^63 + self.assertEqual(sint64_to_uint64(2**63), 2**63) + self.assertEqual(sint64_to_uint64(2**64 - 1), 2**64 - 1) + + def test_generate_rand_int_from_bytes_unsigned_int(self): + """Test that the generated integer is unsigned (non-negative).""" + for num_bytes in range(1, 9): + with self.subTest(num_bytes=num_bytes): + rand_int = generate_rand_int_from_bytes(num_bytes) + self.assertGreaterEqual(rand_int, 0)