Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chore/add tests #46

Merged
merged 14 commits into from
Jul 15, 2024
Merged
31 changes: 31 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Debug Tests",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"purpose": [
"debug-test"
],
"console": "integratedTerminal",
"justMyCode": false,
"env": {
"PYTEST_ADDOPTS": "--no-cov",
"PYTHONPATH": "${workspaceFolder}",
"LOG_LEVEL": "DEBUG"
}
},
{
"name": "Python Debugger: example.py",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/example.py",
"console": "integratedTerminal"
}
]
}
13 changes: 13 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic_tfl_api.client import Client, ApiToken

app_id = 'APPLICATION ID'
app_key = 'APPLICATION KEY'

# token = ApiToken(app_id, app_key)
token = None # only need a token if > 1 request per second

client = Client(token)
print (client.get_line_meta_modes())
print (client.get_lines(mode="bus")[0].model_dump_json())
print (client.get_lines(line_id="victoria")[0].model_dump_json())
print (client.get_route_by_line_id_with_direction(line_id="northern", direction="all").model_dump_json())
122 changes: 49 additions & 73 deletions pydantic_tfl_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ def _get_s_maxage_from_cache_control_header(self, response: Response) -> int | N
directives = cache_control.split(" ")
# e.g. ['public,', 'must-revalidate,', 'max-age=43200,', 's-maxage=86400']
directives = {d.split("=")[0]: d.split("=")[1] for d in directives if "=" in d}
return None if "s-maxage" not in directives else int(directives["s-maxage"])
return None if "s-maxage" not in directives else int(directives["s-maxage"])

def _get_result_expiry(self, response: Response) -> datetime | None:
s_maxage = self._get_s_maxage_from_cache_control_header(response)
request_datetime = parsedate_to_datetime(response.headers.get("date"))
request_datetime = parsedate_to_datetime(response.headers.get("date")) if "date" in response.headers else None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Handling of missing 'date' header

Returning None when the 'date' header is missing might lead to issues downstream if request_datetime is expected to be a datetime object. Consider raising an exception or handling this case more explicitly.

if s_maxage and request_datetime:
return request_datetime + timedelta(seconds=s_maxage)
return None
Expand All @@ -90,7 +90,7 @@ def _get_model(self, model_name: str) -> BaseModel:

def _create_model_instance(
self, Model: BaseModel, response_json: Any, result_expiry: datetime | None
):
) -> BaseModel | List[BaseModel]:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Return type annotation

The return type annotation BaseModel | List[BaseModel] might be too broad. If possible, consider specifying more precise types to improve type safety and readability.

Suggested change
) -> BaseModel | List[BaseModel]:
) -> Union[Model, List[Model]]:

if isinstance(response_json, dict):
return self._create_model_with_expiry(Model, response_json, result_expiry)
else:
Expand All @@ -111,150 +111,126 @@ def _deserialize_error(self, response: Response) -> models.ApiError:
if response.headers.get("content-type") == "application/json":
return self._deserialize("ApiError", response)
return models.ApiError(
timestampUtc=response.headers.get("date"),
timestampUtc=parsedate_to_datetime(response.headers.get("date")),
exceptionType="Unknown",
httpStatusCode=response.status_code,
httpStatus=response.reason,
relativeUri=response.url,
message=response.text,
)


def _send_request_and_deserialize(
self, endpoint: str, model_name: str, endpoint_args: dict = None
) -> BaseModel | List[BaseModel] | models.ApiError:
response = self.client.send_request(endpoint, endpoint_args)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize(model_name, response)

