diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 728f1aac5..7c901acd1 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -12,6 +12,7 @@ Field, TypeAdapter, ValidationError, + field_serializer, ) from blueapi.utils import BlueapiBaseModel, InvalidConfigError @@ -100,17 +101,17 @@ def _config_from_oidc_url(self) -> dict[str, Any]: return response.json() @cached_property - def device_auth_url(self) -> str: + def device_authorization_endpoint(self) -> str: return cast( str, self._config_from_oidc_url.get("device_authorization_endpoint") ) @cached_property - def auth_url(self) -> str: + def authorization_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("authorization_endpoint")) @cached_property - def token_url(self) -> str: + def token_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("token_endpoint")) @cached_property @@ -122,11 +123,11 @@ def jwks_uri(self) -> str: return cast(str, self._config_from_oidc_url.get("jwks_uri")) @cached_property - def logout_url(self) -> str: + def end_session_endpoint(self) -> str: return cast(str, self._config_from_oidc_url.get("end_session_endpoint")) @cached_property - def signing_algos(self) -> list[str]: + def id_token_signing_alg_values_supported(self) -> list[str]: return cast( list[str], self._config_from_oidc_url.get("id_token_signing_alg_values_supported"), @@ -134,7 +135,11 @@ def signing_algos(self) -> list[str]: class CLIClientConfig(OIDCConfig): - token_file_path: Path = Path("~/token") + token_file_path: Path = Field(Path("~/token")) + + @field_serializer("token_file_path") + def serialize_token_file_path(self, token_file_path: Path, _info): + return f"{token_file_path}" class ApplicationConfig(BlueapiBaseModel): @@ -148,7 +153,7 @@ class ApplicationConfig(BlueapiBaseModel): logging: LoggingConfig = Field(default_factory=LoggingConfig) api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None - oidc_config: OIDCConfig | None = None + oidc_config: OIDCConfig | CLIClientConfig | None = None def __eq__(self, other: object) -> bool: if isinstance(other, ApplicationConfig): diff --git a/src/blueapi/service/authentication.py b/src/blueapi/service/authentication.py index da2b00566..ff466dd46 100644 --- a/src/blueapi/service/authentication.py +++ b/src/blueapi/service/authentication.py @@ -26,7 +26,7 @@ def decode_jwt(self, token: str) -> dict[str, str]: return jwt.decode( token, signing_key.key, - algorithms=self._server_config.signing_algos, + algorithms=self._server_config.id_token_signing_alg_values_supported, verify=True, audience=self._server_config.client_audience, issuer=self._server_config.issuer, @@ -105,9 +105,8 @@ def logout(self) -> None: def refresh_auth_token(self) -> None: token = self._token_manager.load_token() response = requests.post( - self._server_config.token_url, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ + self._server_config.token_endpoint, + json={ "client_id": self._server_config.client_id, "grant_type": "refresh_token", "refresh_token": token["refresh_token"], @@ -124,9 +123,8 @@ def poll_for_token( expiry_time: float = time.time() + expires_in while time.time() < expiry_time: response = requests.post( - self._server_config.token_url, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ + self._server_config.token_endpoint, + json={ "grant_type": "urn:ietf:params:oauth:grant-type:device_code", "device_code": device_code, "client_id": self._server_config.client_id, @@ -138,24 +136,10 @@ def poll_for_token( raise TimeoutError("Polling timed out") - def start_device_flow(self) -> None: - token = self._token_manager.load_token() - try: - self.authenticator.decode_jwt(token["access_token"]) - print("Cached token still valid, skipping flow") - return - except jwt.ExpiredSignatureError: - token = self.refresh_auth_token() - print("Refreshed cached token, skipping flow") - return - except Exception: - print("Problem with cached token, starting new session") - self._token_manager.delete_token() - + def _do_device_flow(self) -> None: response: requests.Response = requests.post( - self._server_config.device_auth_url, - headers={"Content-Type": "application/x-www-form-urlencoded"}, - data={ + self._server_config.device_authorization_endpoint, + json={ "client_id": self._server_config.client_id, "scope": "openid profile offline_access", "audience": self._server_config.client_audience, @@ -175,7 +159,21 @@ def start_device_flow(self) -> None: auth_token_json: dict[str, Any] = self.poll_for_token( device_code, interval, expires_in ) - decoded_token: dict[str, Any] = self.authenticator.decode_jwt( - auth_token_json["access_token"] - ) - self._token_manager.save_token(decoded_token) + self._token_manager.save_token(auth_token_json) + + def start_device_flow(self) -> None: + try: + token = self._token_manager.load_token() + self.authenticator.decode_jwt(token["id_token"]) + print("Cached token still valid, skipping flow") + return + except jwt.ExpiredSignatureError: + token = self.refresh_auth_token() + print("Refreshed cached token, skipping flow") + return + except FileNotFoundError: + self._do_device_flow() + except Exception as e: + print(e) + print("Problem with cached token, starting new session") + self._token_manager.delete_token() diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 05f903242..7eeb2a6e3 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -108,9 +108,9 @@ def get_app(config: ApplicationConfig | None = None): def verify_access_token(config: OIDCConfig): oauth_scheme = OAuth2AuthorizationCodeBearer( - authorizationUrl=config.auth_url, - tokenUrl=config.token_url, - refreshUrl=config.token_url, + authorizationUrl=config.authorization_endpoint, + tokenUrl=config.token_endpoint, + refreshUrl=config.token_endpoint, ) authenticator = Authenticator(config) diff --git a/tests/conftest.py b/tests/conftest.py index 917d5cf06..bb40fd4e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,17 @@ import jwt import pytest import responses +import responses.matchers +import yaml from bluesky import RunEngine from bluesky.run_engine import TransitionError +from jwcrypto.jwk import JWK from observability_utils.tracing import setup_tracing from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from opentelemetry.trace import get_tracer_provider -from blueapi.config import CLIClientConfig +from blueapi.config import ApplicationConfig, CLIClientConfig from tests.unit_tests.utils.test_tracing import JsonObjectSpanExporter @@ -70,17 +73,12 @@ def oidc_config(valid_oidc_url: str, tmp_path: Path) -> CLIClientConfig: @pytest.fixture -def valid_auth_config(tmp_path: Path, valid_oidc_url: str) -> str: - config: str = f""" -oidc_config: - well_known_url: {valid_oidc_url} - client_id: "blueapi" - client_audience: "blueapi-cli" - token_file_path: {tmp_path / "auth_token"} -""" - with open(tmp_path / "auth_config.yaml", mode="w") as valid_auth_config_file: - valid_auth_config_file.write(config) - return valid_auth_config_file.name +def valid_auth_config(tmp_path: Path, oidc_config: CLIClientConfig) -> Path: + config = ApplicationConfig(oidc_config=oidc_config) + config_path = tmp_path / "auth_config.yaml" + with open(config_path, mode="w") as valid_auth_config_file: + valid_auth_config_file.write(yaml.dump(config.model_dump())) + return config_path @pytest.fixture @@ -96,20 +94,20 @@ def valid_oidc_config() -> dict[str, Any]: } -def _make_token(name: str, issued_in: float, expires_in: float, tmp_path: Path) -> Path: - token_path = tmp_path / "token" +@pytest.fixture(scope="session") +def json_web_keyset() -> JWK: + return JWK.generate(kty="RSA", size=1024, kid="secret", use="sig", alg="RSA256") - now = time.time() - RSA_key = """-----BEGIN RSA PRIVATE KEY----- -MIIBOgIBAAJBAKj34GkxFhD90vcNLYLInFEX6Ppy1tPf9Cnzj4p4WGeKLs1Pt8Qu -KUpRKfFLfRYC9AIKjbJTWit+CqvjWYzvQwECAwEAAQJAIJLixBy2qpFoS4DSmoEm -o3qGy0t6z09AIJtH+5OeRV1be+N4cDYJKffGzDa88vQENZiRm0GRq6a+HPGQMd2k -TQIhAKMSvzIBnni7ot/OSie2TmJLY4SwTQAevXysE2RbFDYdAiEBCUEaRQnMnbp7 -9mxDXDf6AU0cN/RPBjb9qSHDcWZHGzUCIG2Es59z8ugGrDY+pxLQnwfotadxd+Uy -v/Ow5T0q5gIJAiEAyS4RaI9YG8EWx/2w0T67ZUVAw8eOMB6BIUg0Xcu+3okCIBOs -/5OiPgoTdSy7bcF9IGpSE8ZgGKzgYQVZeN97YE00 ------END RSA PRIVATE KEY-----""" +@pytest.fixture(scope="session") +def rsa_private_key(json_web_keyset: JWK) -> str: + return json_web_keyset.export_to_pem("private_key", password=None).decode("utf-8") + + +def _make_token( + name: str, issued_in: float, expires_in: float, tmp_path: Path, rsa_private_key: str +) -> dict[str, str]: + now = time.time() id_token = { "aud": "default-demo", @@ -124,28 +122,99 @@ def _make_token(name: str, issued_in: float, expires_in: float, tmp_path: Path) "access_token": name, "token_type": "Bearer", "refresh_token": "refresh_token", - "id_token": f"{jwt.encode(id_token, key=RSA_key, algorithm="RS256")}", + "id_token": f"{jwt.encode(id_token, key=rsa_private_key, algorithm="RS256", headers={"kid": "secret"})}", } + return response + + +@pytest.fixture +def cached_expired_token(tmp_path: Path, expired_token: dict[str, Any]) -> Path: + token_path = tmp_path / "token" + token_json = json.dumps(expired_token) + with open(token_path, "w") as token_file: + token_file.write(base64.b64encode(token_json.encode("utf-8")).decode("utf-8")) + return token_path + + +@pytest.fixture +def cached_valid_token(tmp_path: Path, valid_token: dict[str, Any]) -> Path: + token_path = tmp_path / "token" + token_json = json.dumps(valid_token) with open(token_path, "w") as token_file: - token_file.write( - base64.b64encode(json.dumps(response).encode("utf-8")).decode("utf-8") - ) + token_file.write(base64.b64encode(token_json.encode("utf-8")).decode("utf-8")) return token_path @pytest.fixture -def expired_token(tmp_path: Path) -> Path: - return _make_token("expired_token", -3600, -1800, tmp_path) +def expired_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]: + return _make_token("expired_token", -3600, -1800, tmp_path, rsa_private_key) + + +@pytest.fixture +def valid_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]: + return _make_token("valid_token", -900, +900, tmp_path, rsa_private_key) @pytest.fixture -def valid_token(tmp_path: Path) -> Path: - return _make_token("expired_token", -1800, +1800, tmp_path) +def new_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]: + return _make_token("new_token", -100, +1700, tmp_path, rsa_private_key) @pytest.fixture -def mock_authn_server(valid_oidc_url: str, valid_oidc_config: dict[str, Any]): - requests_mock = responses.RequestsMock(assert_all_requests_are_fired=True) +def mock_authn_server( + valid_oidc_url: str, + valid_oidc_config: dict[str, Any], + oidc_config: CLIClientConfig, + valid_token: dict[str, Any], + json_web_keyset: JWK, + new_token: dict[str, Any], +): + requests_mock = responses.RequestsMock(assert_all_requests_are_fired=False) + # Fetch well-known OIDC flow URLs from server requests_mock.get(valid_oidc_url, json=valid_oidc_config) - requests_mock.get(valid_oidc_config["jwks_uri"], json="") + requests_mock.get( + valid_oidc_config["jwks_uri"], + json={"keys": [json_web_keyset.export_public(as_dict=True)]}, + ) + # When device flow begins, return a device_code + device_code = "ff83j3dk" + requests_mock.post( + valid_oidc_config["device_authorization_endpoint"], + json={ + "device_code": device_code, + "verification_uri_complete": valid_oidc_config["issuer"] + "/verify", + "expires_in": 30, + "interval": 5, + }, + ) + + # When polled with device_code return token + requests_mock.post( + valid_oidc_config["token_endpoint"], + json=valid_token, + match=[ + responses.matchers.json_params_matcher( + { + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": device_code, + "client_id": oidc_config.client_id, + } + ), + ], + ) + # When asked to refresh with refresh_token return refreshed token + requests_mock.post( + valid_oidc_config["token_endpoint"], + json=new_token, + match=[ + responses.matchers.json_params_matcher( + { + "client_id": oidc_config.client_id, + "grant_type": "refresh_token", + "refresh_token": "refresh_token", + }, + ) + ], + ) + return requests_mock diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 0e13529b3..644dfaf4a 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -13,7 +13,6 @@ from blueapi.client.event_bus import AnyEvent from blueapi.config import ( ApplicationConfig, - CLIClientConfig, OIDCConfig, StompConfig, ) diff --git a/tests/unit_tests/client/test_rest.py b/tests/unit_tests/client/test_rest.py index 8b99cfe00..6d2fa2098 100644 --- a/tests/unit_tests/client/test_rest.py +++ b/tests/unit_tests/client/test_rest.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Any from unittest.mock import Mock, patch import pytest @@ -6,7 +7,7 @@ from pydantic import BaseModel from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError -from blueapi.config import CLIClientConfig +from blueapi.config import OIDCConfig from blueapi.core.bluesky_types import Plan from blueapi.service.authentication import SessionManager from blueapi.service.model import PlanModel, PlanResponse @@ -18,16 +19,8 @@ def rest() -> BlueapiRestClient: @pytest.fixture -def rest_with_auth(valid_oidc_url: str, tmp_path: Path) -> BlueapiRestClient: - session_manager = SessionManager( - server_config=CLIClientConfig( - well_known_url=valid_oidc_url, - client_id="foo", - client_audience="bar", - token_file_path=tmp_path / "token", - ), - ) - return BlueapiRestClient(session_manager=session_manager) +def rest_with_auth(oidc_config: OIDCConfig) -> BlueapiRestClient: + return BlueapiRestClient(session_manager=SessionManager(oidc_config)) @pytest.mark.parametrize( @@ -58,15 +51,15 @@ class MyModel(BaseModel): def test_auth_request_functionality( rest_with_auth: BlueapiRestClient, - valid_token: Path, + mock_authn_server: responses.RequestsMock, + cached_valid_token: Path, ): plan = Plan(name="my-plan", model=MyModel) - mock_server = responses.RequestsMock() - mock_server.get( + mock_authn_server.get( "http://localhost:8000/plans", json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(), ) - with mock_server: + with mock_authn_server: result = rest_with_auth.get_plans() assert result == PlanResponse(plans=[PlanModel.from_plan(plan)]) @@ -74,12 +67,8 @@ def test_auth_request_functionality( def test_refresh_if_signature_expired( rest_with_auth: BlueapiRestClient, mock_authn_server: responses.RequestsMock, - expired_token: Path, + cached_expired_token: Path, ): - mock_authn_server.post( - "https://example.com/token", - json={"access_token": "new_token"}, - ) plan = Plan(name="my-plan", model=MyModel) mock_get_plans = ( diff --git a/tests/unit_tests/service/test_authentication.py b/tests/unit_tests/service/test_authentication.py index 0230cdee4..acf3fd452 100644 --- a/tests/unit_tests/service/test_authentication.py +++ b/tests/unit_tests/service/test_authentication.py @@ -1,6 +1,4 @@ -import base64 import os -from collections.abc import Callable from http import HTTPStatus from pathlib import Path from typing import Any @@ -11,159 +9,88 @@ from fastapi.exceptions import HTTPException from starlette.status import HTTP_403_FORBIDDEN -from blueapi.config import CLIClientConfig, OAuthClientConfig, OIDCConfig +from blueapi.config import CLIClientConfig, OIDCConfig from blueapi.service import main from blueapi.service.authentication import SessionManager @pytest.fixture -def configure_refreshed_token(mock_authn_server: responses.RequestsMock): - mock_authn_server.post( - url="https://example.com/token", - json={"access_token": "new_access_token"}, - ) - - return mock_authn_server - - -@pytest.fixture -def configure_device_flow_token(mock_authn_server: responses.RequestsMock): - mock_authn_server.post( - url="https://example.com/device_authorization", - json={"expires_in": 30, "interval": 5, "device_code": "device_code"}, - ) - mock_authn_server.post( - url="https://example.com/token", - json={"access_token": "access_token"}, - ) - - return mock_authn_server - - -@pytest.fixture -def configure_awaiting_token(mock_authn_server: responses.RequestsMock): - mock_authn_server.post( - url="https://example.com/token", - json={"error": "authorization_pending"}, - status=HTTP_403_FORBIDDEN, - ) - - return mock_authn_server - - -@pytest.fixture -def expired_token(tmp_path: Path): - token_path = tmp_path / "token" - with open(token_path, "w") as token_file: - # base64 encoded token - token_file.write( - base64.b64encode( - b'{"access_token":"expired_token","refresh_token":"refresh_token"}' - ).decode("utf-8") - ) - yield token_path - - -@pytest.fixture -def client_config(tmp_path: Path) -> OAuthClientConfig: - return CLIClientConfig( - client_id="client_id", - client_audience="client_audience", - token_file_path=tmp_path / "token", - ) - - -@pytest.fixture -def server_config(valid_oidc_url: str, mock_authn_server) -> OIDCConfig: - return OIDCConfig(well_known_url=valid_oidc_url) - - -@pytest.fixture -def session_manager( - client_config: OAuthClientConfig, server_config: OIDCConfig -) -> SessionManager: - return SessionManager(server_config, client_config) - - -@pytest.fixture -def connected_client_config(client_config: OAuthClientConfig) -> CLIClientConfig: - assert isinstance(client_config, CLIClientConfig) - with open(client_config.token_file_path, "w") as token_file: - # base64 encoded token - token_file.write( - base64.b64encode( - b'{"access_token":"token","refresh_token":"refresh_token"}' - ).decode("utf-8") - ) - return client_config +def session_manager(oidc_config: OIDCConfig) -> SessionManager: + return SessionManager(oidc_config) def test_logout( - session_manager: SessionManager, connected_client_config: CLIClientConfig + session_manager: SessionManager, + oidc_config: CLIClientConfig, + cached_valid_token: Path, ): - assert os.path.exists(connected_client_config.token_file_path) + assert os.path.exists(oidc_config.token_file_path) session_manager.logout() - assert not os.path.exists(connected_client_config.token_file_path) + assert not os.path.exists(oidc_config.token_file_path) def test_refresh_auth_token( - configure_refreshed_token: responses.RequestsMock, + mock_authn_server: responses.RequestsMock, session_manager: SessionManager, - expired_token: Path, + cached_expired_token: Path, ): token = session_manager.get_token() assert token and token["access_token"] == "expired_token" - with configure_refreshed_token: + with mock_authn_server: session_manager.refresh_auth_token() token = session_manager.get_token() - assert token and token["access_token"] == "new_access_token" + assert token and token["access_token"] == "new_token" def test_poll_for_token( mock_authn_server: responses.RequestsMock, session_manager: SessionManager, + valid_token: dict[str, Any], ): - mock_authn_server.post( - url="https://example.com/token", - json={"access_token": "access_token"}, - ) with mock_authn_server: token = session_manager.poll_for_token("device_code", 1, 2) - assert token == {"access_token": "access_token"} + assert token == valid_token @patch("time.sleep") def test_poll_for_token_timeout( mock_sleep, - configure_awaiting_token, + mock_authn_server: responses.RequestsMock, session_manager: SessionManager, ): - with pytest.raises(TimeoutError), configure_awaiting_token: + mock_authn_server.post( + url="https://example.com/token", + json={"error": "authorization_pending"}, + status=HTTP_403_FORBIDDEN, + ) + with pytest.raises(TimeoutError), mock_authn_server: session_manager.poll_for_token("device_code", 1, 2) def test_valid_token_access_granted( - server_config: OIDCConfig, + oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock, + valid_token: dict[str, Any], ): with mock_authn_server: - main.verify_access_token(server_config)("token") + main.verify_access_token(oidc_config)(valid_token["access_token"]) def test_invalid_token_no_access( - server_config: OIDCConfig, + oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock, ): with pytest.raises(HTTPException) as exec, mock_authn_server: - main.verify_access_token(server_config)("bad_token") + main.verify_access_token(oidc_config)("bad_token") assert exec.value.status_code == HTTPStatus.UNAUTHORIZED def test_expired_token_no_access( - server_config: OIDCConfig, + oidc_config: OIDCConfig, mock_authn_server: responses.RequestsMock, + expired_token: dict[str, Any], ): with pytest.raises(HTTPException) as exec, mock_authn_server: - main.verify_access_token(server_config)("expired_token") + main.verify_access_token(oidc_config)(expired_token["access_token"]) assert exec.value.status_code == HTTPStatus.UNAUTHORIZED diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index 3a726b9d6..20a530a70 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -634,25 +634,9 @@ def test_logout_missing_config(runner: CliRunner): def test_login_success( runner: CliRunner, - valid_auth_config: str, + valid_auth_config: Path, mock_authn_server: responses.RequestsMock, - valid_oidc_config: dict[str, Any], ): - mock_authn_server.post( - valid_oidc_config["device_authorization_endpoint"], - json={ - "device_code": "device_code", - "verification_uri_complete": "https://example.com/verify", - "expires_in": 30, - "interval": 5, - }, - ) - mock_authn_server.post( - valid_oidc_config["token_endpoint"], - json={ - "access_token": "token", - }, - ) with mock_authn_server: result = runner.invoke(main, ["-c", valid_auth_config, "login"]) assert ( @@ -665,28 +649,22 @@ def test_login_success( def test_token_login_early_exit( runner: CliRunner, - valid_auth_config: str, - valid_token: Path, + valid_auth_config: Path, + mock_authn_server: responses.RequestsMock, + cached_valid_token: Path, ): - result = runner.invoke(main, ["-c", valid_auth_config, "login"]) + with mock_authn_server: + result = runner.invoke(main, ["-c", valid_auth_config, "login"]) assert "Logging in\nCached token still valid, skipping flow\n" == result.output assert result.exit_code == 0 def test_login_with_refresh_token( runner: CliRunner, - valid_auth_config: str, + valid_auth_config: Path, mock_authn_server: responses.RequestsMock, - valid_oidc_config: dict[str, Any], - expired_token: Path, + cached_expired_token: Path, ): - mock_authn_server.post( - valid_oidc_config["token_endpoint"], - json={ - "access_token": "token", - }, - ) - with mock_authn_server: result = runner.invoke(main, ["-c", valid_auth_config, "login"]) @@ -696,7 +674,7 @@ def test_login_with_refresh_token( def test_login_edge_cases( runner: CliRunner, - valid_auth_config: str, + valid_auth_config: Path, mock_authn_server: responses.RequestsMock, valid_oidc_config: dict[str, Any], ): @@ -712,9 +690,13 @@ def test_login_edge_cases( assert result.exit_code == 0 -def test_logout_success(runner: CliRunner, valid_auth_config: str, expired_token: Path): - assert expired_token.exists() +def test_logout_success( + runner: CliRunner, + valid_auth_config: Path, + cached_expired_token: Path, +): + assert cached_expired_token.exists() result = runner.invoke(main, ["-c", valid_auth_config, "logout"]) assert "Logged out" in result.output assert result.exit_code == 0 - assert not expired_token.exists() + assert not cached_expired_token.exists() diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index d1972cd3c..a8665f3e0 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -375,17 +375,22 @@ def mock_server_response(valid_oidc_url: str, valid_oidc_config: dict[str, Any]) def test_oauth_config_model_post_init( valid_oidc_url: str, valid_oidc_config: dict[str, Any], + oidc_config: OIDCConfig, mock_server_response: responses.RequestsMock, ): - oauth_config = OIDCConfig(well_known_url=valid_oidc_url) - with mock_server_response: assert ( - oauth_config.device_auth_url + oidc_config.device_authorization_endpoint == valid_oidc_config["device_authorization_endpoint"] ) - assert oauth_config.auth_url == valid_oidc_config["authorization_endpoint"] - assert oauth_config.token_url == valid_oidc_config["token_endpoint"] - assert oauth_config.issuer == valid_oidc_config["issuer"] - assert oauth_config.jwks_uri == valid_oidc_config["jwks_uri"] - assert oauth_config.logout_url == valid_oidc_config["end_session_endpoint"] + assert ( + oidc_config.authorization_endpoint + == valid_oidc_config["authorization_endpoint"] + ) + assert oidc_config.token_endpoint == valid_oidc_config["token_endpoint"] + assert oidc_config.issuer == valid_oidc_config["issuer"] + assert oidc_config.jwks_uri == valid_oidc_config["jwks_uri"] + assert ( + oidc_config.end_session_endpoint + == valid_oidc_config["end_session_endpoint"] + )