From b4ebeac3c7782cee1bf76c04dcf1abe481ce991a Mon Sep 17 00:00:00 2001 From: bdamokos <163609735+bdamokos@users.noreply.github.com> Date: Fri, 3 Jan 2025 00:57:06 +0100 Subject: [PATCH] basic "calling at handling" --- app/transit_providers/be/sncb/api.py | 93 +++++++++++++++++++++++++--- 1 file changed, 86 insertions(+), 7 deletions(-) diff --git a/app/transit_providers/be/sncb/api.py b/app/transit_providers/be/sncb/api.py index 005124a..9452cd6 100644 --- a/app/transit_providers/be/sncb/api.py +++ b/app/transit_providers/be/sncb/api.py @@ -563,6 +563,61 @@ def _get_fallback_destination(stop_sequence: Optional[List[str]]) -> str: return stop_info.get("name", "") +def _get_trip_route(trip_id: str) -> List[Dict[str, str]]: + """Get the complete route for a trip_id from GTFS data. + + Args: + trip_id: The trip ID to look up + + Returns: + List of dicts containing stop information in sequence, each with: + - stop_id: str + - stop_name: str + - stop_sequence: int + """ + try: + gtfs_path = _get_current_gtfs_path() + if not gtfs_path: + return [] + + stop_times_file = gtfs_path / "stop_times.txt" + if not stop_times_file.exists(): + return [] + + # Read stop_times.txt and collect all stops for this trip + stops: Dict[str, dict] = {} # Use dict for deduplication + with open(stop_times_file, "r", encoding="utf-8") as f: + header = next(f).strip().split(",") + trip_id_index = header.index("trip_id") + stop_id_index = header.index("stop_id") + stop_sequence_index = header.index("stop_sequence") + + for line in f: + fields = line.strip().split(",") + current_trip_id = fields[trip_id_index].strip('"') + + # Use exact trip ID match + if current_trip_id == trip_id: + stop_id = fields[stop_id_index].strip('"') + # Remove any suffix after underscore (e.g. 8814001_7 -> 8814001) + base_stop_id = stop_id.split("_")[0] + + if base_stop_id not in stops: + stop_info = _get_stop_info(base_stop_id) + stops[base_stop_id] = { + "stop_id": base_stop_id, + "stop_name": stop_info["name"], + "stop_sequence": int(fields[stop_sequence_index]), + } + + # Sort stops by sequence + return sorted(list(stops.values()), key=lambda x: x["stop_sequence"]) + + except Exception as e: + logger.error(f"Error getting trip route for {trip_id}: {e}") + return [] + + async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: """Get waiting times for stops. @@ -735,14 +790,14 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: if not stop_id and monitored_lines and route_id not in monitored_lines: continue - # Collect stop sequence for this trip - stop_sequence = [] - for update in trip.stop_time_update: - stop_sequence.append(update.stop_id) + # Get complete route information + route_stops = _get_trip_route(trip_id) + if not route_stops: + continue - destination = _get_destination_from_trip(trip_id, stop_sequence) + destination = route_stops[-1]["stop_name"] if route_stops else None if not destination: - continue # Skip if we couldn't get a destination even with fallback + continue route_info = route_info_cache.get(route_id, _get_route_info(route_id)) @@ -754,6 +809,26 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: if stop_ids and stop_id not in stop_ids: continue + # Find current stop's sequence number + current_stop_sequence = next( + ( + s["stop_sequence"] + for s in route_stops + if s["stop_id"] == stop_id + ), + None, + ) + + # Get remaining stops in sequence + if current_stop_sequence is not None: + remaining_stops = [ + s["stop_name"] + for s in route_stops + if s["stop_sequence"] > current_stop_sequence + ] + else: + remaining_stops = [] + # Initialize line data if needed if route_id not in formatted_data["stops_data"][stop_id]["lines"]: formatted_data["stops_data"][stop_id]["lines"][route_id] = { @@ -794,6 +869,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: "arrival_timestamp": arrival_time, "delay": delay_seconds, "provider": "sncb", + "remaining_stops": remaining_stops, } ) @@ -831,7 +907,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: ) # Skip if realtime is more than 2 minutes in the past - if realtime_minutes < -2: + if realtime_minutes < -200: continue delay_seconds = time_entry.get("delay") @@ -846,6 +922,9 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: "realtime_minutes": f"{realtime_minutes}'", "realtime_time": arrival_time.strftime("%H:%M"), "provider": "sncb", + "remaining_stops": time_entry.get( + "remaining_stops", [] + ), } )