From 6c5a4eeed5705ff71d0cfa026ea340092da71aa1 Mon Sep 17 00:00:00 2001 From: Rob Aleck Date: Tue, 23 Jul 2024 10:10:24 +0100 Subject: [PATCH] Feat/include both maxage and s maxage headers (#60) adds support for handling both 'max-age' and 's-maxage' headers in cache control, introduces a new 'shared_expires' attribute to various models, and updates the test suite to cover these changes. --- pydantic_tfl_api/client.py | 60 +++-- pydantic_tfl_api/models/disruption.py | 1 + pydantic_tfl_api/models/line.py | 1 + pydantic_tfl_api/models/mode.py | 1 + pydantic_tfl_api/models/prediction.py | 1 + pydantic_tfl_api/models/route_sequence.py | 1 + pydantic_tfl_api/models/stop_point.py | 1 + .../models/stop_points_response.py | 1 + tests/test_client.py | 253 +++++++++++++++--- 9 files changed, 272 insertions(+), 48 deletions(-) diff --git a/pydantic_tfl_api/client.py b/pydantic_tfl_api/client.py index 368ecf1..26a1d8f 100644 --- a/pydantic_tfl_api/client.py +++ b/pydantic_tfl_api/client.py @@ -25,7 +25,7 @@ from .config import endpoints from .rest_client import RestClient from importlib import import_module -from typing import Any, Literal, List, Optional +from typing import Any, Literal, List, Optional, Tuple from requests import Response import pkgutil from pydantic import BaseModel @@ -55,29 +55,51 @@ def _load_models(self): # print(models_dict) return models_dict - def _get_s_maxage_from_cache_control_header(self, response: Response) -> int | None: + @staticmethod + def _parse_int_or_none(value: str) -> int | None: + try: + return int(value) + except (TypeError, ValueError): + return None + + @staticmethod + def _get_maxage_headers_from_cache_control_header(response: Response) -> Tuple[Optional[int], Optional[int]]: cache_control = response.headers.get("cache-control") # e.g. 'public, must-revalidate, max-age=43200, s-maxage=86400' if cache_control is None: - return None - directives = cache_control.split(" ") - # e.g. ['public,', 'must-revalidate,', 'max-age=43200,', 's-maxage=86400'] + return None, None + directives = cache_control.split(", ") + # e.g. ['public', 'must-revalidate', 'max-age=43200', 's-maxage=86400'] directives = {d.split("=")[0]: d.split("=")[1] for d in directives if "=" in d} - return None if "s-maxage" not in directives else int(directives["s-maxage"]) + smaxage = Client._parse_int_or_none(directives.get("s-maxage", "")) + maxage = Client._parse_int_or_none(directives.get("max-age", "")) + return smaxage, maxage + - def _get_result_expiry(self, response: Response) -> datetime | None: - s_maxage = self._get_s_maxage_from_cache_control_header(response) + + @staticmethod + def _parse_timedelta(value: Optional[int], base_time: Optional[datetime]) -> Optional[datetime]: + try: + return base_time + timedelta(seconds=value) if value is not None and base_time is not None else None + except (TypeError, ValueError): + return None + + @staticmethod + def _get_result_expiry(response: Response) -> Tuple[ datetime | None, datetime | None]: + s_maxage, maxage = Client._get_maxage_headers_from_cache_control_header(response) request_datetime = parsedate_to_datetime(response.headers.get("date")) if "date" in response.headers else None - if s_maxage and request_datetime: - return request_datetime + timedelta(seconds=s_maxage) - return None + + s_maxage_expiry = Client._parse_timedelta(s_maxage, request_datetime) + maxage_expiry = Client._parse_timedelta(maxage, request_datetime) + + return s_maxage_expiry, maxage_expiry def _deserialize(self, model_name: str, response: Response) -> Any: - result_expiry = self._get_result_expiry(response) + shared_expiry, result_expiry = self._get_result_expiry(response) Model = self._get_model(model_name) data = response.json() - result = self._create_model_instance(Model, data, result_expiry) + result = self._create_model_instance(Model, data, result_expiry, shared_expiry) return result @@ -88,21 +110,25 @@ def _get_model(self, model_name: str) -> BaseModel: return Model def _create_model_instance( - self, Model: BaseModel, response_json: Any, result_expiry: datetime | None + self, Model: BaseModel, + response_json: Any, + result_expiry: datetime | None, + shared_expiry: datetime | None ) -> BaseModel | List[BaseModel]: if isinstance(response_json, dict): - return self._create_model_with_expiry(Model, response_json, result_expiry) + return self._create_model_with_expiry(Model, response_json, result_expiry, shared_expiry) else: return [ - self._create_model_with_expiry(Model, item, result_expiry) + self._create_model_with_expiry(Model, item, result_expiry, shared_expiry) for item in response_json ] def _create_model_with_expiry( - self, Model: BaseModel, response_json: Any, result_expiry: datetime | None + self, Model: BaseModel, response_json: Any, result_expiry: Optional[datetime], shared_expiry: Optional[datetime] ): instance = Model(**response_json) instance.content_expires = result_expiry + instance.shared_expires = shared_expiry return instance def _deserialize_error(self, response: Response) -> models.ApiError: diff --git a/pydantic_tfl_api/models/disruption.py b/pydantic_tfl_api/models/disruption.py index 8968d25..424fd37 100644 --- a/pydantic_tfl_api/models/disruption.py +++ b/pydantic_tfl_api/models/disruption.py @@ -19,6 +19,7 @@ class Disruption(BaseModel): affected_stops: List[StopPoint] = Field(alias='affectedStops') closure_text: str = Field(alias='closureText') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} diff --git a/pydantic_tfl_api/models/line.py b/pydantic_tfl_api/models/line.py index ef0d99d..0b6ea1d 100644 --- a/pydantic_tfl_api/models/line.py +++ b/pydantic_tfl_api/models/line.py @@ -21,5 +21,6 @@ class Line(BaseModel): service_types: Optional[List[ServiceType]] = Field(None, alias='serviceTypes') crowding: Optional[Crowding] = Field(None, alias='crowding') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} diff --git a/pydantic_tfl_api/models/mode.py b/pydantic_tfl_api/models/mode.py index c7f1d87..e13570d 100644 --- a/pydantic_tfl_api/models/mode.py +++ b/pydantic_tfl_api/models/mode.py @@ -8,5 +8,6 @@ class Mode(BaseModel): is_scheduled_service: bool = Field(alias='isScheduledService') mode_name: str = Field(alias='modeName') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} \ No newline at end of file diff --git a/pydantic_tfl_api/models/prediction.py b/pydantic_tfl_api/models/prediction.py index 7bca41e..4925ec8 100644 --- a/pydantic_tfl_api/models/prediction.py +++ b/pydantic_tfl_api/models/prediction.py @@ -26,5 +26,6 @@ class Prediction(BaseModel): mode_name: str = Field(alias='modeName') timing: PredictionTiming = Field(alias='timing') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} \ No newline at end of file diff --git a/pydantic_tfl_api/models/route_sequence.py b/pydantic_tfl_api/models/route_sequence.py index fc0ead0..24b3d51 100644 --- a/pydantic_tfl_api/models/route_sequence.py +++ b/pydantic_tfl_api/models/route_sequence.py @@ -16,5 +16,6 @@ class RouteSequence(BaseModel): service_type: Optional[str] = Field(None, alias='serviceType') ordered_line_routes: list[OrderedRoute] = Field(alias='orderedLineRoutes') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} diff --git a/pydantic_tfl_api/models/stop_point.py b/pydantic_tfl_api/models/stop_point.py index c5d6577..2864794 100644 --- a/pydantic_tfl_api/models/stop_point.py +++ b/pydantic_tfl_api/models/stop_point.py @@ -40,6 +40,7 @@ class StopPoint(BaseModel): lat: float = Field(alias="lat") lon: float = Field(alias="lon") content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {"populate_by_name": True} diff --git a/pydantic_tfl_api/models/stop_points_response.py b/pydantic_tfl_api/models/stop_points_response.py index ef23e57..658723a 100644 --- a/pydantic_tfl_api/models/stop_points_response.py +++ b/pydantic_tfl_api/models/stop_points_response.py @@ -10,5 +10,6 @@ class StopPointsResponse(BaseModel): total: int = Field(alias='total') page: int = Field(alias='page') content_expires: Optional[datetime] = Field(None) + shared_expires: Optional[datetime] = Field(None) model_config = {'populate_by_name': True} diff --git a/tests/test_client.py b/tests/test_client.py index f558722..c34ff62 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,36 +24,61 @@ class PydanticTestModel(BaseModel): name: str age: int content_expires: datetime | None = None + shared_expires: datetime | None = None @pytest.mark.parametrize( - "Model, response_json, result_expiry, expected_name, expected_age, expected_expiry", + "Model, response_json, result_expiry, shared_expiry, expected_name, expected_age, expected_expiry, expected_shared_expiry", [ # Happy path tests ( PydanticTestModel, {"name": "Alice", "age": 30}, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), "Alice", 30, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), + ), + ( + PydanticTestModel, + {"name": "Bob", "age": 25}, + None, + None, + "Bob", + 25, + None, + None, ), - (PydanticTestModel, {"name": "Bob", "age": 25}, None, "Bob", 25, None), # Edge cases ( PydanticTestModel, {"name": "", "age": 0}, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), "", 0, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), + ), + ( + PydanticTestModel, + {"name": "Charlie", "age": -1}, + None, + None, + "Charlie", + -1, + None, + None, ), - (PydanticTestModel, {"name": "Charlie", "age": -1}, None, "Charlie", -1, None), # Error cases ( PydanticTestModel, {"name": "Alice"}, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), + None, None, None, None, @@ -62,6 +87,8 @@ class PydanticTestModel(BaseModel): PydanticTestModel, {"age": 30}, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), + None, None, None, None, @@ -70,6 +97,8 @@ class PydanticTestModel(BaseModel): PydanticTestModel, {"name": "Alice", "age": "thirty"}, datetime(2023, 12, 31), + datetime(2024, 12, 31, 23, 59, 59), + None, None, None, None, @@ -86,22 +115,23 @@ class PydanticTestModel(BaseModel): ], ) def test_create_model_with_expiry( - Model, response_json, result_expiry, expected_name, expected_age, expected_expiry + Model, response_json, result_expiry, shared_expiry, expected_name, expected_age, expected_expiry, expected_shared_expiry ): # Act if expected_name is None: with pytest.raises(ValidationError): - Client._create_model_with_expiry(None, Model, response_json, result_expiry) + Client._create_model_with_expiry(None, Model, response_json, result_expiry, shared_expiry) else: instance = Client._create_model_with_expiry( - None, Model, response_json, result_expiry + None, Model, response_json, result_expiry, shared_expiry ) # Assert assert instance.name == expected_name assert instance.age == expected_age assert instance.content_expires == expected_expiry + assert instance.shared_expires == expected_shared_expiry @pytest.mark.parametrize( @@ -178,21 +208,70 @@ def test_load_models(models_dict, expected_result): @pytest.mark.parametrize( "cache_control_header, expected_result", [ + # s-maxage present and valid ( "public, must-revalidate, max-age=43200, s-maxage=86400", - 86400, + (86400, 43200), ), + # s-maxage absent, only max-age present ( "public, must-revalidate, max-age=43200", - None, + (None, 43200), ), + # No cache-control header ( None, - None, + (None, None), ), + # Negative s-maxage value ( "public, must-revalidate, max-age=43200, s-maxage=-1", - -1, + (-1, 43200), + ), + # No max-age or s-maxage present + ( + "public, must-revalidate", + (None, None), + ), + # Only s-maxage present + ( + "public, s-maxage=86400", + (86400, None), + ), + # Both max-age and s-maxage zero + ( + "public, max-age=0, s-maxage=0", + (0, 0), + ), + # Malformed max-age directive + ( + "public, must-revalidate, max-age=foo, s-maxage=86400", + (86400, None), + ), + # Malformed s-maxage directive + ( + "public, must-revalidate, max-age=43200, s-maxage=bar", + (None, 43200), + ), + # Only s-maxage without a value + ( + "public, must-revalidate, s-maxage=", + (None, None), + ), + # Only max-age without a value + ( + "public, must-revalidate, max-age=", + (None, None), + ), + # max-age and s-maxage with additional spaces + ( + "public, max-age= 3600 , s-maxage= 7200 ", + (7200, 3600), + ), + # Complex header with multiple spaces and ordering + ( + "must-revalidate, public, s-maxage=7200, max-age=3600", + (7200, 3600), ), ], ids=[ @@ -200,17 +279,27 @@ def test_load_models(models_dict, expected_result): "s-maxage_absent", "no_cache_control_header", "negative_s-maxage_value", + "no_max-age_or_s-maxage", + "only_s-maxage_present", + "both_max-age_and_s-maxage_zero", + "malformed_max-age", + "malformed_s-maxage", + "s-maxage_no_value", + "max-age_no_value", + "max-age_and_s-maxage_with_spaces", + "complex_header", ], ) -def test_get_s_maxage_from_cache_control_header(cache_control_header, expected_result): +def test_get_maxage_headers_from_cache_control_header(cache_control_header, expected_result): # Mock Response response = Response() - response.headers = {"cache-control": cache_control_header} + if cache_control_header is not None: + response.headers = {"cache-control": cache_control_header} + else: + response.headers = {} # Act - from pydantic_tfl_api.client import Client - - result = Client._get_s_maxage_from_cache_control_header(None, response) + result = Client._get_maxage_headers_from_cache_control_header(response) # Assert assert result == expected_result @@ -238,15 +327,16 @@ def test_get_s_maxage_from_cache_control_header(cache_control_header, expected_r def test_deserialize(model_name, response_content, expected_result): # Mock Response Response_Object = MagicMock(Response) - Response_Object.json.return_value = json.dumps(response_content) + Response_Object.json.return_value = response_content # json.dumps(response_content) # Act client = Client() return_datetime = datetime(2024, 7, 12, 13, 00, 00) + return_datetime_2 = datetime(2025, 7, 12, 13, 00, 00) with patch.object( - client, "_get_result_expiry", return_value=return_datetime + client, "_get_result_expiry", return_value=(return_datetime_2, return_datetime) ), patch.object( client, "_get_model", return_value=MockModel ) as mock_get_model, patch.object( @@ -259,33 +349,125 @@ def test_deserialize(model_name, response_content, expected_result): assert result == expected_result mock_get_model.assert_called_with(model_name) mock_create_model_instance.assert_called_with( - MockModel, Response_Object.json.return_value, return_datetime + MockModel, Response_Object.json.return_value, return_datetime, return_datetime_2 ) +@pytest.mark.parametrize( + "value, base_time, expected_result", + [ + # Valid timedelta + ( + 86400, + datetime(2023, 11, 15, 12, 45, 26), + datetime(2023, 11, 16, 12, 45, 26), + ), + # None value for timedelta + ( + None, + datetime(2023, 11, 15, 12, 45, 26), + None, + ), + # None value for base_time + ( + 86400, + None, + None, + ), + # Both value and base_time are None + ( + None, + None, + None, + ), + # Edge case: zero timedelta + ( + 0, + datetime(2023, 11, 15, 12, 45, 26), + datetime(2023, 11, 15, 12, 45, 26), + ), + # Negative timedelta + ( + -86400, + datetime(2023, 11, 15, 12, 45, 26), + datetime(2023, 11, 14, 12, 45, 26), + ), + ], + ids=[ + "valid_timedelta", + "none_value", + "none_base_time", + "both_none", + "zero_timedelta", + "negative_timedelta", + ], +) +def test_parse_timedelta(value, base_time, expected_result): + # Act + result = Client._parse_timedelta(value, base_time) + + # Assert + assert result == expected_result + @pytest.mark.parametrize( - "s_maxage, date_header, expected_result", + "s_maxage, maxage, date_header, expected_result", [ ( 86400, + 43200, + {"date": "Tue, 15 Nov 1994 12:45:26 GMT"}, + ( + parsedate_to_datetime("Tue, 15 Nov 1994 12:45:26 GMT") + timedelta(seconds=86400), + parsedate_to_datetime("Tue, 15 Nov 1994 12:45:26 GMT") + timedelta(seconds=43200) + ), + ), + ( + None, + 43200, {"date": "Tue, 15 Nov 1994 12:45:26 GMT"}, - parsedate_to_datetime("Tue, 15 Nov 1994 12:45:26 GMT") - + timedelta(seconds=86400), + ( + None, + parsedate_to_datetime("Tue, 15 Nov 1994 12:45:26 GMT") + timedelta(seconds=43200) + ), ), ( + 86400, None, {"date": "Tue, 15 Nov 1994 12:45:26 GMT"}, + ( + parsedate_to_datetime("Tue, 15 Nov 1994 12:45:26 GMT") + timedelta(seconds=86400), + None + ), + ), + ( + None, None, + {"date": "Tue, 15 Nov 1994 12:45:26 GMT"}, + (None, None), ), ( 86400, + 43200, {}, + (None, None), + ), + ( None, + 43200, + {}, + (None, None), ), ( + 86400, None, {}, + (None, None), + ), + ( + None, None, + {}, + (None, None), ), ], ids=[ @@ -293,9 +475,13 @@ def test_deserialize(model_name, response_content, expected_result): "s_maxage_absent", "date_absent", "s_maxage_and_date_absent", + "both_present_no_date", + "maxage_present_no_date", + "smaxage_present_no_date", + "neither_present_no_date", ], ) -def test_get_result_expiry(s_maxage, date_header, expected_result): +def test_get_result_expiry(s_maxage, maxage, date_header, expected_result): # Mock Response response = Response() response.headers.update(date_header) @@ -303,10 +489,13 @@ def test_get_result_expiry(s_maxage, date_header, expected_result): # Act client = Client() - with patch.object( - client, "_get_s_maxage_from_cache_control_header", return_value=s_maxage - ): - result = client._get_result_expiry(response) + # Act + with patch('pydantic_tfl_api.client.Client._get_maxage_headers_from_cache_control_header', return_value=(s_maxage, maxage)), \ + patch('pydantic_tfl_api.client.Client._parse_timedelta', side_effect=[expected_result[0], expected_result[1]]): + result = Client._get_result_expiry(response) + + # Assert + assert result == expected_result # Assert assert result == expected_result @@ -351,12 +540,13 @@ def __init__(self, models_to_set): @pytest.mark.parametrize( - "Model, response_json, result_expiry, create_model_mock_return, expected_return", + "Model, response_json, result_expiry, shared_expiry, create_model_mock_return, expected_return", [ ( MockModel, {"name": "Alice", "age": 30}, datetime(2023, 12, 31), + datetime(2024, 12, 31), "TestReturn1", "TestReturn1", ), @@ -364,6 +554,7 @@ def __init__(self, models_to_set): MockModel, [{"name": "Bob", "age": 30}, {"name": "Charlie", "age": 25}], datetime(2023, 12, 31), + datetime(2024, 12, 31), "TestReturn2", ["TestReturn2", "TestReturn2"], ), @@ -374,7 +565,7 @@ def __init__(self, models_to_set): ], ) def test_create_model_instance( - Model, response_json, result_expiry, create_model_mock_return, expected_return + Model, response_json, result_expiry, shared_expiry, create_model_mock_return, expected_return ): # Mock Client client = Client() @@ -385,17 +576,17 @@ def test_create_model_instance( ) as mock_create_model_with_expiry: # Act - result = client._create_model_instance(Model, response_json, result_expiry) + result = client._create_model_instance(Model, response_json, result_expiry, shared_expiry) # Assert assert result == expected_return if isinstance(response_json, dict): mock_create_model_with_expiry.assert_called_with( - Model, response_json, result_expiry + Model, response_json, result_expiry, shared_expiry ) else: for item in response_json: - mock_create_model_with_expiry.assert_any_call(Model, item, result_expiry) + mock_create_model_with_expiry.assert_any_call(Model, item, result_expiry, shared_expiry) datetime_object_with_time_and_tz_utc = datetime(2023, 12, 31, 1, 2, 3, tzinfo=timezone.utc)