Skip to content

Commit

Permalink
Feat/include both maxage and s maxage headers (#60)
Browse files Browse the repository at this point in the history
adds support for handling both 'max-age' and 's-maxage' headers in cache control, introduces a new 'shared_expires' attribute to various models, and updates the test suite to cover these changes.
  • Loading branch information
mnbf9rca authored Jul 23, 2024
1 parent ca8423e commit 6c5a4ee
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 48 deletions.
60 changes: 43 additions & 17 deletions pydantic_tfl_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from .config import endpoints
from .rest_client import RestClient
from importlib import import_module
from typing import Any, Literal, List, Optional
from typing import Any, Literal, List, Optional, Tuple
from requests import Response
import pkgutil
from pydantic import BaseModel
Expand Down Expand Up @@ -55,29 +55,51 @@ def _load_models(self):
# print(models_dict)
return models_dict

def _get_s_maxage_from_cache_control_header(self, response: Response) -> int | None:
@staticmethod
def _parse_int_or_none(value: str) -> int | None:
try:
return int(value)
except (TypeError, ValueError):
return None

@staticmethod
def _get_maxage_headers_from_cache_control_header(response: Response) -> Tuple[Optional[int], Optional[int]]:
cache_control = response.headers.get("cache-control")
# e.g. 'public, must-revalidate, max-age=43200, s-maxage=86400'
if cache_control is None:
return None
directives = cache_control.split(" ")
# e.g. ['public,', 'must-revalidate,', 'max-age=43200,', 's-maxage=86400']
return None, None
directives = cache_control.split(", ")
# e.g. ['public', 'must-revalidate', 'max-age=43200', 's-maxage=86400']
directives = {d.split("=")[0]: d.split("=")[1] for d in directives if "=" in d}
return None if "s-maxage" not in directives else int(directives["s-maxage"])
smaxage = Client._parse_int_or_none(directives.get("s-maxage", ""))
maxage = Client._parse_int_or_none(directives.get("max-age", ""))
return smaxage, maxage


def _get_result_expiry(self, response: Response) -> datetime | None:
s_maxage = self._get_s_maxage_from_cache_control_header(response)

@staticmethod
def _parse_timedelta(value: Optional[int], base_time: Optional[datetime]) -> Optional[datetime]:
try:
return base_time + timedelta(seconds=value) if value is not None and base_time is not None else None
except (TypeError, ValueError):
return None

@staticmethod
def _get_result_expiry(response: Response) -> Tuple[ datetime | None, datetime | None]:
s_maxage, maxage = Client._get_maxage_headers_from_cache_control_header(response)
request_datetime = parsedate_to_datetime(response.headers.get("date")) if "date" in response.headers else None
if s_maxage and request_datetime:
return request_datetime + timedelta(seconds=s_maxage)
return None

s_maxage_expiry = Client._parse_timedelta(s_maxage, request_datetime)
maxage_expiry = Client._parse_timedelta(maxage, request_datetime)

return s_maxage_expiry, maxage_expiry

def _deserialize(self, model_name: str, response: Response) -> Any:
result_expiry = self._get_result_expiry(response)
shared_expiry, result_expiry = self._get_result_expiry(response)
Model = self._get_model(model_name)
data = response.json()

result = self._create_model_instance(Model, data, result_expiry)
result = self._create_model_instance(Model, data, result_expiry, shared_expiry)

return result

Expand All @@ -88,21 +110,25 @@ def _get_model(self, model_name: str) -> BaseModel:
return Model

def _create_model_instance(
self, Model: BaseModel, response_json: Any, result_expiry: datetime | None
self, Model: BaseModel,
response_json: Any,
result_expiry: datetime | None,
shared_expiry: datetime | None
) -> BaseModel | List[BaseModel]:
if isinstance(response_json, dict):
return self._create_model_with_expiry(Model, response_json, result_expiry)
return self._create_model_with_expiry(Model, response_json, result_expiry, shared_expiry)
else:
return [
self._create_model_with_expiry(Model, item, result_expiry)
self._create_model_with_expiry(Model, item, result_expiry, shared_expiry)
for item in response_json
]

def _create_model_with_expiry(
self, Model: BaseModel, response_json: Any, result_expiry: datetime | None
self, Model: BaseModel, response_json: Any, result_expiry: Optional[datetime], shared_expiry: Optional[datetime]
):
instance = Model(**response_json)
instance.content_expires = result_expiry
instance.shared_expires = shared_expiry
return instance

def _deserialize_error(self, response: Response) -> models.ApiError:
Expand Down
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/disruption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class Disruption(BaseModel):
affected_stops: List[StopPoint] = Field(alias='affectedStops')
closure_text: str = Field(alias='closureText')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}

Expand Down
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ class Line(BaseModel):
service_types: Optional[List[ServiceType]] = Field(None, alias='serviceTypes')
crowding: Optional[Crowding] = Field(None, alias='crowding')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@ class Mode(BaseModel):
is_scheduled_service: bool = Field(alias='isScheduledService')
mode_name: str = Field(alias='modeName')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@ class Prediction(BaseModel):
mode_name: str = Field(alias='modeName')
timing: PredictionTiming = Field(alias='timing')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/route_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ class RouteSequence(BaseModel):
service_type: Optional[str] = Field(None, alias='serviceType')
ordered_line_routes: list[OrderedRoute] = Field(alias='orderedLineRoutes')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/stop_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StopPoint(BaseModel):
lat: float = Field(alias="lat")
lon: float = Field(alias="lon")
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {"populate_by_name": True}

Expand Down
1 change: 1 addition & 0 deletions pydantic_tfl_api/models/stop_points_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ class StopPointsResponse(BaseModel):
total: int = Field(alias='total')
page: int = Field(alias='page')
content_expires: Optional[datetime] = Field(None)
shared_expires: Optional[datetime] = Field(None)

model_config = {'populate_by_name': True}
Loading

0 comments on commit 6c5a4ee

Please sign in to comment.