From f31fb256f9940dcbe94aa72e891f17cd8fb4e037 Mon Sep 17 00:00:00 2001 From: bdamokos <163609735+bdamokos@users.noreply.github.com> Date: Fri, 3 Jan 2025 01:03:14 +0100 Subject: [PATCH] add stop_times cache, to be tested --- app/transit_providers/be/sncb/api.py | 184 ++++++++++++++++++++++----- 1 file changed, 149 insertions(+), 35 deletions(-) diff --git a/app/transit_providers/be/sncb/api.py b/app/transit_providers/be/sncb/api.py index 9452cd6..e6517bb 100644 --- a/app/transit_providers/be/sncb/api.py +++ b/app/transit_providers/be/sncb/api.py @@ -86,6 +86,10 @@ _trips_lru_cache = {} _trips_lru_cache_max_size = 100000 +# Add after other global cache variables +_stop_times_cache = {} # Format: {trip_id: [{"stop_id": str, "stop_sequence": int}]} +_stop_times_cache_update = None + def get_directory_size(directory: Path) -> float: """Calculate total size of a directory in megabytes.""" @@ -219,6 +223,12 @@ async def _initialize_caches(): logger.error(f"Failed to initialize trips cache: {e}", exc_info=True) return + try: + _load_stop_times_cache() + except Exception as e: + logger.error(f"Failed to initialize stop times cache: {e}", exc_info=True) + return + # Initialize waiting times cache _last_waiting_times_result = {"stops_data": {}, "_metadata": {}} _last_waiting_times_update = time.time() @@ -576,42 +586,28 @@ def _get_trip_route(trip_id: str) -> List[Dict[str, str]]: - stop_sequence: int """ try: - gtfs_path = _get_current_gtfs_path() - if not gtfs_path: - return [] + # Update cache if needed + if not _stop_times_cache_update: + _load_stop_times_cache() - stop_times_file = gtfs_path / "stop_times.txt" - if not stop_times_file.exists(): + # Get stops from cache + stops = _stop_times_cache.get(trip_id, []) + if not stops: return [] - # Read stop_times.txt and collect all stops for this trip - stops: Dict[str, dict] = {} # Use dict for deduplication - with open(stop_times_file, "r", encoding="utf-8") as f: - header = next(f).strip().split(",") - trip_id_index = header.index("trip_id") - stop_id_index = header.index("stop_id") - stop_sequence_index = header.index("stop_sequence") - - for line in f: - fields = line.strip().split(",") - current_trip_id = fields[trip_id_index].strip('"') - - # Use exact trip ID match - if current_trip_id == trip_id: - stop_id = fields[stop_id_index].strip('"') - # Remove any suffix after underscore (e.g. 8814001_7 -> 8814001) - base_stop_id = stop_id.split("_")[0] - - if base_stop_id not in stops: - stop_info = _get_stop_info(base_stop_id) - stops[base_stop_id] = { - "stop_id": base_stop_id, - "stop_name": stop_info["name"], - "stop_sequence": int(fields[stop_sequence_index]), - } + # Add stop names + result = [] + for stop in stops: + stop_info = _get_stop_info(stop["stop_id"]) + result.append( + { + "stop_id": stop["stop_id"], + "stop_name": stop_info["name"], + "stop_sequence": stop["stop_sequence"], + } + ) - # Sort stops by sequence - return sorted(list(stops.values()), key=lambda x: x["stop_sequence"]) + return result except Exception as e: logger.error(f"Error getting trip route for {trip_id}: {e}") @@ -907,7 +903,7 @@ async def get_waiting_times(stop_id: Union[str, List[str]] = None) -> Dict: ) # Skip if realtime is more than 2 minutes in the past - if realtime_minutes < -200: + if realtime_minutes < -2: continue delay_seconds = time_entry.get("delay") @@ -1233,8 +1229,77 @@ async def get_line_info() -> Dict[str, Dict[str, Any]]: # Initialize caches at module load async def _ensure_caches_initialized(): """Ensure caches are initialized""" - if not _caches_initialized: - await _initialize_caches() + global _caches_initialized + if _caches_initialized: + return + + try: + logger.info("Initializing SNCB provider caches...") + + # First ensure GTFS data is available + if not gtfs_manager: + logger.error("GTFSManager not initialized") + return + + logger.info("Ensuring GTFS data is available...") + gtfs_path = await gtfs_manager.ensure_gtfs_data() + if not gtfs_path: + logger.error("Failed to ensure GTFS data") + return + + # Wait for the GTFS file to exist and be non-empty + max_retries = 60 + retry_delay = 1 # seconds + for i in range(max_retries): + if gtfs_path.exists() and gtfs_path.stat().st_size > 0: + logger.info( + f"GTFS data available at {gtfs_path} (size: {get_directory_size(gtfs_path):.2f} MB)" + ) + break + if i < max_retries - 1: + logger.warning( + f"GTFS file not ready yet, retrying in {retry_delay} seconds... ({i + 1}/{max_retries})" + ) + await asyncio.sleep(retry_delay) + else: + logger.error(f"GTFS file not available after {max_retries} retries") + _caches_initialized = False + return + + # Now that we have GTFS data, load the caches + try: + _load_stops_cache() + except Exception as e: + logger.error(f"Failed to initialize stops cache: {e}", exc_info=True) + return + + try: + _load_routes_cache() + except Exception as e: + logger.error(f"Failed to initialize routes cache: {e}", exc_info=True) + return + + try: + _load_trips_cache() + except Exception as e: + logger.error(f"Failed to initialize trips cache: {e}", exc_info=True) + return + + try: + _load_stop_times_cache() + except Exception as e: + logger.error(f"Failed to initialize stop times cache: {e}", exc_info=True) + return + + # Initialize waiting times cache + _last_waiting_times_result = {"stops_data": {}, "_metadata": {}} + _last_waiting_times_update = time.time() + + _caches_initialized = True + logger.info("SNCB provider caches initialized successfully") + except Exception as e: + logger.error(f"Error initializing SNCB provider caches: {e}") + _caches_initialized = False # Create event loop and run initialization @@ -1244,3 +1309,52 @@ async def _ensure_caches_initialized(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(_ensure_caches_initialized()) + + +def _load_stop_times_cache() -> None: + """Load stop times from GTFS data into cache""" + global _stop_times_cache, _stop_times_cache_update + + try: + gtfs_path = _get_current_gtfs_path() + if not gtfs_path: + return + + stop_times_file = gtfs_path / "stop_times.txt" + if not stop_times_file.exists(): + return + + new_cache = {} + with open(stop_times_file, "r", encoding="utf-8") as f: + header = next(f).strip().split(",") + trip_id_index = header.index("trip_id") + stop_id_index = header.index("stop_id") + stop_sequence_index = header.index("stop_sequence") + + for line in f: + fields = line.strip().split(",") + trip_id = fields[trip_id_index].strip('"') + stop_id = fields[stop_id_index].strip('"') + # Remove any suffix after underscore (e.g. 8814001_7 -> 8814001) + base_stop_id = stop_id.split("_")[0] + + if trip_id not in new_cache: + new_cache[trip_id] = [] + + new_cache[trip_id].append( + { + "stop_id": base_stop_id, + "stop_sequence": int(fields[stop_sequence_index]), + } + ) + + # Sort stops by sequence for each trip + for trip_id in new_cache: + new_cache[trip_id].sort(key=lambda x: x["stop_sequence"]) + + _stop_times_cache = new_cache + _stop_times_cache_update = datetime.now(timezone.utc) + logger.info(f"Updated stop times cache with {len(_stop_times_cache)} trips") + + except Exception as e: + logger.error(f"Error loading stop times cache: {e}")