def get_stop_points_by_line_id(
self, line_id: str
) -> models.StopPoint | List[models.StopPoint] | models.ApiError:
response = self.client.send_request(
endpoints["stopPointsByLineId"].format(line_id)
return self._send_request_and_deserialize(
endpoints["stopPointsByLineId"].format(line_id), "StopPoint"
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("StopPoint", response)

def get_line_meta_modes(self) -> models.Mode | models.ApiError:
response = self.client.send_request(endpoints["lineMetaModes"])
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Mode", response)
return self._send_request_and_deserialize(endpoints["lineMetaModes"], "Mode")

def get_lines(
self, line_id: str | None = None, mode: str | None = None
) -> models.Line | List[models.Line] | models.ApiError:
Comment on lines +133 to 142
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Refactoring to use _send_request_and_deserialize

The refactoring to use _send_request_and_deserialize improves code reuse and readability. However, ensure that the new method handles all edge cases and error conditions that were previously managed in the individual methods.

if line_id is None and mode is None:
raise Exception(
raise ValueError(
"Either the --line_id argument or the --mode argument needs to be specified."
)
if line_id is not None:
endpoint = endpoints["linesByLineId"].format(line_id)
else:
endpoint = endpoints["linesByMode"].format(mode)
response = self.client.send_request(endpoint)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)
return self._send_request_and_deserialize(endpoint, "Line")

def get_line_status(
self, line: str, include_details: bool = None
) -> models.Line | List[models.Line] | models.ApiError:
response = self.client.send_request(
endpoints["lineStatus"].format(line), {"detail": include_details is True}
return self._send_request_and_deserialize(
endpoints["lineStatus"].format(line), "Line", {"detail": include_details}
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)

def get_line_status_severity(
self, severity: str
) -> models.Line | List[models.Line] | models.ApiError:
response = self.client.send_request(
endpoints["lineStatusBySeverity"].format(severity)
return self._send_request_and_deserialize(
endpoints["lineStatusBySeverity"].format(severity), "Line"
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)

def get_line_status_by_mode(
self, mode: str
) -> models.Line | List[models.Line] | models.ApiError:
response = self.client.send_request(endpoints["lineStatusByMode"].format(mode))
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)
return self._send_request_and_deserialize(
endpoints["lineStatusByMode"].format(mode), "Line"
)

def get_route_by_line_id(
self, line_id: str
) -> models.Line | List[models.Line] | models.ApiError:
response = self.client.send_request(endpoints["routeByLineId"].format(line_id))
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)
return self._send_request_and_deserialize(
endpoints["routeByLineId"].format(line_id), "Line"
)

def get_route_by_mode(
self, mode: str
) -> models.Line | List[models.Line] | models.ApiError:
response = self.client.send_request(endpoints["routeByMode"].format(mode))
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Line", response)
return self._send_request_and_deserialize(
endpoints["routeByMode"].format(mode), "Line"
)

def get_route_by_line_id_with_direction(
self, line_id: str, direction: Literal["inbound", "outbound", "all"]
) -> models.RouteSequence | List[models.RouteSequence] | models.ApiError:
response = self.client.send_request(
endpoints["routeByLineIdWithDirection"].format(line_id, direction)
return self._send_request_and_deserialize(
endpoints["routeByLineIdWithDirection"].format(line_id, direction),
"RouteSequence",
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("RouteSequence", response)

def get_line_disruptions_by_line_id(
self, line_id: str
) -> models.Disruption | List[models.Disruption] | models.ApiError:
response = self.client.send_request(
endpoints["lineDisruptionsByLineId"].format(line_id)
return self._send_request_and_deserialize(
endpoints["lineDisruptionsByLineId"].format(line_id), "Disruption"
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Disruption", response)

def get_line_disruptions_by_mode(
self, mode: str
) -> models.Disruption | List[models.Disruption] | models.ApiError:
response = self.client.send_request(
endpoints["lineDisruptionsByMode"].format(mode)
return self._send_request_and_deserialize(
endpoints["lineDisruptionsByMode"].format(mode), "Disruption"
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Disruption", response)

def get_stop_points_by_id(
self, id: str
) -> models.StopPoint | List[models.StopPoint] | models.ApiError:
response = self.client.send_request(endpoints["stopPointById"].format(id))
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("StopPoint", response)
return self._send_request_and_deserialize(
endpoints["stopPointById"].format(id), "StopPoint"
)

def get_stop_points_by_mode(
self, mode: str
) -> models.StopPointsResponse | List[models.StopPointsResponse] | models.ApiError:
response = self.client.send_request(endpoints["stopPointByMode"].format(mode))
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("StopPointsResponse", response)
return self._send_request_and_deserialize(
endpoints["stopPointByMode"].format(mode), "StopPointsResponse"
)

def get_stop_point_meta_modes(
self,
) -> models.Mode | List[models.Mode] | models.ApiError:
response = self.client.send_request(endpoints["stopPointMetaModes"])
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Mode", response)
return self._send_request_and_deserialize(
endpoints["stopPointMetaModes"], "Mode"
)

def get_arrivals_by_line_id(
self, line_id: str
) -> models.Prediction | List[models.Prediction] | models.ApiError:
response = self.client.send_request(
endpoints["arrivalsByLineId"].format(line_id)
return self._send_request_and_deserialize(
endpoints["arrivalsByLineId"].format(line_id), "Prediction"
)
if response.status_code != 200:
return self._deserialize_error(response)
return self._deserialize("Prediction", response)
4 changes: 3 additions & 1 deletion pydantic_tfl_api/models/api_error.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, Field, field_validator
from datetime import datetime
from email.utils import parsedate_to_datetime

class ApiError(BaseModel):
timestamp_utc: datetime = Field(alias='timestampUtc')
Expand All @@ -11,6 +12,7 @@ class ApiError(BaseModel):

@field_validator('timestamp_utc', mode='before')
def parse_timestamp(cls, v):
return datetime.strptime(v, '%a, %d %b %Y %H:%M:%S %Z')
return v if isinstance(v, datetime) else parsedate_to_datetime(v)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Field validator logic

The field validator now handles both datetime objects and string dates. Ensure that parsedate_to_datetime correctly handles all expected date formats and edge cases.

Suggested change
return v if isinstance(v, datetime) else parsedate_to_datetime(v)
if isinstance(v, datetime):
return v
try:
return parsedate_to_datetime(v)
except (TypeError, ValueError):
raise ValueError(f"Invalid date format: {v}")

# return datetime.strptime(v, '%a, %d %b %Y %H:%M:%S %Z')

model_config = {'populate_by_name': True}
Loading
Loading