Skip to content

Commit

Permalink
Test changes for mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Nov 11, 2024
1 parent b8a7421 commit cf632d0
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 43 deletions.
5 changes: 4 additions & 1 deletion src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,10 @@ def login(obj: dict) -> None:
if isinstance(config.oidc_config, CLIClientConfig):
print("Logging in")
auth: SessionManager = SessionManager(config.oidc_config)
auth.start_device_flow()
try:
auth.start_device_flow()
except Exception:
print("Failed to login")
else:
print("Please provide configuration to login!")

Expand Down
5 changes: 2 additions & 3 deletions src/blueapi/service/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def save_token(self, token: dict[str, Any]) -> None:
token_base64: bytes = base64.b64encode(token_json.encode("utf-8"))
with open(self._file_path(), "wb") as token_file:
token_file.write(token_base64)
print("Logged in and cached new token")

def load_token(self) -> dict[str, Any]:
file_path = self._file_path()
Expand Down Expand Up @@ -160,6 +159,7 @@ def _do_device_flow(self) -> None:
device_code, interval, expires_in
)
self._token_manager.save_token(auth_token_json)
print("Logged in and cached new token")

def start_device_flow(self) -> None:
try:
Expand All @@ -173,7 +173,6 @@ def start_device_flow(self) -> None:
return
except FileNotFoundError:
self._do_device_flow()
except Exception as e:
print(e)
except Exception:
print("Problem with cached token, starting new session")
self._token_manager.delete_token()
61 changes: 39 additions & 22 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from pathlib import Path
from typing import Any, cast
from unittest.mock import Mock, patch

# Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501
import jwt
Expand Down Expand Up @@ -59,15 +60,17 @@ def exporter(provider: TracerProvider) -> JsonObjectSpanExporter:

@pytest.fixture
def valid_oidc_url() -> str:
return "https://auth.example.com/realms/sample/.well-known/openid-configuration"
return (
"https://auth.example.com/realms/master/oidc/.well-known/openid-configuration"
)


@pytest.fixture
def oidc_config(valid_oidc_url: str, tmp_path: Path) -> CLIClientConfig:
return CLIClientConfig(
well_known_url=valid_oidc_url,
client_id="example-client",
client_audience="example",
client_id="blueapi-client",
client_audience="blueapi",
token_file_path=tmp_path / "token",
)

Expand All @@ -89,14 +92,14 @@ def valid_oidc_config() -> dict[str, Any]:
"token_endpoint": "https://example.com/token",
"issuer": "https://example.com",
"jwks_uri": "https://example.com/realms/master/protocol/openid-connect/certs",
"end_session_endpoint": "https://example.com/logout",
"id_token_signing_alg_values_supported": ["RS256", "RS384", "RS512"],
"end_session_endpoint": "https://example.com/end_session",
"id_token_signing_alg_values_supported": ["RS256"],
}


@pytest.fixture(scope="session")
def json_web_keyset() -> JWK:
return JWK.generate(kty="RSA", size=1024, kid="secret", use="sig", alg="RSA256")
return JWK.generate(kty="RSA", size=1024, kid="secret", use="sig", alg="RS256")


@pytest.fixture(scope="session")
Expand All @@ -105,24 +108,30 @@ def rsa_private_key(json_web_keyset: JWK) -> str:


def _make_token(
name: str, issued_in: float, expires_in: float, tmp_path: Path, rsa_private_key: str
name: str, issued_in: float, expires_in: float, rsa_private_key: str
) -> dict[str, str]:
now = time.time()

id_token = {
"aud": "default-demo",
"aud": "blueapi",
"exp": now + expires_in,
"iat": now + issued_in,
"iss": "https://example.com",
"sub": "jd1",
"name": "Jane Doe",
"fedid": "jd1",
}
id_token_encoded = jwt.encode(
id_token,
key=rsa_private_key,
algorithm="RS256",
headers={"kid": "secret"},
)
response = {
"access_token": name,
"token_type": "Bearer",
"refresh_token": "refresh_token",
"id_token": f"{jwt.encode(id_token, key=rsa_private_key, algorithm="RS256", headers={"kid": "secret"})}",
"id_token": id_token_encoded,
}
return response

Expand All @@ -146,18 +155,23 @@ def cached_valid_token(tmp_path: Path, valid_token: dict[str, Any]) -> Path:


@pytest.fixture
def expired_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("expired_token", -3600, -1800, tmp_path, rsa_private_key)
def expired_token(rsa_private_key: str) -> dict[str, Any]:
return _make_token("expired_token", -3600, -1800, rsa_private_key)


@pytest.fixture
def valid_token(rsa_private_key: str) -> dict[str, Any]:
return _make_token("valid_token", -900, +900, rsa_private_key)


@pytest.fixture
def valid_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("valid_token", -900, +900, tmp_path, rsa_private_key)
def new_token(rsa_private_key: str) -> dict[str, Any]:
return _make_token("new_token", -100, +1700, rsa_private_key)


@pytest.fixture
def new_token(tmp_path: Path, rsa_private_key: str) -> dict[str, Any]:
return _make_token("new_token", -100, +1700, tmp_path, rsa_private_key)
def device_code() -> str:
return "ff83j3dk"


