Skip to content

Commit

Permalink
Merge branch 'main' into avoid-non-fleet-calls-in-node-auth
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Feb 4, 2025
2 parents ae5c173 + 0538cb6 commit 7ce4759
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 14 deletions.
4 changes: 2 additions & 2 deletions src/py/flwr/cli/auth_plugin/oidc_cli_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from flwr.common.auth_plugin import CliAuthPlugin
from flwr.common.constant import (
ACCESS_TOKEN_KEY,
AUTH_TYPE_KEY,
AUTH_TYPE_JSON_KEY,
REFRESH_TOKEN_KEY,
AuthType,
)
Expand Down Expand Up @@ -97,7 +97,7 @@ def store_tokens(self, credentials: UserAuthCredentials) -> None:
self.access_token = credentials.access_token
self.refresh_token = credentials.refresh_token
json_dict = {
AUTH_TYPE_KEY: AuthType.OIDC,
AUTH_TYPE_JSON_KEY: AuthType.OIDC,
ACCESS_TOKEN_KEY: credentials.access_token,
REFRESH_TOKEN_KEY: credentials.refresh_token,
}
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from flwr.cli.cli_user_auth_interceptor import CliUserAuthInterceptor
from flwr.common.auth_plugin import CliAuthPlugin
from flwr.common.constant import AUTH_TYPE_KEY, CREDENTIALS_DIR, FLWR_DIR
from flwr.common.constant import AUTH_TYPE_JSON_KEY, CREDENTIALS_DIR, FLWR_DIR
from flwr.common.grpc import (
GRPC_MAX_MESSAGE_LENGTH,
create_channel,
Expand Down Expand Up @@ -239,7 +239,7 @@ def try_obtain_cli_auth_plugin(
try:
with config_path.open("r", encoding="utf-8") as file:
json_file = json.load(file)
auth_type = json_file[AUTH_TYPE_KEY]
auth_type = json_file[AUTH_TYPE_JSON_KEY]
except (FileNotFoundError, KeyError):
typer.secho(
"❌ Missing or invalid credentials for user authentication. "
Expand Down
16 changes: 9 additions & 7 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,17 @@

# Constants for user authentication
CREDENTIALS_DIR = ".credentials"
AUTH_TYPE_KEY = "auth_type"
ACCESS_TOKEN_KEY = "access_token"
REFRESH_TOKEN_KEY = "refresh_token"
AUTH_TYPE_JSON_KEY = "auth-type" # For key name in JSON file
AUTH_TYPE_YAML_KEY = "auth_type" # For key name in YAML file
ACCESS_TOKEN_KEY = "flwr-oidc-access-token"
REFRESH_TOKEN_KEY = "flwr-oidc-refresh-token"

# Constants for node authentication
PUBLIC_KEY_HEADER = "public-key-bin" # Must end with "-bin" for binary data
SIGNATURE_HEADER = "signature-bin" # Must end with "-bin" for binary data
TIMESTAMP_HEADER = "timestamp"
TIMESTAMP_TOLERANCE = 10 # Tolerance for timestamp verification
PUBLIC_KEY_HEADER = "flwr-public-key-bin" # Must end with "-bin" for binary data
SIGNATURE_HEADER = "flwr-signature-bin" # Must end with "-bin" for binary data
TIMESTAMP_HEADER = "flwr-timestamp"
TIMESTAMP_TOLERANCE = 10 # General tolerance for timestamp verification
SYSTEM_TIME_TOLERANCE = 5 # Allowance for system time drift


class MessageType:
Expand Down
4 changes: 2 additions & 2 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from flwr.common.auth_plugin import ExecAuthPlugin
from flwr.common.config import get_flwr_dir, parse_config_args
from flwr.common.constant import (
AUTH_TYPE_KEY,
AUTH_TYPE_YAML_KEY,
CLIENT_OCTET,
EXEC_API_DEFAULT_SERVER_ADDRESS,
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
Expand Down Expand Up @@ -578,7 +578,7 @@ def _try_obtain_exec_auth_plugin(

# Load authentication configuration
auth_config: dict[str, Any] = config.get("authentication", {})
auth_type: str = auth_config.get(AUTH_TYPE_KEY, "")
auth_type: str = auth_config.get(AUTH_TYPE_YAML_KEY, "")

# Load authentication plugin
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flwr.common.constant import (
PUBLIC_KEY_HEADER,
SIGNATURE_HEADER,
SYSTEM_TIME_TOLERANCE,
TIMESTAMP_HEADER,
TIMESTAMP_TOLERANCE,
)
Expand All @@ -38,6 +39,9 @@
)
from flwr.server.superlink.linkstate import LinkStateFactory

MIN_TIMESTAMP_DIFF = -SYSTEM_TIME_TOLERANCE
MAX_TIMESTAMP_DIFF = TIMESTAMP_TOLERANCE + SYSTEM_TIME_TOLERANCE


def _unary_unary_rpc_terminator(
message: str, code: Any = grpc.StatusCode.UNAUTHENTICATED
Expand Down Expand Up @@ -109,7 +113,7 @@ def intercept_service( # pylint: disable=too-many-return-statements
current = now()
time_diff = current - datetime.datetime.fromisoformat(timestamp_iso)
# Abort the RPC call if the timestamp is too old or in the future
if not 0 < time_diff.total_seconds() < TIMESTAMP_TOLERANCE:
if not MIN_TIMESTAMP_DIFF < time_diff.total_seconds() < MAX_TIMESTAMP_DIFF:
return _unary_unary_rpc_terminator("Invalid timestamp")

# Continue the RPC call
Expand Down

0 comments on commit 7ce4759

Please sign in to comment.