Skip to content

Commit

Permalink
simplified tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZohebShaikh committed Nov 27, 2024
1 parent ea471d8 commit 27c12ce
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 103 deletions.
27 changes: 4 additions & 23 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _make_token(


@pytest.fixture
def expired_cache(
def cached_valid_refresh(
tmp_path: Path, expired_token: dict[str, Any], oidc_config: OIDCConfig
) -> Path:
cache_path = tmp_path / CACHE_FILE
Expand All @@ -153,33 +153,14 @@ def expired_cache(


@pytest.fixture
def cached_invalid_token(
tmp_path: Path, expired_token: dict[str, Any], oidc_config: OIDCConfig
) -> Path:
cache_path = tmp_path / CACHE_FILE
cache = Cache(
oidc_config=oidc_config,
access_token="Invalid Token",
refresh_token=expired_token["refresh_token"],
id_token=expired_token["id_token"],
)
cache_json = cache.model_dump_json()
cache_base64 = base64.b64encode(cache_json.encode("utf-8"))

with open(cache_path, "xb") as cache_file:
cache_file.write(cache_base64)
return cache_path


@pytest.fixture
def cached_invalid_refresh(
def cached_expired_refresh(
tmp_path: Path, expired_token: dict[str, Any], oidc_config: OIDCConfig
) -> Path:
cache_path = tmp_path / CACHE_FILE
cache = Cache(
oidc_config=oidc_config,
access_token=expired_token["access_token"],
refresh_token="invalid_refresh",
refresh_token="expired_refresh",
id_token=expired_token["id_token"],
)
cache_json = cache.model_dump_json()
Expand Down Expand Up @@ -303,7 +284,7 @@ def mock_authn_server(
{
"client_id": oidc_config.client_id,
"grant_type": "refresh_token",
"refresh_token": "invalid_refresh",
"refresh_token": "expired_refresh",
},
)
],
Expand Down
68 changes: 0 additions & 68 deletions tests/unit_tests/client/test_rest.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,16 @@
from pathlib import Path
from unittest.mock import Mock, patch

import pytest
import responses
from pydantic import BaseModel

from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError
from blueapi.core.bluesky_types import Plan
from blueapi.service.authentication import (
SessionCacheManager,
SessionManager,
)
from blueapi.service.model import OIDCConfig, PlanModel, PlanResponse


@pytest.fixture
def rest() -> BlueapiRestClient:
return BlueapiRestClient()


@pytest.fixture
def rest_with_auth(oidc_config: OIDCConfig, tmp_path) -> BlueapiRestClient:
return BlueapiRestClient(
session_manager=SessionManager(
server_config=oidc_config,
cache_manager=SessionCacheManager(tmp_path / "blueapi_cache"),
)
)


@pytest.fixture
def mock_authn_server_with_plan(mock_authn_server):
plan = Plan(name="my-plan", model=MyModel)
mock_authn_server.stop()
mock_get_plans = mock_authn_server.get(
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
)
return mock_get_plans


@pytest.mark.parametrize(
"code,expected_exception",
[
Expand All @@ -65,42 +36,3 @@ def test_rest_error_code(

class MyModel(BaseModel):
id: str


def test_auth_request_functionality(
rest_with_auth: BlueapiRestClient,
mock_authn_server: responses.RequestsMock,
cached_valid_token: Path,
):
plan = Plan(name="my-plan", model=MyModel)
mock_authn_server.stop() # Cannot use multiple RequestsMock context manager
mock_authn_server.get(
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
)
result = None
with mock_authn_server:
result = rest_with_auth.get_plans()
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])


def test_refresh_if_signature_expired(
rest_with_auth: BlueapiRestClient,
mock_authn_server: responses.RequestsMock,
expired_cache: Path,
):
plan = Plan(name="my-plan", model=MyModel)
mock_authn_server.stop()
mock_get_plans = (
mock_authn_server.get( # Cannot use multiple RequestsMock context manager
"http://localhost:8000/plans",
json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(),
)
)
result = None
with mock_authn_server:
result = rest_with_auth.get_plans()
assert result == PlanResponse(plans=[PlanModel.from_plan(plan)])
calls = mock_get_plans.calls
assert len(calls) == 1
# assert calls[0].request.headers["Authorization"] == "Bearer new_token"
8 changes: 4 additions & 4 deletions tests/unit_tests/service/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_logout(
def test_refresh_auth_token(
mock_authn_server: responses.RequestsMock,
session_manager: SessionManager,
expired_cache: Path,
cached_valid_refresh: Path,
):
token = session_manager.get_valid_access_token()
assert token == "new_token"
Expand All @@ -61,12 +61,12 @@ def test_get_empty_token_if_no_cache(session_manager: SessionManager):
def test_get_empty_token_if_refresh_fails(
mock_authn_server: responses.RequestsMock,
session_manager: SessionManager,
cached_invalid_refresh: Path,
cached_expired_refresh: Path,
):
assert cached_invalid_refresh.exists()
assert cached_expired_refresh.exists()
token = session_manager.get_valid_access_token()
assert token == ""
assert not cached_invalid_refresh.exists()
assert not cached_expired_refresh.exists()


def test_poll_for_token(
Expand Down
16 changes: 8 additions & 8 deletions tests/unit_tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ def test_login_with_refresh_token(
runner: CliRunner,
config_with_auth: str,
mock_authn_server: responses.RequestsMock,
expired_cache: Path,
cached_valid_refresh: Path,
):
result = runner.invoke(main, ["-c", config_with_auth, "login"])

Expand All @@ -660,7 +660,7 @@ def test_login_when_cached_token_decode_fails(
runner: CliRunner,
config_with_auth: str,
mock_authn_server: responses.RequestsMock,
cached_invalid_token: Path,
cached_expired_refresh: Path,
):
result = runner.invoke(main, ["-c", config_with_auth, "login"])
assert (
Expand All @@ -675,24 +675,24 @@ def test_login_when_cached_token_decode_fails(
def test_logout_success(
runner: CliRunner,
config_with_auth: str,
expired_cache: Path,
cached_valid_refresh: Path,
mock_authn_server: responses.RequestsMock,
):
assert expired_cache.exists()
assert cached_valid_refresh.exists()
result = runner.invoke(main, ["-c", config_with_auth, "logout"])
assert "Logged out" in result.output
assert not expired_cache.exists()
assert not cached_valid_refresh.exists()


def test_local_cache_cleared_oidc_unavailable(
runner: CliRunner,
config_with_auth: str,
expired_cache: Path,
cached_valid_refresh: Path,
):
assert expired_cache.exists()
assert cached_valid_refresh.exists()
result = runner.invoke(main, ["-c", config_with_auth, "logout"])
assert (
"An unexpected error occurred while attempting to log out from the server."
in result.output
)
assert not expired_cache.exists()
assert not cached_valid_refresh.exists()

0 comments on commit 27c12ce

Please sign in to comment.