@pytest.fixture
Expand All @@ -166,18 +180,14 @@ def mock_authn_server(
valid_oidc_config: dict[str, Any],
oidc_config: CLIClientConfig,
valid_token: dict[str, Any],
json_web_keyset: JWK,
new_token: dict[str, Any],
device_code: str,
mock_jwks_fetch,
):
requests_mock = responses.RequestsMock(assert_all_requests_are_fired=False)
# Fetch well-known OIDC flow URLs from server
requests_mock.get(valid_oidc_url, json=valid_oidc_config)
requests_mock.get(
valid_oidc_config["jwks_uri"],
json={"keys": [json_web_keyset.export_public(as_dict=True)]},
)
# When device flow begins, return a device_code
device_code = "ff83j3dk"
requests_mock.post(
valid_oidc_config["device_authorization_endpoint"],
json={
Expand Down Expand Up @@ -217,4 +227,11 @@ def mock_authn_server(
],
)

return requests_mock
with mock_jwks_fetch, requests_mock:
yield requests_mock


@pytest.fixture
def mock_jwks_fetch(json_web_keyset: JWK):
mock = Mock(return_value={"keys": [json_web_keyset.export_public(as_dict=True)]})
return patch("jwt.PyJWKClient.fetch_data", mock)
5 changes: 3 additions & 2 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pathlib import Path
from typing import Any
from unittest.mock import Mock, patch

import pytest
Expand Down Expand Up @@ -55,6 +54,7 @@ def test_auth_request_functionality(
cached_valid_token: Path,
):
plan = Plan(name="my-plan", model=MyModel)
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_authn_server.get(
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
Expand All @@ -71,8 +71,9 @@ def test_refresh_if_signature_expired(
):
plan = Plan(name="my-plan", model=MyModel)

mock_authn_server.stop()
mock_get_plans = (
mock_authn_server.get( # Cannot use more than 1 RequestsMock context manager
mock_authn_server.get( # Cannot use multiple RequestsMock context manager
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
)
Expand Down
16 changes: 8 additions & 8 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def test_refresh_auth_token(
token = session_manager.get_token()
assert token and token["access_token"] == "expired_token"

with mock_authn_server:
session_manager.refresh_auth_token()
session_manager.refresh_auth_token()
token = session_manager.get_token()
assert token and token["access_token"] == "new_token"

Expand All @@ -47,9 +46,9 @@ def test_poll_for_token(
mock_authn_server: responses.RequestsMock,
session_manager: SessionManager,
valid_token: dict[str, Any],
device_code: str,
):
with mock_authn_server:
token = session_manager.poll_for_token("device_code", 1, 2)
token = session_manager.poll_for_token(device_code, 1, 2)
assert token == valid_token


Expand All @@ -58,23 +57,24 @@ def test_poll_for_token_timeout(
mock_sleep,
mock_authn_server: responses.RequestsMock,
session_manager: SessionManager,
device_code: str,
):
mock_authn_server.stop()
mock_authn_server.post(
url="https://example.com/token",
json={"error": "authorization_pending"},
status=HTTP_403_FORBIDDEN,
)
with pytest.raises(TimeoutError), mock_authn_server:
session_manager.poll_for_token("device_code", 1, 2)
session_manager.poll_for_token(device_code, 1, 2)


def test_valid_token_access_granted(
oidc_config: OIDCConfig,
mock_authn_server: responses.RequestsMock,
valid_token: dict[str, Any],
):
with mock_authn_server:
main.verify_access_token(oidc_config)(valid_token["access_token"])
main.verify_access_token(oidc_config)(valid_token["id_token"])


def test_invalid_token_no_access(
Expand All @@ -92,5 +92,5 @@ def test_expired_token_no_access(
expired_token: dict[str, Any],
):
with pytest.raises(HTTPException) as exec, mock_authn_server:
main.verify_access_token(oidc_config)(expired_token["access_token"])
main.verify_access_token(oidc_config)(expired_token["id_token"])
assert exec.value.status_code == HTTPStatus.UNAUTHORIZED
15 changes: 8 additions & 7 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from collections.abc import Callable, Mapping
from collections.abc import Mapping
from dataclasses import dataclass
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -637,8 +637,7 @@ def test_login_success(
valid_auth_config: Path,
mock_authn_server: responses.RequestsMock,
):
with mock_authn_server:
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
assert (
"Logging in\n"
"Please login from this URL:- https://example.com/verify\n"
Expand All @@ -653,8 +652,7 @@ def test_token_login_early_exit(
mock_authn_server: responses.RequestsMock,
cached_valid_token: Path,
):
with mock_authn_server:
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
assert "Logging in\nCached token still valid, skipping flow\n" == result.output
assert result.exit_code == 0

Expand All @@ -665,8 +663,7 @@ def test_login_with_refresh_token(
mock_authn_server: responses.RequestsMock,
cached_expired_token: Path,
):
with mock_authn_server:
result = runner.invoke(main, ["-c", valid_auth_config, "login"])
result = runner.invoke(main, ["-c", valid_auth_config, "login"])

assert "Logging in\nRefreshed cached token, skipping flow\n" == result.output
assert result.exit_code == 0
Expand All @@ -678,6 +675,10 @@ def test_login_edge_cases(
mock_authn_server: responses.RequestsMock,
valid_oidc_config: dict[str, Any],
):
mock_authn_server.stop()
mock_authn_server.remove(
responses.POST, url=valid_oidc_config["device_authorization_endpoint"]
)
mock_authn_server.post(
valid_oidc_config["device_authorization_endpoint"],
json={"details": "not found"},
Expand Down

0 comments on commit cf632d0

Please sign in to comment.