Skip to content

Commit

Permalink
basic "calling at handling"
Browse files Browse the repository at this point in the history
  • Loading branch information
bdamokos committed Jan 2, 2025
1 parent 6070105 commit b4ebeac
Showing 1 changed file with 86 additions and 7 deletions.
93 changes: 86 additions & 7 deletions app/transit_providers/be/sncb/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,61 @@ def _get_fallback_destination(stop_sequence: Optional[List[str]]) -> str:
return stop_info.get("name", "")


def _get_trip_route(trip_id: str) -> List[Dict[str, str]]:
"""Get the complete route for a trip_id from GTFS data.
Args:
trip_id: The trip ID to look up
Returns:
List of dicts containing stop information in sequence, each with:
- stop_id: str
- stop_name: str
- stop_sequence: int
"""
try:
gtfs_path = _get_current_gtfs_path()
if not gtfs_path:
return []

stop_times_file = gtfs_path / "stop_times.txt"
if not stop_times_file.exists():
return []

# Read stop_times.txt and collect all stops for this trip
stops: Dict[str, dict] = {} # Use dict for deduplication
with open(stop_times_file, "r", encoding="utf-8") as f:
header = next(f).strip().split(",")
trip_id_index = header.index("trip_id")
stop_id_index = header.index("stop_id")
stop_sequence_index = header.index("stop_sequence")

for line in f:
fields = line.strip().split(",")
current_trip_id = fields[trip_id_index].strip('"')

# Use exact trip ID match
if current_trip_id == trip_id:
stop_id = fields[stop_id_index].strip('"')
# Remove any suffix after underscore (e.g. 8814001_7 -> 8814001)
base_stop_id = stop_id.split("_")[0]

if base_stop_id not in stops:
stop_info = _get_stop_info(base_stop_id)
stops[base_stop_id] = {
"stop_id": base_stop_id,
"stop_name": stop_info["name"],
"stop_sequence": int(fields[stop_sequence_index]),
}

# Sort stops by sequence
return sorted(list(stops.values()), key=lambda x: x["stop_sequence"])

except Exception as e:
logger.error(f"Error getting trip route for {trip_id}: {e}")
return []


async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
"""Get waiting times for stops.
Expand Down Expand Up @@ -735,14 +790,14 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
if not stop_id and monitored_lines and route_id not in monitored_lines:
continue

# Collect stop sequence for this trip
stop_sequence = []
for update in trip.stop_time_update:
stop_sequence.append(update.stop_id)
# Get complete route information
route_stops = _get_trip_route(trip_id)
if not route_stops:
continue

destination = _get_destination_from_trip(trip_id, stop_sequence)
destination = route_stops[-1]["stop_name"] if route_stops else None
if not destination:
continue # Skip if we couldn't get a destination even with fallback
continue

route_info = route_info_cache.get(route_id, _get_route_info(route_id))

Expand All @@ -754,6 +809,26 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
if stop_ids and stop_id not in stop_ids:
continue

# Find current stop's sequence number
current_stop_sequence = next(
(
s["stop_sequence"]
for s in route_stops
if s["stop_id"] == stop_id
),
None,
)

# Get remaining stops in sequence
if current_stop_sequence is not None:
remaining_stops = [
s["stop_name"]
for s in route_stops
if s["stop_sequence"] > current_stop_sequence
]
else:
remaining_stops = []

# Initialize line data if needed
if route_id not in formatted_data["stops_data"][stop_id]["lines"]:
formatted_data["stops_data"][stop_id]["lines"][route_id] = {
Expand Down Expand Up @@ -794,6 +869,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
"arrival_timestamp": arrival_time,
"delay": delay_seconds,
"provider": "sncb",
"remaining_stops": remaining_stops,
}
)

Expand Down Expand Up @@ -831,7 +907,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
)

# Skip if realtime is more than 2 minutes in the past
if realtime_minutes < -2:
if realtime_minutes < -200:
continue

delay_seconds = time_entry.get("delay")
Expand All @@ -846,6 +922,9 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
"realtime_minutes": f"{realtime_minutes}'",
"realtime_time": arrival_time.strftime("%H:%M"),
"provider": "sncb",
"remaining_stops": time_entry.get(
"remaining_stops", []
),
}
)

Expand Down

0 comments on commit b4ebeac

Please sign in to comment.