Skip to content

Commit

Permalink
Consistency in names
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 11, 2024
1 parent 9bde45c commit b8a7421
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 238 deletions.
19 changes: 12 additions & 7 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Field,
TypeAdapter,
ValidationError,
field_serializer,
)

from blueapi.utils import BlueapiBaseModel, InvalidConfigError
Expand Down Expand Up @@ -100,17 +101,17 @@ def _config_from_oidc_url(self) -> dict[str, Any]:
return response.json()

@cached_property
def device_auth_url(self) -> str:
def device_authorization_endpoint(self) -> str:
return cast(
str, self._config_from_oidc_url.get("device_authorization_endpoint")
)

@cached_property
def auth_url(self) -> str:
def authorization_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("authorization_endpoint"))

@cached_property
def token_url(self) -> str:
def token_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("token_endpoint"))

@cached_property
Expand All @@ -122,19 +123,23 @@ def jwks_uri(self) -> str:
return cast(str, self._config_from_oidc_url.get("jwks_uri"))

@cached_property
def logout_url(self) -> str:
def end_session_endpoint(self) -> str:
return cast(str, self._config_from_oidc_url.get("end_session_endpoint"))

@cached_property
def signing_algos(self) -> list[str]:
def id_token_signing_alg_values_supported(self) -> list[str]:
return cast(
list[str],
self._config_from_oidc_url.get("id_token_signing_alg_values_supported"),
)


class CLIClientConfig(OIDCConfig):
token_file_path: Path = Path("~/token")
token_file_path: Path = Field(Path("~/token"))

@field_serializer("token_file_path")
def serialize_token_file_path(self, token_file_path: Path, _info):
return f"{token_file_path}"


class ApplicationConfig(BlueapiBaseModel):
Expand All @@ -148,7 +153,7 @@ class ApplicationConfig(BlueapiBaseModel):
logging: LoggingConfig = Field(default_factory=LoggingConfig)
api: RestConfig = Field(default_factory=RestConfig)
scratch: ScratchConfig | None = None
oidc_config: OIDCConfig | None = None
oidc_config: OIDCConfig | CLIClientConfig | None = None

