-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Chore/add tests #46
Chore/add tests #46
Changes from 11 commits
53962a4
c01c638
802e3f9
557b7a6
a1f555b
7d6bc16
4fcf966
9226423
e6b8bfc
0037a7c
b571b15
f46f376
401a301
305094a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
{ | ||
// Use IntelliSense to learn about possible attributes. | ||
// Hover to view descriptions of existing attributes. | ||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 | ||
"version": "0.2.0", | ||
"configurations": [ | ||
{ | ||
"name": "Python: Debug Tests", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "${file}", | ||
"purpose": [ | ||
"debug-test" | ||
], | ||
"console": "integratedTerminal", | ||
"justMyCode": false, | ||
"env": { | ||
"PYTEST_ADDOPTS": "--no-cov", | ||
"PYTHONPATH": "${workspaceFolder}", | ||
"LOG_LEVEL": "DEBUG" | ||
} | ||
}, | ||
{ | ||
"name": "Python Debugger: example.py", | ||
"type": "debugpy", | ||
"request": "launch", | ||
"program": "${workspaceFolder}/example.py", | ||
"console": "integratedTerminal" | ||
} | ||
] | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from pydantic_tfl_api.client import Client, ApiToken | ||
|
||
app_id = 'APPLICATION ID' | ||
app_key = 'APPLICATION KEY' | ||
|
||
# token = ApiToken(app_id, app_key) | ||
token = None # only need a token if > 1 request per second | ||
|
||
client = Client(token) | ||
print (client.get_line_meta_modes()) | ||
print (client.get_lines(mode="bus")[0].model_dump_json()) | ||
print (client.get_lines(line_id="victoria")[0].model_dump_json()) | ||
print (client.get_route_by_line_id_with_direction(line_id="northern", direction="all").model_dump_json()) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -64,11 +64,11 @@ def _get_s_maxage_from_cache_control_header(self, response: Response) -> int | N | |||||
directives = cache_control.split(" ") | ||||||
# e.g. ['public,', 'must-revalidate,', 'max-age=43200,', 's-maxage=86400'] | ||||||
directives = {d.split("=")[0]: d.split("=")[1] for d in directives if "=" in d} | ||||||
return None if "s-maxage" not in directives else int(directives["s-maxage"]) | ||||||
return None if "s-maxage" not in directives else int(directives["s-maxage"]) | ||||||
|
||||||
def _get_result_expiry(self, response: Response) -> datetime | None: | ||||||
s_maxage = self._get_s_maxage_from_cache_control_header(response) | ||||||
request_datetime = parsedate_to_datetime(response.headers.get("date")) | ||||||
request_datetime = parsedate_to_datetime(response.headers.get("date")) if "date" in response.headers else None | ||||||
if s_maxage and request_datetime: | ||||||
return request_datetime + timedelta(seconds=s_maxage) | ||||||
return None | ||||||
|
@@ -90,7 +90,7 @@ def _get_model(self, model_name: str) -> BaseModel: | |||||
|
||||||
def _create_model_instance( | ||||||
self, Model: BaseModel, response_json: Any, result_expiry: datetime | None | ||||||
): | ||||||
) -> BaseModel | List[BaseModel]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Return type annotation The return type annotation
Suggested change
|
||||||
if isinstance(response_json, dict): | ||||||
return self._create_model_with_expiry(Model, response_json, result_expiry) | ||||||
else: | ||||||
|
@@ -111,150 +111,126 @@ def _deserialize_error(self, response: Response) -> models.ApiError: | |||||
if response.headers.get("content-type") == "application/json": | ||||||
return self._deserialize("ApiError", response) | ||||||
return models.ApiError( | ||||||
timestampUtc=response.headers.get("date"), | ||||||
timestampUtc=parsedate_to_datetime(response.headers.get("date")), | ||||||
exceptionType="Unknown", | ||||||
httpStatusCode=response.status_code, | ||||||
httpStatus=response.reason, | ||||||
relativeUri=response.url, | ||||||
message=response.text, | ||||||
) | ||||||
|
||||||
|
||||||
def _send_request_and_deserialize( | ||||||
self, endpoint: str, model_name: str, endpoint_args: dict = None | ||||||
) -> BaseModel | List[BaseModel] | models.ApiError: | ||||||
response = self.client.send_request(endpoint, endpoint_args) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize(model_name, response) | ||||||
|
||||||
def get_stop_points_by_line_id( | ||||||
self, line_id: str | ||||||
) -> models.StopPoint | List[models.StopPoint] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["stopPointsByLineId"].format(line_id) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["stopPointsByLineId"].format(line_id), "StopPoint" | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("StopPoint", response) | ||||||
|
||||||
def get_line_meta_modes(self) -> models.Mode | models.ApiError: | ||||||
response = self.client.send_request(endpoints["lineMetaModes"]) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Mode", response) | ||||||
return self._send_request_and_deserialize(endpoints["lineMetaModes"], "Mode") | ||||||
|
||||||
def get_lines( | ||||||
self, line_id: str | None = None, mode: str | None = None | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
Comment on lines
+133
to
142
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue: Refactoring to use _send_request_and_deserialize The refactoring to use |
||||||
if line_id is None and mode is None: | ||||||
raise Exception( | ||||||
raise ValueError( | ||||||
"Either the --line_id argument or the --mode argument needs to be specified." | ||||||
) | ||||||
if line_id is not None: | ||||||
endpoint = endpoints["linesByLineId"].format(line_id) | ||||||
else: | ||||||
endpoint = endpoints["linesByMode"].format(mode) | ||||||
response = self.client.send_request(endpoint) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
return self._send_request_and_deserialize(endpoint, "Line") | ||||||
|
||||||
def get_line_status( | ||||||
self, line: str, include_details: bool = None | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["lineStatus"].format(line), {"detail": include_details is True} | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["lineStatus"].format(line), "Line", {"detail": include_details} | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
|
||||||
def get_line_status_severity( | ||||||
self, severity: str | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["lineStatusBySeverity"].format(severity) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["lineStatusBySeverity"].format(severity), "Line" | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
|
||||||
def get_line_status_by_mode( | ||||||
self, mode: str | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["lineStatusByMode"].format(mode)) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["lineStatusByMode"].format(mode), "Line" | ||||||
) | ||||||
|
||||||
def get_route_by_line_id( | ||||||
self, line_id: str | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["routeByLineId"].format(line_id)) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["routeByLineId"].format(line_id), "Line" | ||||||
) | ||||||
|
||||||
def get_route_by_mode( | ||||||
self, mode: str | ||||||
) -> models.Line | List[models.Line] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["routeByMode"].format(mode)) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Line", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["routeByMode"].format(mode), "Line" | ||||||
) | ||||||
|
||||||
def get_route_by_line_id_with_direction( | ||||||
self, line_id: str, direction: Literal["inbound", "outbound", "all"] | ||||||
) -> models.RouteSequence | List[models.RouteSequence] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["routeByLineIdWithDirection"].format(line_id, direction) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["routeByLineIdWithDirection"].format(line_id, direction), | ||||||
"RouteSequence", | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("RouteSequence", response) | ||||||
|
||||||
def get_line_disruptions_by_line_id( | ||||||
self, line_id: str | ||||||
) -> models.Disruption | List[models.Disruption] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["lineDisruptionsByLineId"].format(line_id) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["lineDisruptionsByLineId"].format(line_id), "Disruption" | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Disruption", response) | ||||||
|
||||||
def get_line_disruptions_by_mode( | ||||||
self, mode: str | ||||||
) -> models.Disruption | List[models.Disruption] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["lineDisruptionsByMode"].format(mode) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["lineDisruptionsByMode"].format(mode), "Disruption" | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Disruption", response) | ||||||
|
||||||
def get_stop_points_by_id( | ||||||
self, id: str | ||||||
) -> models.StopPoint | List[models.StopPoint] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["stopPointById"].format(id)) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("StopPoint", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["stopPointById"].format(id), "StopPoint" | ||||||
) | ||||||
|
||||||
def get_stop_points_by_mode( | ||||||
self, mode: str | ||||||
) -> models.StopPointsResponse | List[models.StopPointsResponse] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["stopPointByMode"].format(mode)) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("StopPointsResponse", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["stopPointByMode"].format(mode), "StopPointsResponse" | ||||||
) | ||||||
|
||||||
def get_stop_point_meta_modes( | ||||||
self, | ||||||
) -> models.Mode | List[models.Mode] | models.ApiError: | ||||||
response = self.client.send_request(endpoints["stopPointMetaModes"]) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Mode", response) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["stopPointMetaModes"], "Mode" | ||||||
) | ||||||
|
||||||
def get_arrivals_by_line_id( | ||||||
self, line_id: str | ||||||
) -> models.Prediction | List[models.Prediction] | models.ApiError: | ||||||
response = self.client.send_request( | ||||||
endpoints["arrivalsByLineId"].format(line_id) | ||||||
return self._send_request_and_deserialize( | ||||||
endpoints["arrivalsByLineId"].format(line_id), "Prediction" | ||||||
) | ||||||
if response.status_code != 200: | ||||||
return self._deserialize_error(response) | ||||||
return self._deserialize("Prediction", response) |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||||||||||||
from pydantic import BaseModel, Field, field_validator | ||||||||||||||||
from datetime import datetime | ||||||||||||||||
from email.utils import parsedate_to_datetime | ||||||||||||||||
|
||||||||||||||||
class ApiError(BaseModel): | ||||||||||||||||
timestamp_utc: datetime = Field(alias='timestampUtc') | ||||||||||||||||
|
@@ -11,6 +12,7 @@ class ApiError(BaseModel): | |||||||||||||||
|
||||||||||||||||
@field_validator('timestamp_utc', mode='before') | ||||||||||||||||
def parse_timestamp(cls, v): | ||||||||||||||||
return datetime.strptime(v, '%a, %d %b %Y %H:%M:%S %Z') | ||||||||||||||||
return v if isinstance(v, datetime) else parsedate_to_datetime(v) | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Field validator logic The field validator now handles both datetime objects and string dates. Ensure that
Suggested change
|
||||||||||||||||
# return datetime.strptime(v, '%a, %d %b %Y %H:%M:%S %Z') | ||||||||||||||||
|
||||||||||||||||
model_config = {'populate_by_name': True} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: Handling of missing 'date' header
Returning
None
when the 'date' header is missing might lead to issues downstream ifrequest_datetime
is expected to be a datetime object. Consider raising an exception or handling this case more explicitly.