Skip to content

Commit

Permalink
update with main
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 committed Jan 2, 2025
1 parent c28bcc2 commit 73fa4a6
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
6 changes: 2 additions & 4 deletions src/py/flwr/client/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from flwr.common.logger import log, warn_deprecated_feature
from flwr.common.message import Error
from flwr.common.retry_invoker import RetryInvoker, RetryState, exponential
from flwr.common.typing import Fab, Run, RunNotRunningException, UserConfig
from flwr.common.typing import Run, RunNotRunningException, UserConfig
from flwr.proto.clientappio_pb2_grpc import add_ClientAppIoServicer_to_server
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
from flwr.server.superlink.linkstate.utils import generate_rand_int_from_bytes
Expand Down Expand Up @@ -411,9 +411,7 @@ def _on_backoff(retry_state: RetryState) -> None:
# Call create_node fn to register node
# and store node_id in state
if (node_id := conn.create_node()) is None:
raise ValueError(
"Failed to register SuperNode with the SuperLink"
)
raise ValueError("Failed to register SuperNode with the SuperLink")
state.set_node_id(node_id)
run_info_store = DeprecatedRunInfoStore(
node_id=state.get_node_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def send(self, message: Message) -> None:
def get_run(self, run_id: int) -> Run:
"""Get run info."""
log(DEBUG, "GetRun API is not supported by GrpcBidiConnection.")
return Run(run_id, "", "", "", {})
return Run.create_empty(run_id)

def get_fab(self, fab_hash: str, run_id: int) -> Fab:
"""Get FAB file."""
Expand Down
35 changes: 20 additions & 15 deletions src/py/flwr/client/connection/rere_fleet_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,8 @@
from flwr.common.logger import log
from flwr.common.message import Message, Metadata
from flwr.common.retry_invoker import RetryInvoker
from flwr.common.serde import (
message_from_taskins,
message_to_taskres,
user_config_from_proto,
)
from flwr.common.typing import Fab, Run
from flwr.common.serde import message_from_taskins, message_to_taskres, run_from_proto
from flwr.common.typing import Fab, Run, RunNotRunningException
from flwr.proto.fab_pb2 import GetFabRequest, GetFabResponse # pylint: disable=E0611
from flwr.proto.fleet_pb2 import ( # pylint: disable=E0611
CreateNodeRequest,
Expand Down Expand Up @@ -78,6 +74,21 @@ def __init__( # pylint: disable=R0913, R0914, R0915, R0917
) = None,
) -> None:
"""Initialize the RereFleetConnection."""

def _should_giveup_fn(e: Exception) -> bool:
if not isinstance(e, grpc.RpcError):
return False
if e.code() == grpc.StatusCode.PERMISSION_DENIED:
raise RunNotRunningException
if e.code() == grpc.StatusCode.UNAVAILABLE:
return False
return True

# Restrict retries to cases where the status code is UNAVAILABLE
# If the status code is PERMISSION_DENIED,
# additionally raise RunNotRunningException
retry_invoker.should_giveup = _should_giveup_fn

super().__init__(
server_address=server_address,
insecure=insecure,
Expand Down Expand Up @@ -222,18 +233,12 @@ def get_run(self, run_id: int) -> Run:
)

# Return fab_id and fab_version
return Run(
run_id,
res.run.fab_id,
res.run.fab_version,
res.run.fab_hash,
user_config_from_proto(res.run.override_config),
)
return run_from_proto(res.run)

def get_fab(self, fab_hash: str) -> Fab:
def get_fab(self, fab_hash: str, run_id: int) -> Fab:
"""Get FAB file."""
# Call FleetAPI
req = GetFabRequest(node=self.node, hash_str=fab_hash)
req = GetFabRequest(node=self.node, hash_str=fab_hash, run_id=run_id)
res: GetFabResponse = self.retry_invoker.invoke(
self.api.GetFab,
request=req,
Expand Down

0 comments on commit 73fa4a6

Please sign in to comment.