Skip to content

Commit

Permalink
Update sint64 to uint64
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadnaseri committed Sep 10, 2024
1 parent 8f8639f commit 2a8b57f
Show file tree
Hide file tree
Showing 10 changed files with 170 additions and 32 deletions.
6 changes: 3 additions & 3 deletions src/proto/flwr/proto/clientappio.proto
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ message ClientAppOutputStatus {
}

message GetTokenRequest {}
message GetTokenResponse { sint64 token = 1; }
message GetTokenResponse { uint64 token = 1; }

message PullClientAppInputsRequest { sint64 token = 1; }
message PullClientAppInputsRequest { uint64 token = 1; }
message PullClientAppInputsResponse {
Message message = 1;
Context context = 2;
Expand All @@ -41,7 +41,7 @@ message PullClientAppInputsResponse {
}

message PushClientAppOutputsRequest {
sint64 token = 1;
uint64 token = 1;
Message message = 2;
Context context = 3;
}
Expand Down
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/driver.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ message CreateRunRequest {
map<string, Scalar> override_config = 3;
Fab fab = 4;
}
message CreateRunResponse { sint64 run_id = 1; }
message CreateRunResponse { uint64 run_id = 1; }

// GetNodes messages
message GetNodesRequest { sint64 run_id = 1; }
message GetNodesRequest { uint64 run_id = 1; }
message GetNodesResponse { repeated Node nodes = 1; }

// PushTaskIns messages
Expand Down
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/exec.proto
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ message StartRunRequest {
map<string, Scalar> override_config = 2;
map<string, Scalar> federation_config = 3;
}
message StartRunResponse { sint64 run_id = 1; }
message StreamLogsRequest { sint64 run_id = 1; }
message StartRunResponse { uint64 run_id = 1; }
message StreamLogsRequest { uint64 run_id = 1; }
message StreamLogsResponse { string log_output = 1; }
8 changes: 4 additions & 4 deletions src/proto/flwr/proto/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ message Message {
}

message Context {
sint64 node_id = 1;
uint64 node_id = 1;
map<string, Scalar> node_config = 2;
RecordSet state = 3;
map<string, Scalar> run_config = 4;
}

message Metadata {
sint64 run_id = 1;
uint64 run_id = 1;
string message_id = 2;
sint64 src_node_id = 3;
sint64 dst_node_id = 4;
uint64 src_node_id = 3;
uint64 dst_node_id = 4;
string reply_to_message = 5;
string group_id = 6;
double ttl = 7;
Expand Down
2 changes: 1 addition & 1 deletion src/proto/flwr/proto/node.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@ syntax = "proto3";
package flwr.proto;

message Node {
sint64 node_id = 1;
uint64 node_id = 1;
bool anonymous = 2;
}
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package flwr.proto;
import "flwr/proto/transport.proto";

message Run {
sint64 run_id = 1;
uint64 run_id = 1;
string fab_id = 2;
string fab_version = 3;
map<string, Scalar> override_config = 4;
string fab_hash = 5;
}
message GetRunRequest { sint64 run_id = 1; }
message GetRunRequest { uint64 run_id = 1; }
message GetRunResponse { Run run = 1; }
4 changes: 2 additions & 2 deletions src/proto/flwr/proto/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ message Task {
message TaskIns {
string task_id = 1;
string group_id = 2;
sint64 run_id = 3;
uint64 run_id = 3;
Task task = 4;
}

message TaskRes {
string task_id = 1;
string group_id = 2;
sint64 run_id = 3;
uint64 run_id = 3;
Task task = 4;
}
83 changes: 69 additions & 14 deletions src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from flwr.server.utils.validator import validate_task_ins_or_res

from .state import State
from .utils import generate_rand_int_from_bytes, make_node_unavailable_taskres
from .utils import (
generate_rand_int_from_bytes,
make_node_unavailable_taskres,
uint64_to_sint64,
sint64_to_uint64,
)

SQL_CREATE_TABLE_NODE = """
CREATE TABLE IF NOT EXISTS node(
Expand Down Expand Up @@ -219,6 +224,9 @@ def store_task_ins(self, task_ins: TaskIns) -> Optional[UUID]:
# Create task_id
task_id = uuid4()

# Convert a uint64 value to sint64 for SQLite
task_ins.run_id = uint64_to_sint64(task_ins.run_id)

# Store TaskIns
task_ins.task_id = str(task_id)
data = (task_ins_to_dict(task_ins),)
Expand Down Expand Up @@ -291,7 +299,10 @@ def get_task_ins(
AND consumer_node_id == :node_id
AND delivered_at = ""
"""
data["node_id"] = node_id

# Convert a uint64 value to sint64 for SQLite
sint64_node_id = uint64_to_sint64(node_id)
data["node_id"] = sint64_node_id

if limit is not None:
query += " LIMIT :limit"
Expand Down Expand Up @@ -350,6 +361,9 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
# Create task_id
task_id = uuid4()

# Convert a uint64 value to sint64 for SQLite
task_res.run_id = uint64_to_sint64(task_res.run_id)

# Store TaskIns
task_res.task_id = str(task_id)
data = (task_res_to_dict(task_res),)
Expand Down Expand Up @@ -473,6 +487,10 @@ def get_task_res(self, task_ids: Set[UUID], limit: Optional[int]) -> List[TaskRe
for row in task_ins_rows:
if limit and len(result) == limit:
break

# Convert run_id from sint64 to uint64
row["run_id"] = sint64_to_uint64(row["run_id"])

task_ins = dict_to_task_ins(row)
err_taskres = make_node_unavailable_taskres(
ref_taskins=task_ins,
Expand Down Expand Up @@ -546,6 +564,9 @@ def create_node(
# Sample a random int64 as node_id
node_id = generate_rand_int_from_bytes(NODE_ID_NUM_BYTES)

# Convert a uint64 value to sint64 for SQLite
sint64_node_id = uint64_to_sint64(node_id)

query = "SELECT node_id FROM node WHERE public_key = :public_key;"
row = self.query(query, {"public_key": public_key})

Expand All @@ -561,17 +582,29 @@ def create_node(

try:
self.query(
query, (node_id, time.time() + ping_interval, ping_interval, public_key)
query,
(
sint64_node_id,
time.time() + ping_interval,
ping_interval,
public_key,
),
)
except sqlite3.IntegrityError:
log(ERROR, "Unexpected node registration failure.")
return 0

# Return the uint64 value of the node_id
return node_id

def delete_node(self, node_id: int, public_key: Optional[bytes] = None) -> None:
"""Delete a node."""

# Convert the uint64 value to sint64 for SQLite
sint64_node_id = uint64_to_sint64(node_id)

query = "DELETE FROM node WHERE node_id = ?"
params = (node_id,)
params = (sint64_node_id,)

if public_key is not None:
query += " AND public_key = ?"
Expand All @@ -596,15 +629,22 @@ def get_nodes(self, run_id: int) -> Set[int]:
If the provided `run_id` does not exist or has no matching nodes,
an empty `Set` MUST be returned.
"""

# Convert the uint64 value to sint64 for SQLite
sint64_run_id = uint64_to_sint64(run_id)

# Validate run ID
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = "SELECT COUNT(*) FROM run WHERE sint64_run_id = ?;"
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
return set()

# Get nodes
query = "SELECT node_id FROM node WHERE online_until > ?;"
rows = self.query(query, (time.time(),))
result: Set[int] = {row["node_id"] for row in rows}

# Convert sint64 node_ids to uint64
result: Set[int] = {sint64_to_uint64(row["node_id"]) for row in rows}

return result

def get_node_id(self, node_public_key: bytes) -> Optional[int]:
Expand All @@ -613,7 +653,11 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
row = self.query(query, {"public_key": node_public_key})
if len(row) > 0:
node_id: int = row[0]["node_id"]
return node_id

# Convert a sint64 value to uint64 after reading from SQLite
uint64_node_id = sint64_to_uint64(node_id)

return uint64_node_id
return None

def create_run(
Expand All @@ -627,24 +671,35 @@ def create_run(
# Sample a random int64 as run_id
run_id = generate_rand_int_from_bytes(RUN_ID_NUM_BYTES)

# Convert a uint64 value to sint64 for SQLite
sint64_run_id = uint64_to_sint64(run_id)

# Check conflicts
query = "SELECT COUNT(*) FROM run WHERE run_id = ?;"
# If run_id does not exist
if self.query(query, (run_id,))[0]["COUNT(*)"] == 0:
query = "SELECT COUNT(*) FROM run WHERE sint64_run_id = ?;"
# If sint64_run_id does not exist
if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0:
query = (
"INSERT INTO run "
"(run_id, fab_id, fab_version, fab_hash, override_config)"
"(sint64_run_id, fab_id, fab_version, fab_hash, override_config)"
"VALUES (?, ?, ?, ?, ?);"
)
if fab_hash:
self.query(
query, (run_id, "", "", fab_hash, json.dumps(override_config))
query,
(sint64_run_id, "", "", fab_hash, json.dumps(override_config)),
)
else:
self.query(
query,
(run_id, fab_id, fab_version, "", json.dumps(override_config)),
(
sint64_run_id,
fab_id,
fab_version,
"",
json.dumps(override_config),
),
)
# Return the uint64 value of the run_id
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down
34 changes: 32 additions & 2 deletions src/py/flwr/server/superlink/state/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,38 @@


def generate_rand_int_from_bytes(num_bytes: int) -> int:
"""Generate a random `num_bytes` integer."""
return int.from_bytes(urandom(num_bytes), "little", signed=True)
"""Generate a random unsigned integer from `num_bytes` bytes."""
return int.from_bytes(urandom(num_bytes), "little", signed=False)


def uint64_to_sint64(u: int) -> int:
"""
Convert a uint64 value to sint64.
Args:
u (int): The unsigned 64-bit integer to convert.
Returns:
int: The signed 64-bit integer equivalent.
"""
if u >= 2**63:
return u - 2**64
return u


def sint64_to_uint64(s: int) -> int:
"""
Convert a sint64 value to uint64.
Args:
s (int): The signed 64-bit integer to convert.
Returns:
int: The unsigned 64-bit integer equivalent.
"""
if s < 0:
return s + 2**64
return s


def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes:
Expand Down
53 changes: 53 additions & 0 deletions src/py/flwr/server/superlink/state/utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils tests."""

import unittest

from .utils import uint64_to_sint64, sint64_to_uint64, generate_rand_int_from_bytes


class UtilsTest(unittest.TestCase):
"""Test utils code."""

def test_uint64_to_sint64(self):
# Test values below 2^63
self.assertEqual(uint64_to_sint64(0), 0)
self.assertEqual(uint64_to_sint64(2**62), 2**62)
self.assertEqual(uint64_to_sint64(2**63 - 1), 2**63 - 1)

# Test values at and above 2^63
self.assertEqual(uint64_to_sint64(2**63), -(2**63))
self.assertEqual(uint64_to_sint64(2**63 + 1), -(2**63) + 1)
self.assertEqual(uint64_to_sint64(2**64 - 1), -1)

def test_sint64_to_uint64(self):
# Test values within the range of sint64
self.assertEqual(sint64_to_uint64(-(2**63)), 2**63)
self.assertEqual(sint64_to_uint64(-(2**63) + 1), 2**63 + 1)
self.assertEqual(sint64_to_uint64(-1), 2**64 - 1)
self.assertEqual(sint64_to_uint64(0), 0)
self.assertEqual(sint64_to_uint64(2**63 - 1), 2**63 - 1)

# Test values above 2^63
self.assertEqual(sint64_to_uint64(2**63), 2**63)
self.assertEqual(sint64_to_uint64(2**64 - 1), 2**64 - 1)

def test_generate_rand_int_from_bytes_unsigned_int(self):
"""Test that the generated integer is unsigned (non-negative)."""
for num_bytes in range(1, 9):
with self.subTest(num_bytes=num_bytes):
rand_int = generate_rand_int_from_bytes(num_bytes)
self.assertGreaterEqual(rand_int, 0)

0 comments on commit 2a8b57f

Please sign in to comment.