Skip to content

Commit

Permalink
Merge branch 'main' into migrate-opacus-example
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadnaseri authored Sep 26, 2024
2 parents d0c04d6 + 83cd4ba commit 4c45c9c
Show file tree
Hide file tree
Showing 20 changed files with 350 additions and 71 deletions.
9 changes: 9 additions & 0 deletions datasets/doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,15 @@ Flower Datasets

Flower Datasets (``flwr-datasets``) is a library that enables the quick and easy creation of datasets for federated learning/analytics/evaluation. It enables heterogeneity (non-iidness) simulation and division of datasets with the preexisting notion of IDs. The library was created by the ``Flower Labs`` team that also created `Flower <https://flower.ai>`_ : A Friendly Federated Learning Framework.

.. raw:: html

<script
type="module"
src="https://gradio.s3-us-west-2.amazonaws.com/4.44.0/gradio.js"
></script>

<gradio-app src="https://flwrlabs-federated-learning-datasets-by-flwr-datasets.hf.space"></gradio-app>

Flower Datasets Framework
-------------------------

Expand Down
7 changes: 6 additions & 1 deletion src/proto/flwr/proto/fab.proto
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ syntax = "proto3";

package flwr.proto;

import "flwr/proto/node.proto";

message Fab {
// This field is the hash of the data field. It is used to identify the data.
// The hash is calculated using the SHA-256 algorithm and is represented as a
Expand All @@ -26,5 +28,8 @@ message Fab {
bytes content = 2;
}

message GetFabRequest { string hash_str = 1; }
message GetFabRequest {
Node node = 1;
string hash_str = 2;
}
message GetFabResponse { Fab fab = 1; }
5 changes: 4 additions & 1 deletion src/proto/flwr/proto/fleet.proto
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ message PullTaskInsResponse {
}

// PushTaskRes messages
message PushTaskResRequest { repeated TaskRes task_res_list = 1; }
message PushTaskResRequest {
Node node = 1;
repeated TaskRes task_res_list = 2;
}
message PushTaskResResponse {
Reconnect reconnect = 1;
map<string, uint32> results = 2;
Expand Down
11 changes: 9 additions & 2 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ syntax = "proto3";
package flwr.proto;

import "flwr/proto/fab.proto";
import "flwr/proto/node.proto";
import "flwr/proto/transport.proto";

message Run {
Expand Down Expand Up @@ -47,7 +48,10 @@ message CreateRunRequest {
message CreateRunResponse { uint64 run_id = 1; }

// GetRun
message GetRunRequest { uint64 run_id = 1; }
message GetRunRequest {
Node node = 1;
uint64 run_id = 2;
}
message GetRunResponse { Run run = 1; }

// UpdateRunStatus
Expand All @@ -58,5 +62,8 @@ message UpdateRunStatusRequest {
message UpdateRunStatusResponse {}

// GetRunStatus
message GetRunStatusRequest { repeated uint64 run_ids = 1; }
message GetRunStatusRequest {
Node node = 1;
repeated uint64 run_ids = 2;
}
message GetRunStatusResponse { map<uint64, RunStatus> run_status_dict = 1; }
37 changes: 34 additions & 3 deletions src/py/flwr/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,49 @@
from flwr.cli.config_utils import load_and_validate
from flwr.common.grpc import GRPC_MAX_MESSAGE_LENGTH, create_channel
from flwr.common.logger import log as logger
from flwr.proto.exec_pb2 import StreamLogsRequest # pylint: disable=E0611
from flwr.proto.exec_pb2_grpc import ExecStub

CONN_REFRESH_PERIOD = 60 # Connection refresh period for log streaming (seconds)


# pylint: disable=unused-argument
def stream_logs(run_id: int, channel: grpc.Channel, period: int) -> None:
def stream_logs(run_id: int, channel: grpc.Channel, duration: int) -> None:
"""Stream logs from the beginning of a run with connection refresh."""
start_time = time.time()
stub = ExecStub(channel)
req = StreamLogsRequest(run_id=run_id)

for res in stub.StreamLogs(req):
print(res.log_output)
if time.time() - start_time > duration:
break


# pylint: disable=unused-argument
def print_logs(run_id: int, channel: grpc.Channel, timeout: int) -> None:
"""Print logs from the beginning of a run."""
stub = ExecStub(channel)
req = StreamLogsRequest(run_id=run_id)

try:
while True:
try:
# Enforce timeout for graceful exit
for res in stub.StreamLogs(req, timeout=timeout):
print(res.log_output)
except grpc.RpcError as e:
# pylint: disable=E1101
if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED:
break
if e.code() == grpc.StatusCode.NOT_FOUND:
logger(ERROR, "Invalid run_id `%s`, exiting", run_id)
break
if e.code() == grpc.StatusCode.CANCELLED:
break
except KeyboardInterrupt:
logger(DEBUG, "Stream interrupted by user")
finally:
channel.close()
logger(DEBUG, "Channel closed")


def on_channel_state_change(channel_connectivity: str) -> None:
Expand Down
78 changes: 78 additions & 0 deletions src/py/flwr/cli/log_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for Flower command line interface `log` command."""


import unittest
from typing import NoReturn
from unittest.mock import Mock, call, patch

from flwr.proto.exec_pb2 import StreamLogsResponse # pylint: disable=E0611

from .log import print_logs, stream_logs


class InterruptedStreamLogsResponse:
"""Create a StreamLogsResponse object with KeyboardInterrupt."""

@property
def log_output(self) -> NoReturn:
"""Raise KeyboardInterrupt to exit logstream test gracefully."""
raise KeyboardInterrupt


class TestFlwrLog(unittest.TestCase):
"""Unit tests for `flwr log` CLI functions."""

def setUp(self) -> None:
"""Initialize mock ExecStub before each test."""
self.expected_calls = [
call("log_output_1"),
call("log_output_2"),
call("log_output_3"),
]
mock_response_iterator = [
iter(
[StreamLogsResponse(log_output=f"log_output_{i}") for i in range(1, 4)]
+ [InterruptedStreamLogsResponse()]
)
]
self.mock_stub = Mock()
self.mock_stub.StreamLogs.side_effect = mock_response_iterator
self.patcher = patch("flwr.cli.log.ExecStub", return_value=self.mock_stub)

self.patcher.start()

# Create mock channel
self.mock_channel = Mock()

def tearDown(self) -> None:
"""Cleanup."""
self.patcher.stop()

def test_flwr_log_stream_method(self) -> None:
"""Test stream_logs."""
with patch("builtins.print") as mock_print:
with self.assertRaises(KeyboardInterrupt):
stream_logs(run_id=123, channel=self.mock_channel, duration=1)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_calls)

def test_flwr_log_print_method(self) -> None:
"""Test print_logs."""
with patch("builtins.print") as mock_print:
print_logs(run_id=123, channel=self.mock_channel, timeout=0)
# Assert that mock print was called with the expected arguments
mock_print.assert_has_calls(self.expected_calls)
6 changes: 3 additions & 3 deletions src/py/flwr/client/grpc_rere_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,15 +269,15 @@ def send(message: Message) -> None:
task_res = message_to_taskres(message)

# Serialize ProtoBuf to bytes
request = PushTaskResRequest(task_res_list=[task_res])
request = PushTaskResRequest(node=node, task_res_list=[task_res])
_ = retry_invoker.invoke(stub.PushTaskRes, request)

# Cleanup
metadata = None

def get_run(run_id: int) -> Run:
# Call FleetAPI
get_run_request = GetRunRequest(run_id=run_id)
get_run_request = GetRunRequest(node=node, run_id=run_id)
get_run_response: GetRunResponse = retry_invoker.invoke(
stub.GetRun,
request=get_run_request,
Expand All @@ -294,7 +294,7 @@ def get_run(run_id: int) -> Run:

def get_fab(fab_hash: str) -> Fab:
# Call FleetAPI
get_fab_request = GetFabRequest(hash_str=fab_hash)
get_fab_request = GetFabRequest(node=node, hash_str=fab_hash)
get_fab_response: GetFabResponse = retry_invoker.invoke(
stub.GetFab,
request=get_fab_request,
Expand Down
6 changes: 3 additions & 3 deletions src/py/flwr/client/rest_client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def send(message: Message) -> None:
task_res = message_to_taskres(message)

# Serialize ProtoBuf to bytes
req = PushTaskResRequest(task_res_list=[task_res])
req = PushTaskResRequest(node=node, task_res_list=[task_res])

# Send the request
res = _request(req, PushTaskResResponse, PATH_PUSH_TASK_RES)
Expand All @@ -356,7 +356,7 @@ def send(message: Message) -> None:

def get_run(run_id: int) -> Run:
# Construct the request
req = GetRunRequest(run_id=run_id)
req = GetRunRequest(node=node, run_id=run_id)

# Send the request
res = _request(req, GetRunResponse, PATH_GET_RUN)
Expand All @@ -373,7 +373,7 @@ def get_run(run_id: int) -> Run:

def get_fab(fab_hash: str) -> Fab:
# Construct the request
req = GetFabRequest(hash_str=fab_hash)
req = GetFabRequest(node=node, hash_str=fab_hash)

# Send the request
res = _request(req, GetFabResponse, PATH_GET_FAB)
Expand Down
8 changes: 4 additions & 4 deletions src/py/flwr/common/secure_aggregation/secaggplus_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def share_keys_plaintext_concat(
"""
return b"".join(
[
int.to_bytes(src_node_id, 8, "little", signed=True),
int.to_bytes(dst_node_id, 8, "little", signed=True),
int.to_bytes(src_node_id, 8, "little", signed=False),
int.to_bytes(dst_node_id, 8, "little", signed=False),
int.to_bytes(len(b_share), 4, "little"),
b_share,
sk_share,
Expand Down Expand Up @@ -72,8 +72,8 @@ def share_keys_plaintext_separate(plaintext: bytes) -> tuple[int, int, bytes, by
the secret key share of the source sent to the destination.
"""
src, dst, mark = (
int.from_bytes(plaintext[:8], "little", signed=True),
int.from_bytes(plaintext[8:16], "little", signed=True),
int.from_bytes(plaintext[:8], "little", signed=False),
int.from_bytes(plaintext[8:16], "little", signed=False),
int.from_bytes(plaintext[16:20], "little"),
)
ret = (src, dst, plaintext[20 : 20 + mark], plaintext[20 + mark :])
Expand Down
15 changes: 8 additions & 7 deletions src/py/flwr/proto/fab_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion src/py/flwr/proto/fab_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
isort:skip_file
"""
import builtins
import flwr.proto.node_pb2
import google.protobuf.descriptor
import google.protobuf.message
import typing
Expand Down Expand Up @@ -33,13 +34,18 @@ global___Fab = Fab

class GetFabRequest(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NODE_FIELD_NUMBER: builtins.int
HASH_STR_FIELD_NUMBER: builtins.int
@property
def node(self) -> flwr.proto.node_pb2.Node: ...
hash_str: typing.Text
def __init__(self,
*,
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
hash_str: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["hash_str",b"hash_str","node",b"node"]) -> None: ...
global___GetFabRequest = GetFabRequest

class GetFabResponse(google.protobuf.message.Message):
Expand Down
20 changes: 10 additions & 10 deletions src/py/flwr/proto/fleet_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion src/py/flwr/proto/fleet_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,19 @@ global___PullTaskInsResponse = PullTaskInsResponse
class PushTaskResRequest(google.protobuf.message.Message):
"""PushTaskRes messages"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
NODE_FIELD_NUMBER: builtins.int
TASK_RES_LIST_FIELD_NUMBER: builtins.int
@property
def node(self) -> flwr.proto.node_pb2.Node: ...
@property
def task_res_list(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[flwr.proto.task_pb2.TaskRes]: ...
def __init__(self,
*,
node: typing.Optional[flwr.proto.node_pb2.Node] = ...,
task_res_list: typing.Optional[typing.Iterable[flwr.proto.task_pb2.TaskRes]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["task_res_list",b"task_res_list"]) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["node",b"node"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["node",b"node","task_res_list",b"task_res_list"]) -> None: ...
global___PushTaskResRequest = PushTaskResRequest

class PushTaskResResponse(google.protobuf.message.Message):
Expand Down
Loading

0 comments on commit 4c45c9c

Please sign in to comment.