Skip to content

Commit

Permalink
add stop_times cache, to be tested
Browse files Browse the repository at this point in the history
  • Loading branch information
bdamokos committed Jan 3, 2025
1 parent b4ebeac commit f31fb25
Showing 1 changed file with 149 additions and 35 deletions.
184 changes: 149 additions & 35 deletions app/transit_providers/be/sncb/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@
_trips_lru_cache = {}
_trips_lru_cache_max_size = 100000

# Add after other global cache variables
_stop_times_cache = {} # Format: {trip_id: [{"stop_id": str, "stop_sequence": int}]}
_stop_times_cache_update = None


def get_directory_size(directory: Path) -> float:
"""Calculate total size of a directory in megabytes."""
Expand Down Expand Up @@ -219,6 +223,12 @@ async def _initialize_caches():
logger.error(f"Failed to initialize trips cache: {e}", exc_info=True)
return

try:
_load_stop_times_cache()
except Exception as e:
logger.error(f"Failed to initialize stop times cache: {e}", exc_info=True)
return

# Initialize waiting times cache
_last_waiting_times_result = {"stops_data": {}, "_metadata": {}}
_last_waiting_times_update = time.time()
Expand Down Expand Up @@ -576,42 +586,28 @@ def _get_trip_route(trip_id: str) -> List[Dict[str, str]]:
- stop_sequence: int
"""
try:
gtfs_path = _get_current_gtfs_path()
if not gtfs_path:
return []
# Update cache if needed
if not _stop_times_cache_update:
_load_stop_times_cache()

stop_times_file = gtfs_path / "stop_times.txt"
if not stop_times_file.exists():
# Get stops from cache
stops = _stop_times_cache.get(trip_id, [])
if not stops:
return []

# Read stop_times.txt and collect all stops for this trip
stops: Dict[str, dict] = {} # Use dict for deduplication
with open(stop_times_file, "r", encoding="utf-8") as f:
header = next(f).strip().split(",")
trip_id_index = header.index("trip_id")
stop_id_index = header.index("stop_id")
stop_sequence_index = header.index("stop_sequence")

for line in f:
fields = line.strip().split(",")
current_trip_id = fields[trip_id_index].strip('"')

# Use exact trip ID match
if current_trip_id == trip_id:
stop_id = fields[stop_id_index].strip('"')
# Remove any suffix after underscore (e.g. 8814001_7 -> 8814001)
base_stop_id = stop_id.split("_")[0]

if base_stop_id not in stops:
stop_info = _get_stop_info(base_stop_id)
stops[base_stop_id] = {
"stop_id": base_stop_id,
"stop_name": stop_info["name"],
"stop_sequence": int(fields[stop_sequence_index]),
}
# Add stop names
result = []
for stop in stops:
stop_info = _get_stop_info(stop["stop_id"])
result.append(
{
"stop_id": stop["stop_id"],
"stop_name": stop_info["name"],
"stop_sequence": stop["stop_sequence"],
}
)

# Sort stops by sequence
return sorted(list(stops.values()), key=lambda x: x["stop_sequence"])
return result

except Exception as e:
logger.error(f"Error getting trip route for {trip_id}: {e}")
Expand Down Expand Up @@ -907,7 +903,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict:
)

# Skip if realtime is more than 2 minutes in the past
if realtime_minutes < -200:
if realtime_minutes < -2:
continue

delay_seconds = time_entry.get("delay")
Expand Down Expand Up @@ -1233,8 +1229,77 @@ async def get_line_info() -> Dict[str, Dict[str, Any]]:
# Initialize caches at module load
async def _ensure_caches_initialized():
"""Ensure caches are initialized"""
if not _caches_initialized:
await _initialize_caches()
global _caches_initialized
if _caches_initialized:
return

try:
logger.info("Initializing SNCB provider caches...")

# First ensure GTFS data is available
if not gtfs_manager:
logger.error("GTFSManager not initialized")
return

logger.info("Ensuring GTFS data is available...")
gtfs_path = await gtfs_manager.ensure_gtfs_data()
if not gtfs_path:
logger.error("Failed to ensure GTFS data")
return

# Wait for the GTFS file to exist and be non-empty
max_retries = 60
retry_delay = 1 # seconds
for i in range(max_retries):
if gtfs_path.exists() and gtfs_path.stat().st_size > 0:
logger.info(
f"GTFS data available at {gtfs_path} (size: {get_directory_size(gtfs_path):.2f} MB)"
)
break
if i < max_retries - 1:
logger.warning(
f"GTFS file not ready yet, retrying in {retry_delay} seconds... ({i + 1}/{max_retries})"
)
await asyncio.sleep(retry_delay)
else:
logger.error(f"GTFS file not available after {max_retries} retries")
_caches_initialized = False
return

# Now that we have GTFS data, load the caches
try:
_load_stops_cache()
except Exception as e:
logger.error(f"Failed to initialize stops cache: {e}", exc_info=True)
return

try:
_load_routes_cache()
except Exception as e:
logger.error(f"Failed to initialize routes cache: {e}", exc_info=True)
return

try:
_load_trips_cache()
except Exception as e:
logger.error(f"Failed to initialize trips cache: {e}", exc_info=True)
return

try:
_load_stop_times_cache()
except Exception as e:
logger.error(f"Failed to initialize stop times cache: {e}", exc_info=True)
return

# Initialize waiting times cache
_last_waiting_times_result = {"stops_data": {}, "_metadata": {}}
_last_waiting_times_update = time.time()

_caches_initialized = True
logger.info("SNCB provider caches initialized successfully")
except Exception as e:
logger.error(f"Error initializing SNCB provider caches: {e}")
_caches_initialized = False


# Create event loop and run initialization
Expand All @@ -1244,3 +1309,52 @@ async def _ensure_caches_initialized():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(_ensure_caches_initialized())


def _load_stop_times_cache() -> None:
"""Load stop times from GTFS data into cache"""
global _stop_times_cache, _stop_times_cache_update

try:
gtfs_path = _get_current_gtfs_path()
if not gtfs_path:
return

stop_times_file = gtfs_path / "stop_times.txt"
if not stop_times_file.exists():
return

new_cache = {}
with open(stop_times_file, "r", encoding="utf-8") as f:
header = next(f).strip().split(",")
trip_id_index = header.index("trip_id")
stop_id_index = header.index("stop_id")
stop_sequence_index = header.index("stop_sequence")

for line in f:
fields = line.strip().split(",")
trip_id = fields[trip_id_index].strip('"')
stop_id = fields[stop_id_index].strip('"')
# Remove any suffix after underscore (e.g. 8814001_7 -> 8814001)
base_stop_id = stop_id.split("_")[0]

if trip_id not in new_cache:
new_cache[trip_id] = []

new_cache[trip_id].append(
{
"stop_id": base_stop_id,
"stop_sequence": int(fields[stop_sequence_index]),
}
)

# Sort stops by sequence for each trip
for trip_id in new_cache:
new_cache[trip_id].sort(key=lambda x: x["stop_sequence"])

_stop_times_cache = new_cache
_stop_times_cache_update = datetime.now(timezone.utc)
logger.info(f"Updated stop times cache with {len(_stop_times_cache)} trips")

except Exception as e:
logger.error(f"Error loading stop times cache: {e}")

0 comments on commit f31fb25

Please sign in to comment.