def __eq__(self, other: object) -> bool:
if isinstance(other, ApplicationConfig):
Expand Down
54 changes: 26 additions & 28 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def decode_jwt(self, token: str) -> dict[str, str]:
return jwt.decode(
token,
signing_key.key,
algorithms=self._server_config.signing_algos,
algorithms=self._server_config.id_token_signing_alg_values_supported,
verify=True,
audience=self._server_config.client_audience,
issuer=self._server_config.issuer,
Expand Down Expand Up @@ -105,9 +105,8 @@ def logout(self) -> None:
def refresh_auth_token(self) -> None:
token = self._token_manager.load_token()
response = requests.post(
self._server_config.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
self._server_config.token_endpoint,
json={
"client_id": self._server_config.client_id,
"grant_type": "refresh_token",
"refresh_token": token["refresh_token"],
Expand All @@ -124,9 +123,8 @@ def poll_for_token(
expiry_time: float = time.time() + expires_in
while time.time() < expiry_time:
response = requests.post(
self._server_config.token_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
self._server_config.token_endpoint,
json={
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
"client_id": self._server_config.client_id,
Expand All @@ -138,24 +136,10 @@ def poll_for_token(

raise TimeoutError("Polling timed out")

def start_device_flow(self) -> None:
token = self._token_manager.load_token()
try:
self.authenticator.decode_jwt(token["access_token"])
print("Cached token still valid, skipping flow")
return
except jwt.ExpiredSignatureError:
token = self.refresh_auth_token()
print("Refreshed cached token, skipping flow")
return
except Exception:
print("Problem with cached token, starting new session")
self._token_manager.delete_token()

def _do_device_flow(self) -> None:
response: requests.Response = requests.post(
self._server_config.device_auth_url,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
self._server_config.device_authorization_endpoint,
json={
"client_id": self._server_config.client_id,
"scope": "openid profile offline_access",
"audience": self._server_config.client_audience,
Expand All @@ -175,7 +159,21 @@ def start_device_flow(self) -> None:
auth_token_json: dict[str, Any] = self.poll_for_token(
device_code, interval, expires_in
)
decoded_token: dict[str, Any] = self.authenticator.decode_jwt(
auth_token_json["access_token"]
)
self._token_manager.save_token(decoded_token)
self._token_manager.save_token(auth_token_json)

def start_device_flow(self) -> None:
try:
token = self._token_manager.load_token()
self.authenticator.decode_jwt(token["id_token"])
print("Cached token still valid, skipping flow")
return
except jwt.ExpiredSignatureError:
token = self.refresh_auth_token()
print("Refreshed cached token, skipping flow")
return
except FileNotFoundError:
self._do_device_flow()
except Exception as e:
print(e)
print("Problem with cached token, starting new session")
self._token_manager.delete_token()
6 changes: 3 additions & 3 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ def get_app(config: ApplicationConfig | None = None):

def verify_access_token(config: OIDCConfig):
oauth_scheme = OAuth2AuthorizationCodeBearer(
authorizationUrl=config.auth_url,
tokenUrl=config.token_url,
refreshUrl=config.token_url,
authorizationUrl=config.authorization_endpoint,
tokenUrl=config.token_endpoint,
refreshUrl=config.token_endpoint,
)
authenticator = Authenticator(config)

Expand Down
139 changes: 104 additions & 35 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
import jwt
import pytest
import responses
import responses.matchers
import yaml
from bluesky import RunEngine
from bluesky.run_engine import TransitionError
from jwcrypto.jwk import JWK
from observability_utils.tracing import setup_tracing
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.trace import get_tracer_provider

from blueapi.config import CLIClientConfig
from blueapi.config import ApplicationConfig, CLIClientConfig
from tests.unit_tests.utils.test_tracing import JsonObjectSpanExporter


Expand Down Expand Up @@ -70,17 +73,12 @@ def oidc_config(valid_oidc_url: str, tmp_path: Path) -> CLIClientConfig:


@pytest.fixture
def valid_auth_config(tmp_path: Path, valid_oidc_url: str) -> str:
config: str = f"""
oidc_config:
well_known_url: {valid_oidc_url}
client_id: "blueapi"
client_audience: "blueapi-cli"
token_file_path: {tmp_path / "auth_token"}
"""
with open(tmp_path / "auth_config.yaml", mode="w") as valid_auth_config_file:
valid_auth_config_file.write(config)
return valid_auth_config_file.name
def valid_auth_config(tmp_path: Path, oidc_config: CLIClientConfig) -> Path:
config = ApplicationConfig(oidc_config=oidc_config)
config_path = tmp_path / "auth_config.yaml"
with open(config_path, mode="w") as valid_auth_config_file:
valid_auth_config_file.write(yaml.dump(config.model_dump()))
return config_path


@pytest.fixture
Expand All @@ -96,20 +94,20 @@ def valid_oidc_config() -> dict[str, Any]:
}


def _make_token(name: str, issued_in: float, expires_in: float, tmp_path: Path) -> Path:
token_path = tmp_path / "token"
@pytest.fixture(scope="session")
def json_web_keyset() -> JWK:
return JWK.generate(kty="RSA", size=1024, kid="secret", use="sig", alg="RSA256")

now = time.time()

RSA_key = """-----BEGIN RSA PRIVATE KEY-----
MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu
KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm
o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k
TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7
9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy
v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs
/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00
-----END RSA PRIVATE KEY-----"""
@pytest.fixture(scope="session")
def rsa_private_key(json_web_keyset: JWK) -> str:
return json_web_keyset.export_to_pem("private_key", password=None).decode("utf-8")


def _make_token(
name: str, issued_in: float, expires_in: float, tmp_path: Path, rsa_private_key: str
) -> dict[str, str]:
now = time.time()

id_token = {
"aud": "default-demo",
Expand All @@ -124,28 +122,99 @@ def _make_token(name: str, issued_in: float, expires_in: float, tmp_path: Path)
"access_token": name,
"token_type": "Bearer",
"refresh_token": "refresh_token",
"id_token": f"{jwt.encode(id_token, key=RSA_key, algorithm="RS256")}",
"id_token": f"{jwt.encode(id_token, key=rsa_private_key, algorithm="RS256", headers={"kid": "secret"})}",
}
return response


@pytest.fixture
def cached_expired_token(tmp_path: Path, expired_token: dict[str, Any]) -> Path:
token_path = tmp_path / "token"
token_json = json.dumps(expired_token)
with open(token_path, "w") as token_file:
token_file.write(base64.b64encode(token_json.encode("utf-8")).decode("utf-8"))
return token_path


@pytest.fixture
def cached_valid_token(tmp_path: Path, valid_token: dict[str, Any]) -> Path:
token_path = tmp_path / "token"
token_json = json.dumps(valid_token)
with open(token_path, "w") as token_file:
token_file.write(
base64.b64encode(json.dumps(response).encode("utf-8")).decode("utf-8")
)
token_file.write(base64.b64encode(token_json.encode("utf-8")).decode("utf-8"))
return token_path


@pytest.fixture
def expired_token(tmp_path: Path) -> Path:
return _make_token("expired_token", -3600, -1800, tmp_path)
def expired_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("expired_token", -3600, -1800, tmp_path, rsa_private_key)


@pytest.fixture
def valid_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("valid_token", -900, +900, tmp_path, rsa_private_key)


@pytest.fixture
def valid_token(tmp_path: Path) -> Path:
return _make_token("expired_token", -1800, +1800, tmp_path)
def new_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("new_token", -100, +1700, tmp_path, rsa_private_key)


@pytest.fixture
def mock_authn_server(valid_oidc_url: str, valid_oidc_config: dict[str, Any]):
requests_mock = responses.RequestsMock(assert_all_requests_are_fired=True)
def mock_authn_server(
valid_oidc_url: str,
valid_oidc_config: dict[str, Any],
oidc_config: CLIClientConfig,
valid_token: dict[str, Any],
json_web_keyset: JWK,
new_token: dict[str, Any],
):
requests_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
# Fetch well-known OIDC flow URLs from server
requests_mock.get(valid_oidc_url, json=valid_oidc_config)
requests_mock.get(valid_oidc_config["jwks_uri"], json="")
requests_mock.get(
valid_oidc_config["jwks_uri"],
json={"keys": [json_web_keyset.export_public(as_dict=True)]},
)
# When device flow begins, return a device_code
device_code = "ff83j3dk"
requests_mock.post(
valid_oidc_config["device_authorization_endpoint"],
json={
"device_code": device_code,
"verification_uri_complete": valid_oidc_config["issuer"] + "/verify",
"expires_in": 30,
"interval": 5,
},
)

# When polled with device_code return token
requests_mock.post(
valid_oidc_config["token_endpoint"],
json=valid_token,
match=[
responses.matchers.json_params_matcher(
{
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": device_code,
"client_id": oidc_config.client_id,
}
),
],
)
# When asked to refresh with refresh_token return refreshed token
requests_mock.post(
valid_oidc_config["token_endpoint"],
json=new_token,
match=[
responses.matchers.json_params_matcher(
{
"client_id": oidc_config.client_id,
"grant_type": "refresh_token",
"refresh_token": "refresh_token",
},
)
],
)

return requests_mock
1 change: 0 additions & 1 deletion tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from blueapi.client.event_bus import AnyEvent
from blueapi.config import (
ApplicationConfig,
CLIClientConfig,
OIDCConfig,
StompConfig,
)
Expand Down
Loading

0 comments on commit b8a7421

Please sign in to comment.