diff --git a/examples/rust/.gitignore b/examples/rust/.gitignore new file mode 100644 index 000000000000..03314f77b5aa --- /dev/null +++ b/examples/rust/.gitignore @@ -0,0 +1 @@ +Cargo.lock diff --git a/examples/rust/Cargo.toml b/examples/rust/Cargo.toml new file mode 100644 index 000000000000..d6283698ea64 --- /dev/null +++ b/examples/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "rust" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +flwr = { path = "../../src/rust/flwr/" } +tokio = "1.33.0" diff --git a/examples/rust/README.md b/examples/rust/README.md new file mode 100644 index 000000000000..329e8bd62e3e --- /dev/null +++ b/examples/rust/README.md @@ -0,0 +1,10 @@ +--- +tags: [quickstart, dummy, sdk] +dataset: [] +framework: [Rust] +--- + +# Rust quickstart example + +Currently this example only provides an empty skeleton for implementing +a Flower client in Rust. diff --git a/examples/rust/src/main.rs b/examples/rust/src/main.rs new file mode 100644 index 000000000000..08302de59a62 --- /dev/null +++ b/examples/rust/src/main.rs @@ -0,0 +1,69 @@ +use flwr::client; +use flwr::start; +use flwr::typing; + +struct TestClient; + +impl client::Client for TestClient { + fn get_parameters(&self) -> typing::GetParametersRes { + println!("get_parameters"); + typing::GetParametersRes { + parameters: typing::Parameters { + tensors: vec![vec![1 as u8]], + tensor_type: "".to_string(), + }, + status: typing::Status { + code: typing::Code::OK, + message: "".to_string(), + }, + } + } + + fn get_properties(&self, ins: typing::GetPropertiesIns) -> typing::GetPropertiesRes { + println!("get_properties"); + typing::GetPropertiesRes { + properties: std::collections::HashMap::new(), + status: typing::Status { + code: typing::Code::OK, + message: "".to_string(), + }, + } + } + + fn fit(&self, ins: typing::FitIns) -> typing::FitRes { + println!("fit"); + typing::FitRes { + parameters: typing::Parameters { + tensors: vec![vec![1 as u8]], + tensor_type: "".to_string(), + }, + num_examples: 1, + metrics: std::collections::HashMap::new(), + status: typing::Status { + code: typing::Code::OK, + message: "".to_string(), + }, + } + } + + fn evaluate(&self, ins: typing::EvaluateIns) -> typing::EvaluateRes { + println!("evaluate"); + typing::EvaluateRes { + num_examples: 1, + metrics: std::collections::HashMap::new(), + loss: 1.0, + status: typing::Status { + code: typing::Code::OK, + message: "".to_string(), + }, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Start client..."); + let _client = + start::start_client("http://127.0.0.1:9092", &TestClient, None, Some("rere")).await?; + Ok(()) +} diff --git a/src/rust/flwr/.gitignore b/src/rust/flwr/.gitignore new file mode 100644 index 000000000000..03314f77b5aa --- /dev/null +++ b/src/rust/flwr/.gitignore @@ -0,0 +1 @@ +Cargo.lock diff --git a/src/rust/flwr/Cargo.toml b/src/rust/flwr/Cargo.toml new file mode 100644 index 000000000000..f4ab5aceb4c4 --- /dev/null +++ b/src/rust/flwr/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "flwr" +version = "0.1.0" +edition = "2021" +include = ["src/**/*", "proto/**/*"] + +[lib] +name = "flwr" +path = "src/lib.rs" + +[dependencies] +async-channel = "2.0.0" +futures = "0.3.29" +prost = "0.12.1" +rustls = "0.21.8" +rustls-native-certs = "0.6.3" +tokio = { version = "1.33.0", features = ["full"] } +tokio-stream = "0.1.14" +tonic = { version = "0.10.2", features = ["transport", "tls"] } +uuid = { version = "1.5.0", features = ["v4"] } + +[build-dependencies] +tonic-build = "0.10.2" diff --git a/src/rust/flwr/build.rs b/src/rust/flwr/build.rs new file mode 100644 index 000000000000..08732a871eef --- /dev/null +++ b/src/rust/flwr/build.rs @@ -0,0 +1,11 @@ +fn main() { + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + let proto_src = format!("{}/../../proto/flwr/proto/fleet.proto", manifest_dir); + let proto_include = format!("{}/../../proto", manifest_dir); + + tonic_build::configure() + .build_server(false) + .out_dir(format!("{}/src", manifest_dir)) + .compile(&[&proto_src], &[&proto_include]) + .unwrap(); +} diff --git a/src/rust/flwr/src/client.rs b/src/rust/flwr/src/client.rs new file mode 100644 index 000000000000..adc2d86cb067 --- /dev/null +++ b/src/rust/flwr/src/client.rs @@ -0,0 +1,22 @@ +use crate::typing as local; + +pub trait Client { + /// Return the current local model parameters + fn get_parameters(&self) -> local::GetParametersRes; + + fn get_properties(&self, ins: local::GetPropertiesIns) -> local::GetPropertiesRes; + + /// Refine the provided weights using the locally held dataset + /// + /// The training instructions contain (global) model parameters + /// received from the server and a dictionary of configuration + /// values used to customize the local training process. + fn fit(&self, ins: local::FitIns) -> local::FitRes; + + /// Evaluate the provided weights using the locally held dataset. + /// + /// The evaluation instructions contain (global) model parameters + /// received from the server and a dictionary of configuration + /// values used to customize the local evaluation process. + fn evaluate(&self, ins: local::EvaluateIns) -> local::EvaluateRes; +} diff --git a/src/rust/flwr/src/flwr.proto.rs b/src/rust/flwr/src/flwr.proto.rs new file mode 100644 index 000000000000..b9c63c01fc95 --- /dev/null +++ b/src/rust/flwr/src/flwr.proto.rs @@ -0,0 +1,764 @@ +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Node { + #[prost(sint64, tag = "1")] + pub node_id: i64, + #[prost(bool, tag = "2")] + pub anonymous: bool, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Status { + #[prost(enumeration = "Code", tag = "1")] + pub code: i32, + #[prost(string, tag = "2")] + pub message: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Parameters { + #[prost(bytes = "vec", repeated, tag = "1")] + pub tensors: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + #[prost(string, tag = "2")] + pub tensor_type: ::prost::alloc::string::String, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ServerMessage { + #[prost(oneof = "server_message::Msg", tags = "1, 2, 3, 4, 5")] + pub msg: ::core::option::Option, +} +/// Nested message and enum types in `ServerMessage`. +pub mod server_message { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct ReconnectIns { + #[prost(int64, tag = "1")] + pub seconds: i64, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct GetPropertiesIns { + #[prost(map = "string, message", tag = "1")] + pub config: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct GetParametersIns { + #[prost(map = "string, message", tag = "1")] + pub config: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct FitIns { + #[prost(message, optional, tag = "1")] + pub parameters: ::core::option::Option, + #[prost(map = "string, message", tag = "2")] + pub config: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct EvaluateIns { + #[prost(message, optional, tag = "1")] + pub parameters: ::core::option::Option, + #[prost(map = "string, message", tag = "2")] + pub config: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Msg { + #[prost(message, tag = "1")] + ReconnectIns(ReconnectIns), + #[prost(message, tag = "2")] + GetPropertiesIns(GetPropertiesIns), + #[prost(message, tag = "3")] + GetParametersIns(GetParametersIns), + #[prost(message, tag = "4")] + FitIns(FitIns), + #[prost(message, tag = "5")] + EvaluateIns(EvaluateIns), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ClientMessage { + #[prost(oneof = "client_message::Msg", tags = "1, 2, 3, 4, 5")] + pub msg: ::core::option::Option, +} +/// Nested message and enum types in `ClientMessage`. +pub mod client_message { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct DisconnectRes { + #[prost(enumeration = "super::Reason", tag = "1")] + pub reason: i32, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct GetPropertiesRes { + #[prost(message, optional, tag = "1")] + pub status: ::core::option::Option, + #[prost(map = "string, message", tag = "2")] + pub properties: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct GetParametersRes { + #[prost(message, optional, tag = "1")] + pub status: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub parameters: ::core::option::Option, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct FitRes { + #[prost(message, optional, tag = "1")] + pub status: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub parameters: ::core::option::Option, + #[prost(int64, tag = "3")] + pub num_examples: i64, + #[prost(map = "string, message", tag = "4")] + pub metrics: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct EvaluateRes { + #[prost(message, optional, tag = "1")] + pub status: ::core::option::Option, + #[prost(float, tag = "2")] + pub loss: f32, + #[prost(int64, tag = "3")] + pub num_examples: i64, + #[prost(map = "string, message", tag = "4")] + pub metrics: ::std::collections::HashMap< + ::prost::alloc::string::String, + super::Scalar, + >, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Msg { + #[prost(message, tag = "1")] + DisconnectRes(DisconnectRes), + #[prost(message, tag = "2")] + GetPropertiesRes(GetPropertiesRes), + #[prost(message, tag = "3")] + GetParametersRes(GetParametersRes), + #[prost(message, tag = "4")] + FitRes(FitRes), + #[prost(message, tag = "5")] + EvaluateRes(EvaluateRes), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Scalar { + /// The following `oneof` contains all types that ProtoBuf considers to be + /// "Scalar Value Types". Commented-out types are listed for reference and + /// might be enabled in future releases. Source: + /// + #[prost(oneof = "scalar::Scalar", tags = "1, 8, 13, 14, 15")] + pub scalar: ::core::option::Option, +} +/// Nested message and enum types in `Scalar`. +pub mod scalar { + /// The following `oneof` contains all types that ProtoBuf considers to be + /// "Scalar Value Types". Commented-out types are listed for reference and + /// might be enabled in future releases. Source: + /// + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Scalar { + #[prost(double, tag = "1")] + Double(f64), + /// float float = 2; + /// int32 int32 = 3; + /// int64 int64 = 4; + /// uint32 uint32 = 5; + /// uint64 uint64 = 6; + /// sint32 sint32 = 7; + #[prost(sint64, tag = "8")] + Sint64(i64), + /// fixed32 fixed32 = 9; + /// fixed64 fixed64 = 10; + /// sfixed32 sfixed32 = 11; + /// sfixed64 sfixed64 = 12; + #[prost(bool, tag = "13")] + Bool(bool), + #[prost(string, tag = "14")] + String(::prost::alloc::string::String), + #[prost(bytes, tag = "15")] + Bytes(::prost::alloc::vec::Vec), + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Code { + Ok = 0, + GetPropertiesNotImplemented = 1, + GetParametersNotImplemented = 2, + FitNotImplemented = 3, + EvaluateNotImplemented = 4, +} +impl Code { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Code::Ok => "OK", + Code::GetPropertiesNotImplemented => "GET_PROPERTIES_NOT_IMPLEMENTED", + Code::GetParametersNotImplemented => "GET_PARAMETERS_NOT_IMPLEMENTED", + Code::FitNotImplemented => "FIT_NOT_IMPLEMENTED", + Code::EvaluateNotImplemented => "EVALUATE_NOT_IMPLEMENTED", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "OK" => Some(Self::Ok), + "GET_PROPERTIES_NOT_IMPLEMENTED" => Some(Self::GetPropertiesNotImplemented), + "GET_PARAMETERS_NOT_IMPLEMENTED" => Some(Self::GetParametersNotImplemented), + "FIT_NOT_IMPLEMENTED" => Some(Self::FitNotImplemented), + "EVALUATE_NOT_IMPLEMENTED" => Some(Self::EvaluateNotImplemented), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum Reason { + Unknown = 0, + Reconnect = 1, + PowerDisconnected = 2, + WifiUnavailable = 3, + Ack = 4, +} +impl Reason { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Reason::Unknown => "UNKNOWN", + Reason::Reconnect => "RECONNECT", + Reason::PowerDisconnected => "POWER_DISCONNECTED", + Reason::WifiUnavailable => "WIFI_UNAVAILABLE", + Reason::Ack => "ACK", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "UNKNOWN" => Some(Self::Unknown), + "RECONNECT" => Some(Self::Reconnect), + "POWER_DISCONNECTED" => Some(Self::PowerDisconnected), + "WIFI_UNAVAILABLE" => Some(Self::WifiUnavailable), + "ACK" => Some(Self::Ack), + _ => None, + } + } +} +/// Generated client implementations. +pub mod flower_service_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct FlowerServiceClient { + inner: tonic::client::Grpc, + } + impl FlowerServiceClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl FlowerServiceClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> FlowerServiceClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + FlowerServiceClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn join( + &mut self, + request: impl tonic::IntoStreamingRequest, + ) -> std::result::Result< + tonic::Response>, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flwr.proto.FlowerService/Join", + ); + let mut req = request.into_streaming_request(); + req.extensions_mut() + .insert(GrpcMethod::new("flwr.proto.FlowerService", "Join")); + self.inner.streaming(req, path, codec).await + } + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Task { + #[prost(message, optional, tag = "1")] + pub producer: ::core::option::Option, + #[prost(message, optional, tag = "2")] + pub consumer: ::core::option::Option, + #[prost(string, tag = "3")] + pub created_at: ::prost::alloc::string::String, + #[prost(string, tag = "4")] + pub delivered_at: ::prost::alloc::string::String, + #[prost(string, tag = "5")] + pub ttl: ::prost::alloc::string::String, + #[prost(string, repeated, tag = "6")] + pub ancestry: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + #[prost(message, optional, tag = "7")] + pub sa: ::core::option::Option, + #[deprecated] + #[prost(message, optional, tag = "101")] + pub legacy_server_message: ::core::option::Option, + #[deprecated] + #[prost(message, optional, tag = "102")] + pub legacy_client_message: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TaskIns { + #[prost(string, tag = "1")] + pub task_id: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub group_id: ::prost::alloc::string::String, + #[prost(sint64, tag = "3")] + pub workload_id: i64, + #[prost(message, optional, tag = "4")] + pub task: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct TaskRes { + #[prost(string, tag = "1")] + pub task_id: ::prost::alloc::string::String, + #[prost(string, tag = "2")] + pub group_id: ::prost::alloc::string::String, + #[prost(sint64, tag = "3")] + pub workload_id: i64, + #[prost(message, optional, tag = "4")] + pub task: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Value { + #[prost(oneof = "value::Value", tags = "1, 2, 3, 4, 5, 21, 22, 23, 24, 25")] + pub value: ::core::option::Option, +} +/// Nested message and enum types in `Value`. +pub mod value { + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct DoubleList { + #[prost(double, repeated, tag = "1")] + pub vals: ::prost::alloc::vec::Vec, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct Sint64List { + #[prost(sint64, repeated, tag = "1")] + pub vals: ::prost::alloc::vec::Vec, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct BoolList { + #[prost(bool, repeated, tag = "1")] + pub vals: ::prost::alloc::vec::Vec, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct StringList { + #[prost(string, repeated, tag = "1")] + pub vals: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct BytesList { + #[prost(bytes = "vec", repeated, tag = "1")] + pub vals: ::prost::alloc::vec::Vec<::prost::alloc::vec::Vec>, + } + #[allow(clippy::derive_partial_eq_without_eq)] + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Value { + /// Single element + #[prost(double, tag = "1")] + Double(f64), + #[prost(sint64, tag = "2")] + Sint64(i64), + #[prost(bool, tag = "3")] + Bool(bool), + #[prost(string, tag = "4")] + String(::prost::alloc::string::String), + #[prost(bytes, tag = "5")] + Bytes(::prost::alloc::vec::Vec), + /// List types + #[prost(message, tag = "21")] + DoubleList(DoubleList), + #[prost(message, tag = "22")] + Sint64List(Sint64List), + #[prost(message, tag = "23")] + BoolList(BoolList), + #[prost(message, tag = "24")] + StringList(StringList), + #[prost(message, tag = "25")] + BytesList(BytesList), + } +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct SecureAggregation { + #[prost(map = "string, message", tag = "1")] + pub named_values: ::std::collections::HashMap<::prost::alloc::string::String, Value>, +} +/// CreateNode messages +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CreateNodeRequest {} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CreateNodeResponse { + #[prost(message, optional, tag = "1")] + pub node: ::core::option::Option, +} +/// DeleteNode messages +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DeleteNodeRequest { + #[prost(message, optional, tag = "1")] + pub node: ::core::option::Option, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DeleteNodeResponse {} +/// PullTaskIns messages +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PullTaskInsRequest { + #[prost(message, optional, tag = "1")] + pub node: ::core::option::Option, + #[prost(string, repeated, tag = "2")] + pub task_ids: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PullTaskInsResponse { + #[prost(message, optional, tag = "1")] + pub reconnect: ::core::option::Option, + #[prost(message, repeated, tag = "2")] + pub task_ins_list: ::prost::alloc::vec::Vec, +} +/// PushTaskRes messages +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PushTaskResRequest { + #[prost(message, repeated, tag = "1")] + pub task_res_list: ::prost::alloc::vec::Vec, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PushTaskResResponse { + #[prost(message, optional, tag = "1")] + pub reconnect: ::core::option::Option, + #[prost(map = "string, uint32", tag = "2")] + pub results: ::std::collections::HashMap<::prost::alloc::string::String, u32>, +} +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct Reconnect { + #[prost(uint64, tag = "1")] + pub reconnect: u64, +} +/// Generated client implementations. +pub mod fleet_client { + #![allow(unused_variables, dead_code, missing_docs, clippy::let_unit_value)] + use tonic::codegen::*; + use tonic::codegen::http::Uri; + #[derive(Debug, Clone)] + pub struct FleetClient { + inner: tonic::client::Grpc, + } + impl FleetClient { + /// Attempt to create a new client by connecting to a given endpoint. + pub async fn connect(dst: D) -> Result + where + D: TryInto, + D::Error: Into, + { + let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; + Ok(Self::new(conn)) + } + } + impl FleetClient + where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, + { + pub fn new(inner: T) -> Self { + let inner = tonic::client::Grpc::new(inner); + Self { inner } + } + pub fn with_origin(inner: T, origin: Uri) -> Self { + let inner = tonic::client::Grpc::with_origin(inner, origin); + Self { inner } + } + pub fn with_interceptor( + inner: T, + interceptor: F, + ) -> FleetClient> + where + F: tonic::service::Interceptor, + T::ResponseBody: Default, + T: tonic::codegen::Service< + http::Request, + Response = http::Response< + >::ResponseBody, + >, + >, + , + >>::Error: Into + Send + Sync, + { + FleetClient::new(InterceptedService::new(inner, interceptor)) + } + /// Compress requests with the given encoding. + /// + /// This requires the server to support it otherwise it might respond with an + /// error. + #[must_use] + pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.send_compressed(encoding); + self + } + /// Enable decompressing responses. + #[must_use] + pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self { + self.inner = self.inner.accept_compressed(encoding); + self + } + /// Limits the maximum size of a decoded message. + /// + /// Default: `4MB` + #[must_use] + pub fn max_decoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_decoding_message_size(limit); + self + } + /// Limits the maximum size of an encoded message. + /// + /// Default: `usize::MAX` + #[must_use] + pub fn max_encoding_message_size(mut self, limit: usize) -> Self { + self.inner = self.inner.max_encoding_message_size(limit); + self + } + pub async fn create_node( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flwr.proto.Fleet/CreateNode", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("flwr.proto.Fleet", "CreateNode")); + self.inner.unary(req, path, codec).await + } + pub async fn delete_node( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flwr.proto.Fleet/DeleteNode", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("flwr.proto.Fleet", "DeleteNode")); + self.inner.unary(req, path, codec).await + } + /// Retrieve one or more tasks, if possible + /// + /// HTTP API path: /api/v1/fleet/pull-task-ins + pub async fn pull_task_ins( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flwr.proto.Fleet/PullTaskIns", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("flwr.proto.Fleet", "PullTaskIns")); + self.inner.unary(req, path, codec).await + } + /// Complete one or more tasks, if possible + /// + /// HTTP API path: /api/v1/fleet/push-task-res + pub async fn push_task_res( + &mut self, + request: impl tonic::IntoRequest, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + self.inner + .ready() + .await + .map_err(|e| { + tonic::Status::new( + tonic::Code::Unknown, + format!("Service was not ready: {}", e.into()), + ) + })?; + let codec = tonic::codec::ProstCodec::default(); + let path = http::uri::PathAndQuery::from_static( + "/flwr.proto.Fleet/PushTaskRes", + ); + let mut req = request.into_request(); + req.extensions_mut() + .insert(GrpcMethod::new("flwr.proto.Fleet", "PushTaskRes")); + self.inner.unary(req, path, codec).await + } + } +} diff --git a/src/rust/flwr/src/grpc_bidi.rs b/src/rust/flwr/src/grpc_bidi.rs new file mode 100644 index 000000000000..5c4c438123a8 --- /dev/null +++ b/src/rust/flwr/src/grpc_bidi.rs @@ -0,0 +1,90 @@ +use crate::flwr_proto as proto; +use std::path::Path; + +use async_channel::Sender; +use futures::StreamExt; +use tonic::transport::channel::ClientTlsConfig; +use tonic::transport::{Certificate, Channel}; +use tonic::{Request, Streaming}; + +use uuid::Uuid; + +pub struct GrpcConnection { + channel: tonic::transport::Channel, + stub: proto::flower_service_client::FlowerServiceClient, + queue: Sender, + server_message_iterator: Streaming, +} + +impl GrpcConnection { + pub async fn new( + server_address: &str, + root_certificates: Option<&Path>, + ) -> Result> { + let mut builder = Channel::builder(server_address.parse()?); + + // For now, skipping max_message_length because it's not directly supported in Tonic + // Check Tonic's documentation or source for workarounds or updates regarding this + + if let Some(root_certificates) = root_certificates { + let pem = tokio::fs::read(root_certificates).await?; + let cert = Certificate::from_pem(pem); + let tls_config = ClientTlsConfig::new().ca_certificate(cert); + builder = builder.tls_config(tls_config)?; + } + + let channel = builder.connect().await?; + let mut stub = proto::flower_service_client::FlowerServiceClient::new(channel.clone()); + + let (tx, rx) = async_channel::bounded(1); // This is our queue equivalent in async Rust + + let response = stub.join(Request::new(rx)).await?; + let server_message_stream = response.into_inner(); + + Ok(GrpcConnection { + channel, + stub, + queue: tx, + server_message_iterator: server_message_stream, + }) + } + + pub async fn receive(&mut self) -> Result> { + if let Some(Ok(server_message)) = self.server_message_iterator.next().await { + let task_ins = proto::TaskIns { + group_id: "".to_string(), + workload_id: 0, + task_id: Uuid::new_v4().to_string(), + task: Some(proto::Task { + producer: Some(proto::Node { + node_id: 0, + anonymous: true, + }), + consumer: Some(proto::Node { + node_id: 0, + anonymous: true, + }), + ancestry: vec![], + legacy_server_message: Some(server_message), + ..Default::default() + }), + }; + Ok(task_ins) + } else { + Err(Box::new(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to get server message.", + ))) + } + } + + pub async fn send(&self, task_res: proto::TaskRes) -> Result<(), Box> { + let client_message = task_res + .task + .unwrap_or_default() + .legacy_client_message + .unwrap(); + self.queue.send(client_message).await?; + Ok(()) + } +} diff --git a/src/rust/flwr/src/grpc_rere.rs b/src/rust/flwr/src/grpc_rere.rs new file mode 100644 index 000000000000..8e61b2510d1a --- /dev/null +++ b/src/rust/flwr/src/grpc_rere.rs @@ -0,0 +1,137 @@ +use crate::flwr_proto as proto; +use crate::task_handler; +use std::path::Path; + +use tonic::transport::channel::ClientTlsConfig; +use tonic::transport::{Certificate, Channel}; + +const KEY_TASK_INS: &str = "current_task_ins"; +const KEY_NODE: &str = "node"; + +pub struct GrpcRereConnection { + stub: proto::fleet_client::FleetClient, + state: std::collections::HashMap>, + node_store: std::collections::HashMap>, +} + +impl GrpcRereConnection { + pub async fn new( + server_address: &str, + root_certificates: Option<&Path>, + ) -> Result> { + let mut builder = Channel::builder(server_address.parse()?); + + // For now, skipping max_message_length because it's not directly supported in Tonic + // Check Tonic's documentation or source for workarounds or updates regarding this + + if let Some(root_certificates) = root_certificates { + let pem = tokio::fs::read(root_certificates).await?; + let cert = Certificate::from_pem(pem); + let tls_config = ClientTlsConfig::new().ca_certificate(cert); + builder = builder.tls_config(tls_config)?; + } + + let channel = builder.connect().await?; + let stub = proto::fleet_client::FleetClient::new(channel.clone()); + let state = std::collections::HashMap::from([(KEY_TASK_INS.to_string(), None)]); + let node_store = std::collections::HashMap::from([(KEY_NODE.to_string(), None)]); + + Ok(GrpcRereConnection { + stub, + state, + node_store, + }) + } + + pub async fn create_node(&mut self) -> Result<(), Box> { + let create_node_request = proto::CreateNodeRequest::default(); + let create_node_response = self + .stub + .create_node(create_node_request) + .await? + .into_inner(); + self.node_store + .insert(KEY_NODE.to_string(), create_node_response.node); + Ok(()) + } + + pub async fn delete_node(&mut self) -> Result<(), Box> { + let node = match self.node_store.get(&KEY_NODE.to_string()) { + Some(Some(n)) => n.clone(), + _ => { + eprintln!("Node instance missing"); + return Err("Node instance missing".into()); + } + }; + + let delete_node_request = proto::DeleteNodeRequest { node: Some(node) }; + self.stub.delete_node(delete_node_request).await?; + Ok(()) + } + + pub async fn receive(&mut self) -> Result, Box> { + let node = match self.node_store.get(KEY_NODE) { + Some(Some(n)) => n.clone(), + _ => { + eprintln!("Node instance missing"); + return Err("Node instance missing".into()); + } + }; + + let request = proto::PullTaskInsRequest { + node: Some(node), + ..Default::default() + }; + let response = self.stub.pull_task_ins(request).await?.into_inner(); + + let mut task_ins = task_handler::get_task_ins(&response); + if let Some(ref ti) = task_ins { + if !task_handler::validate_task_ins(ti, true) { + task_ins = None; + } + } + + self.state + .insert(KEY_TASK_INS.to_string(), task_ins.clone()); + + Ok(task_ins) + } + + pub async fn send( + &mut self, + task_res: proto::TaskRes, + ) -> Result<(), Box> { + let node = match self.node_store.get(KEY_NODE) { + Some(Some(n)) => n.clone(), + _ => { + eprintln!("Node instance missing"); + return Err("Node instance missing".into()); + } + }; + + let task_ins = match self.state.get(KEY_TASK_INS) { + Some(Some(ti)) => ti.clone(), + _ => { + eprintln!("No current TaskIns"); + return Err("No current TaskIns".into()); + } + }; + + if !task_handler::validate_task_res(&task_res) { + self.state.insert(KEY_TASK_INS.to_string(), None); + eprintln!("TaskRes is invalid"); + return Err("TaskRes is invalid".into()); + } + + let task_res = task_handler::configure_task_res(task_res, &task_ins, node); + + let request = proto::PushTaskResRequest { + task_res_list: vec![task_res], + }; + self.stub.push_task_res(request).await?; + + self.state.insert(KEY_TASK_INS.to_string(), None); + + Ok(()) + } +} diff --git a/src/rust/flwr/src/lib.rs b/src/rust/flwr/src/lib.rs new file mode 100644 index 000000000000..d53f7a1f3ff6 --- /dev/null +++ b/src/rust/flwr/src/lib.rs @@ -0,0 +1,12 @@ +pub mod client; +pub mod grpc_bidi; +pub mod grpc_rere; +pub mod message_handler; +pub mod serde; +pub mod start; +pub mod task_handler; +pub mod typing; + +pub mod flwr_proto { + include!("flwr.proto.rs"); +} diff --git a/src/rust/flwr/src/message_handler.rs b/src/rust/flwr/src/message_handler.rs new file mode 100644 index 000000000000..f36f6c2d18f9 --- /dev/null +++ b/src/rust/flwr/src/message_handler.rs @@ -0,0 +1,104 @@ +use crate::client; +use crate::flwr_proto as proto; +use crate::serde; +use crate::task_handler; + +fn reconnect(reconnect_msg: proto::server_message::ReconnectIns) -> (proto::ClientMessage, i64) { + let mut reason = proto::Reason::Ack; + let mut sleep_duration = 0; + if reconnect_msg.seconds != 0 { + reason = proto::Reason::Reconnect; + sleep_duration = reconnect_msg.seconds; + } + + let disconnect_res = proto::client_message::DisconnectRes { + reason: reason.into(), + }; + + return ( + proto::ClientMessage { + msg: Some(proto::client_message::Msg::DisconnectRes(disconnect_res)), + }, + sleep_duration, + ); +} + +fn get_properties( + client: &dyn client::Client, + get_properties_msg: proto::server_message::GetPropertiesIns, +) -> proto::ClientMessage { + let get_properties_res = serde::get_properties_res_to_proto( + client.get_properties(serde::get_properties_ins_from_proto(get_properties_msg)), + ); + + proto::ClientMessage { + msg: Some(proto::client_message::Msg::GetPropertiesRes( + get_properties_res, + )), + } +} + +fn get_parameters( + client: &dyn client::Client, + // get_parameters_msg: proto::server_message::GetParametersIns, +) -> proto::ClientMessage { + let res = serde::parameter_res_to_proto(client.get_parameters()); + proto::ClientMessage { + msg: Some(proto::client_message::Msg::GetParametersRes(res)), + } +} + +fn fit( + client: &dyn client::Client, + fit_msg: proto::server_message::FitIns, +) -> proto::ClientMessage { + let res = serde::fit_res_to_proto(client.fit(serde::fit_ins_from_proto(fit_msg))); + proto::ClientMessage { + msg: Some(proto::client_message::Msg::FitRes(res)), + } +} + +fn evaluate( + client: &dyn client::Client, + evaluate_msg: proto::server_message::EvaluateIns, +) -> proto::ClientMessage { + let res = + serde::evaluate_res_to_proto(client.evaluate(serde::evaluate_ins_from_proto(evaluate_msg))); + proto::ClientMessage { + msg: Some(proto::client_message::Msg::EvaluateRes(res)), + } +} + +fn handle_legacy_message( + client: &dyn client::Client, + server_msg: proto::ServerMessage, +) -> Result<(proto::ClientMessage, i64, bool), &str> { + match server_msg.msg { + Some(proto::server_message::Msg::ReconnectIns(reconnect_ins)) => { + let rec = reconnect(reconnect_ins); + Ok((rec.0, rec.1, false)) + } + Some(proto::server_message::Msg::GetParametersIns(_)) => { + Ok((get_parameters(client), 0, true)) + } + Some(proto::server_message::Msg::FitIns(fit_ins)) => Ok((fit(client, fit_ins), 0, true)), + Some(proto::server_message::Msg::EvaluateIns(evaluate_ins)) => { + Ok((evaluate(client, evaluate_ins), 0, true)) + } + _ => Err("Unknown server message"), + } +} + +pub fn handle( + client: &dyn client::Client, + task_ins: proto::TaskIns, +) -> Result<(proto::TaskRes, i64, bool), &str> { + let server_msg = task_handler::get_server_message_from_task_ins(&task_ins, false); + if server_msg.is_none() { + return Err("Not implemented"); + } + let (client_msg, sleep_duration, keep_going) = + handle_legacy_message(client, server_msg.unwrap())?; + let task_res = task_handler::wrap_client_message_in_task_res(client_msg); + Ok((task_res, sleep_duration, keep_going)) +} diff --git a/src/rust/flwr/src/serde.rs b/src/rust/flwr/src/serde.rs new file mode 100644 index 000000000000..72fb88135966 --- /dev/null +++ b/src/rust/flwr/src/serde.rs @@ -0,0 +1,165 @@ +use crate::flwr_proto as proto; +use crate::typing as local; + +pub fn parameters_to_proto(parameters: local::Parameters) -> proto::Parameters { + return proto::Parameters { + tensors: parameters.tensors, + tensor_type: parameters.tensor_type, + }; +} + +pub fn parameters_from_proto(params_msg: proto::Parameters) -> local::Parameters { + return local::Parameters { + tensors: params_msg.tensors, + tensor_type: params_msg.tensor_type, + }; +} +pub fn scalar_to_proto(scalar: local::Scalar) -> proto::Scalar { + match scalar { + local::Scalar::Bool(value) => proto::Scalar { + scalar: Some(proto::scalar::Scalar::Bool(value)), + }, + local::Scalar::Bytes(value) => proto::Scalar { + scalar: Some(proto::scalar::Scalar::Bytes(value)), + }, + local::Scalar::Float(value) => proto::Scalar { + scalar: Some(proto::scalar::Scalar::Double(value as f64)), + }, + local::Scalar::Int(value) => proto::Scalar { + scalar: Some(proto::scalar::Scalar::Sint64(value as i64)), + }, + local::Scalar::Str(value) => proto::Scalar { + scalar: Some(proto::scalar::Scalar::String(value)), + }, + } +} + +pub fn scalar_from_proto(scalar_msg: proto::Scalar) -> local::Scalar { + match &scalar_msg.scalar { + Some(proto::scalar::Scalar::Double(value)) => local::Scalar::Float(*value as f32), + Some(proto::scalar::Scalar::Sint64(value)) => local::Scalar::Int(*value as i32), + Some(proto::scalar::Scalar::Bool(value)) => local::Scalar::Bool(*value), + Some(proto::scalar::Scalar::String(value)) => local::Scalar::Str(value.clone()), + Some(proto::scalar::Scalar::Bytes(value)) => local::Scalar::Bytes(value.clone()), + None => panic!("Error scalar type"), + } +} + +pub fn metrics_to_proto( + metrics: local::Metrics, +) -> std::collections::HashMap { + let mut proto_metrics = std::collections::HashMap::new(); + + for (key, value) in metrics.iter() { + proto_metrics.insert(key.clone(), scalar_to_proto(value.clone())); + } + + return proto_metrics; +} + +pub fn metrics_from_proto( + proto_metrics: std::collections::HashMap, +) -> local::Metrics { + let mut metrics = local::Metrics::new(); + + for (key, value) in proto_metrics.iter() { + metrics.insert(key.clone(), scalar_from_proto(value.clone())); + } + + return metrics; +} + +pub fn parameter_res_to_proto( + res: local::GetParametersRes, +) -> proto::client_message::GetParametersRes { + return proto::client_message::GetParametersRes { + parameters: Some(parameters_to_proto(res.parameters)), + status: Some(status_to_proto(res.status)), + }; +} + +pub fn fit_ins_from_proto(fit_ins_msg: proto::server_message::FitIns) -> local::FitIns { + local::FitIns { + parameters: parameters_from_proto(fit_ins_msg.parameters.unwrap()), + config: metrics_from_proto(fit_ins_msg.config.into_iter().collect()), + } +} + +pub fn fit_res_to_proto(res: local::FitRes) -> proto::client_message::FitRes { + return proto::client_message::FitRes { + parameters: Some(parameters_to_proto(res.parameters)), + num_examples: res.num_examples.into(), + metrics: if res.metrics.len() > 0 { + metrics_to_proto(res.metrics) + } else { + Default::default() + }, + status: Some(status_to_proto(res.status)), + }; +} + +pub fn evaluate_ins_from_proto( + evaluate_ins_msg: proto::server_message::EvaluateIns, +) -> local::EvaluateIns { + local::EvaluateIns { + parameters: parameters_from_proto(evaluate_ins_msg.parameters.unwrap()), + config: metrics_from_proto(evaluate_ins_msg.config.into_iter().collect()), + } +} + +pub fn evaluate_res_to_proto(res: local::EvaluateRes) -> proto::client_message::EvaluateRes { + return proto::client_message::EvaluateRes { + loss: res.loss.into(), + num_examples: res.num_examples.into(), + metrics: if res.metrics.len() > 0 { + metrics_to_proto(res.metrics) + } else { + Default::default() + }, + status: Default::default(), + }; +} + +fn status_to_proto(status: local::Status) -> proto::Status { + return proto::Status { + code: status.code as i32, + message: status.message, + }; +} + +pub fn get_properties_ins_from_proto( + get_properties_msg: proto::server_message::GetPropertiesIns, +) -> local::GetPropertiesIns { + return local::GetPropertiesIns { + config: properties_from_proto(get_properties_msg.config), + }; +} + +pub fn get_properties_res_to_proto( + res: local::GetPropertiesRes, +) -> proto::client_message::GetPropertiesRes { + return proto::client_message::GetPropertiesRes { + properties: properties_to_proto(res.properties), + status: Some(status_to_proto(res.status)), + }; +} + +fn properties_from_proto( + proto: std::collections::HashMap, +) -> local::Properties { + let mut properties = std::collections::HashMap::new(); + for (k, v) in proto.iter() { + properties.insert(k.clone(), scalar_from_proto(v.clone())); + } + return properties; +} + +fn properties_to_proto( + properties: local::Properties, +) -> std::collections::HashMap { + let mut proto = std::collections::HashMap::new(); + for (k, v) in properties.iter() { + proto.insert(k.clone(), scalar_to_proto(v.clone())); + } + return proto; +} diff --git a/src/rust/flwr/src/start.rs b/src/rust/flwr/src/start.rs new file mode 100644 index 000000000000..a78843ac3b8c --- /dev/null +++ b/src/rust/flwr/src/start.rs @@ -0,0 +1,88 @@ +use std::path::Path; +use std::time::Duration; + +use crate::grpc_bidi as bidi; +use crate::grpc_rere as rere; +use crate::message_handler as handler; + +use crate::client; + +pub async fn start_client( + address: &str, + client: &C, + root_certificates: Option<&Path>, + transport: Option<&str>, +) -> Result<(), Box> +where + C: client::Client, +{ + loop { + let mut sleep_duration: i64 = 0; + if transport.is_some() && transport == Some("rere") { + let mut conn = rere::GrpcRereConnection::new(address, root_certificates).await?; + // Register node + conn.create_node().await?; + loop { + match conn.receive().await { + Ok(Some(task_ins)) => match handler::handle(client, task_ins) { + Ok((task_res, new_sleep_duration, keep_going)) => { + println!("Task received! {}", task_res.task.is_some()); + sleep_duration = new_sleep_duration; + conn.send(task_res).await?; + if !keep_going { + break; + } + } + Err(e) => { + eprintln!("Error: {}", e); + return Err("Couldn't handle task".into()); + } + }, + Ok(None) => { + println!("No task received"); + tokio::time::sleep(Duration::from_secs(3)).await; // Wait for 3s before asking again + } + Err(e) => { + eprintln!("Error: {}", e); + return Err("Couldn't receive task".into()); + } + } + } + // Unregister node + conn.delete_node().await?; + } else { + let mut conn = bidi::GrpcConnection::new(address, root_certificates).await?; + loop { + match conn.receive().await { + Ok(task_ins) => match handler::handle(client, task_ins) { + Ok((task_res, new_sleep_duration, keep_going)) => { + sleep_duration = new_sleep_duration; + conn.send(task_res).await?; + if !keep_going { + break; + } + } + Err(e) => { + eprintln!("Error: {}", e); + } + }, + Err(_) => { + tokio::time::sleep(Duration::from_secs(3)).await; // Wait for 3s before asking again + } + } + } + } + + if sleep_duration == 0 { + println!("Disconnect and shut down"); + break; + } + + println!( + "Disconnect, then re-establish connection after {} second(s)", + sleep_duration + ); + tokio::time::sleep(Duration::from_secs(sleep_duration as u64)).await; + } + Ok(()) +} diff --git a/src/rust/flwr/src/task_handler.rs b/src/rust/flwr/src/task_handler.rs new file mode 100644 index 000000000000..5675fc1fef31 --- /dev/null +++ b/src/rust/flwr/src/task_handler.rs @@ -0,0 +1,99 @@ +use crate::flwr_proto as proto; + +pub fn validate_task_ins(task_ins: &proto::TaskIns, discard_reconnect_ins: bool) -> bool { + match &task_ins.task { + Some(task) => { + let has_legacy_server_msg = task.legacy_server_message.as_ref().map_or(false, |lsm| { + if discard_reconnect_ins { + !matches!(lsm.msg, Some(proto::server_message::Msg::ReconnectIns(..))) + } else { + true + } + }); + + has_legacy_server_msg || task.sa.is_some() + } + None => false, + } +} + +pub fn validate_task_res(task_res: &proto::TaskRes) -> bool { + // Check for initialization of fields in TaskRes + let task_res_is_uninitialized = + task_res.task_id.is_empty() && task_res.group_id.is_empty() && task_res.workload_id == 0; + + // Check for initialization of fields in Task, if Task is present + let task_is_uninitialized = task_res.task.as_ref().map_or(true, |task| { + task.producer.is_none() && task.consumer.is_none() && task.ancestry.is_empty() + }); + + task_res_is_uninitialized && task_is_uninitialized +} + +pub fn get_task_ins(pull_task_ins_response: &proto::PullTaskInsResponse) -> Option { + if pull_task_ins_response.task_ins_list.is_empty() { + return None; + } + + let task_ins = pull_task_ins_response.task_ins_list.first(); + return task_ins.cloned(); +} + +pub fn get_server_message_from_task_ins( + task_ins: &proto::TaskIns, + exclude_reconnect_ins: bool, +) -> Option { + if !validate_task_ins(task_ins, exclude_reconnect_ins) { + return None; + } + + match &task_ins.task { + Some(task) => { + if let Some(legacy_server_message) = &task.legacy_server_message { + Some(legacy_server_message.clone()) + } else { + None + } + } + None => None, + } +} + +pub fn wrap_client_message_in_task_res(client_message: proto::ClientMessage) -> proto::TaskRes { + return proto::TaskRes { + task_id: "".to_string(), + group_id: "".to_string(), + workload_id: 0, + task: Some(proto::Task { + ancestry: vec![], + legacy_client_message: Some(client_message), + ..Default::default() + }), + }; +} + +pub fn configure_task_res( + mut task_res: proto::TaskRes, + ref_task_ins: &proto::TaskIns, + producer: proto::Node, +) -> proto::TaskRes { + // Set group_id and workload_id + task_res.group_id = ref_task_ins.group_id.clone(); + task_res.workload_id = ref_task_ins.workload_id; + + // Check if task_res has a task field set; if not, initialize it. + if task_res.task.is_none() { + task_res.task = Some(proto::Task::default()); + } + + // Assuming the task is now Some, unwrap it and set its fields. + if let Some(ref mut task) = task_res.task { + task.producer = Some(producer); + task.consumer = ref_task_ins.task.as_ref().and_then(|t| t.producer.clone()); + + // Set ancestry to contain just ref_task_ins.task_id + task.ancestry = vec![ref_task_ins.task_id.clone()]; + } + + task_res +} diff --git a/src/rust/flwr/src/typing.rs b/src/rust/flwr/src/typing.rs new file mode 100644 index 000000000000..31b482ff2031 --- /dev/null +++ b/src/rust/flwr/src/typing.rs @@ -0,0 +1,127 @@ +use std::collections::HashMap; + +// Scalar and Value types as described for ProtoBuf +pub type Metrics = HashMap; +pub type MetricsAggregationFn = fn(Vec<(i32, Metrics)>) -> Metrics; +pub type Config = HashMap; +pub type Properties = HashMap; + +#[derive(Debug, Clone)] +pub enum Scalar { + Bool(bool), + Bytes(Vec), + Float(f32), + Int(i32), + Str(String), +} + +#[derive(Debug, Clone)] +pub enum Value { + Bool(bool), + Bytes(Vec), + Float(f32), + Int(i32), + Str(String), + ListBool(Vec), + ListBytes(Vec>), + ListFloat(Vec), + ListInt(Vec), + ListStr(Vec), +} + +#[derive(Debug, Clone, Copy)] +pub enum Code { + OK = 0, + GET_PROPERTIES_NOT_IMPLEMENTED = 1, + GET_PARAMETERS_NOT_IMPLEMENTED = 2, + FIT_NOT_IMPLEMENTED = 3, + EVALUATE_NOT_IMPLEMENTED = 4, +} + +#[derive(Debug, Clone)] +pub struct Status { + pub code: Code, + pub message: String, +} + +#[derive(Debug, Clone)] +pub struct Parameters { + pub tensors: Vec>, + pub tensor_type: String, +} + +#[derive(Debug, Clone)] +pub struct GetParametersIns { + pub config: Config, +} + +#[derive(Debug, Clone)] +pub struct GetParametersRes { + pub status: Status, + pub parameters: Parameters, +} + +#[derive(Debug, Clone)] +pub struct FitIns { + pub parameters: Parameters, + pub config: Config, +} + +#[derive(Debug, Clone)] +pub struct FitRes { + pub status: Status, + pub parameters: Parameters, + pub num_examples: i32, + pub metrics: Metrics, +} + +#[derive(Debug, Clone)] +pub struct EvaluateIns { + pub parameters: Parameters, + pub config: Config, +} + +#[derive(Debug, Clone)] +pub struct EvaluateRes { + pub status: Status, + pub loss: f32, + pub num_examples: i32, + pub metrics: Metrics, +} + +#[derive(Debug, Clone)] +pub struct GetPropertiesIns { + pub config: Config, +} + +#[derive(Debug, Clone)] +pub struct GetPropertiesRes { + pub status: Status, + pub properties: Properties, +} + +#[derive(Debug, Clone)] +pub struct ReconnectIns { + pub seconds: Option, +} + +#[derive(Debug, Clone)] +pub struct DisconnectRes { + pub reason: String, +} + +#[derive(Debug, Clone)] +pub struct ServerMessage { + pub get_properties_ins: Option, + pub get_parameters_ins: Option, + pub fit_ins: Option, + pub evaluate_ins: Option, +} + +#[derive(Debug, Clone)] +pub struct ClientMessage { + pub get_properties_res: Option, + pub get_parameters_res: Option, + pub fit_res: Option, + pub evaluate_res: Option, +}