diff --git a/backend/models/arrival_history.py b/backend/models/arrival_history.py index 0117277a..a8b1c813 100644 --- a/backend/models/arrival_history.py +++ b/backend/models/arrival_history.py @@ -174,9 +174,8 @@ def get_by_date(agency_id: str, route_id: str, d: date, version = DefaultVersion mtime = os.stat(cache_path).st_mtime now = time.time() if now - mtime < 86400: - with open(cache_path, "r") as f: - text = f.read() - return ArrivalHistory.from_data(json.loads(text)) + text = util.read_from_file(cache_path) + return ArrivalHistory.from_data(json.loads(text)) except FileNotFoundError as err: pass @@ -199,9 +198,7 @@ def get_by_date(agency_id: str, route_id: str, d: date, version = DefaultVersion if not cache_dir.exists(): cache_dir.mkdir(parents = True, exist_ok = True) - with open(cache_path, "w") as f: - f.write(r.text) - + util.write_to_file(cache_path, r.text) return ArrivalHistory.from_data(data) def save_for_date(history: ArrivalHistory, d: date, s3=False): @@ -217,8 +214,7 @@ def save_for_date(history: ArrivalHistory, d: date, s3=False): if not cache_dir.exists(): cache_dir.mkdir(parents = True, exist_ok = True) - with open(cache_path, "w") as f: - f.write(data_str) + util.write_to_file(cache_path, data_str) if s3: s3 = boto3.resource('s3') @@ -231,4 +227,4 @@ def save_for_date(history: ArrivalHistory, d: date, s3=False): ContentType='application/json', ContentEncoding='gzip', ACL='public-read' - ) \ No newline at end of file + ) diff --git a/backend/models/precomputed_stats.py b/backend/models/precomputed_stats.py index 5bb22db3..3a9be9b8 100644 --- a/backend/models/precomputed_stats.py +++ b/backend/models/precomputed_stats.py @@ -91,9 +91,8 @@ def get_precomputed_stats(agency_id, stat_id: str, d: date, start_time_str = Non cache_path = get_cache_path(agency_id, stat_id, d, start_time_str, end_time_str, scheduled, version) try: - with open(cache_path, "r") as f: - text = f.read() - return PrecomputedStats(json.loads(text)) + text = util.read_from_file(cache_path) + return PrecomputedStats(json.loads(text)) except FileNotFoundError as err: pass @@ -116,9 +115,7 @@ def get_precomputed_stats(agency_id, stat_id: str, d: date, start_time_str = Non if not cache_dir.exists(): cache_dir.mkdir(parents = True, exist_ok = True) - with open(cache_path, "w") as f: - f.write(r.text) - + util.write_to_file(cache_path, r.text) return PrecomputedStats(data) def get_time_range_path(start_time_str, end_time_str): @@ -172,9 +169,8 @@ def save_stats(agency_id, stat_id, d, start_time_str, end_time_str, scheduled, d cache_dir.mkdir(parents = True, exist_ok = True) print(f'saving to {cache_path}') - with open(cache_path, "w") as f: - f.write(data_str) + util.write_to_file(cache_path, data_str) if save_to_s3: s3 = boto3.resource('s3') s3_path = get_s3_path(agency_id, stat_id, d, start_time_str, end_time_str, scheduled) diff --git a/backend/models/routeconfig.py b/backend/models/routeconfig.py index df664c8d..1da61741 100644 --- a/backend/models/routeconfig.py +++ b/backend/models/routeconfig.py @@ -140,12 +140,11 @@ def route_list_from_data(data): mtime = os.stat(cache_path).st_mtime now = time.time() if now - mtime < 86400: - with open(cache_path, mode='r', encoding='utf-8') as f: - data_str = f.read() - try: - return route_list_from_data(json.loads(data_str)) - except Exception as err: - print(err) + data_str = util.read_from_file(cache_path, encoding='utf-8') + try: + return route_list_from_data(json.loads(data_str)) + except Exception as err: + print(err) except FileNotFoundError as err: pass @@ -168,8 +167,7 @@ def route_list_from_data(data): if not 'routes' in data: raise Exception("S3 object did not contain 'routes' key") - with open(cache_path, mode='w', encoding='utf-8') as f: - f.write(r.text) + util.write_to_file(cache_path, r.text, encoding='utf-8') return route_list_from_data(data) @@ -187,8 +185,7 @@ def save_routes(agency_id, routes, save_to_s3=False): cache_path = get_cache_path(agency_id) - with open(cache_path, "w") as f: - f.write(data_str) + util.write_to_file(cache_path, data_str) if save_to_s3: s3 = boto3.resource('s3') diff --git a/backend/models/timetables.py b/backend/models/timetables.py index 33c3eb42..d573db22 100644 --- a/backend/models/timetables.py +++ b/backend/models/timetables.py @@ -258,9 +258,8 @@ def match_schedule_to_actual_times(scheduled_times, actual_times, early_sec=60, def get_data_by_date_key(agency_id: str, route_id: str, date_key: str, version = DefaultVersion) -> Timetable: cache_path = get_cache_path(agency_id, route_id, date_key, version) try: - with open(cache_path, "r") as f: - text = f.read() - return json.loads(text) + text = util.read_from_file(cache_path) + return json.loads(text) except FileNotFoundError as err: pass @@ -283,9 +282,7 @@ def get_data_by_date_key(agency_id: str, route_id: str, date_key: str, version = if not cache_dir.exists(): cache_dir.mkdir(parents = True, exist_ok = True) - with open(cache_path, "w") as f: - f.write(r.text) - + util.write_to_file(cache_path, r.text) return data def get_cache_path(agency_id, route_id, date_key, version = DefaultVersion): @@ -318,9 +315,9 @@ def get_date_keys(agency_id, version = DefaultVersion): cache_path = get_date_keys_cache_path(agency_id, version) try: - with open(cache_path, "r") as f: - data = json.loads(f.read()) - return data['date_keys'] + data = util.read_from_file(cache_path) + data = json.loads(data) + return data['date_keys'] except FileNotFoundError as err: pass @@ -343,9 +340,7 @@ def get_date_keys(agency_id, version = DefaultVersion): if not cache_dir.exists(): cache_dir.mkdir(parents = True, exist_ok = True) - with open(cache_path, "w") as f: - f.write(r.text) - + util.write_to_file(cache_path, r.text) return data['date_keys'] def get_date_keys_cache_path(agency_id, version = DefaultVersion): diff --git a/backend/models/util.py b/backend/models/util.py index 2dcbb255..68c847e0 100644 --- a/backend/models/util.py +++ b/backend/models/util.py @@ -1,8 +1,34 @@ from datetime import datetime, date, timedelta import os import pytz +import fcntl import numpy as np + +def read_from_file(filepath=None, mode='r', **kwargs) -> str: + with open(filepath, mode, **kwargs) as f: + text = f.read() + return text + + +def write_to_file(filepath=None, data=None, mode='w', **kwargs) -> int: + ''' + Concurrency safe function for writing data to the file. + Param mode: file open mode ('a' for appending, 'w' for writing) + ''' + assert mode == 'w' or mode == 'a' + + with open(filepath, mode, **kwargs) as f: + fcntl.flock(f, fcntl.LOCK_EX) + written_characters = f.write(data) + fcntl.flock(f, fcntl.LOCK_UN) + return written_characters + + +def append_to_file(filepath=None, data=None, mode='a', **kwargs): + return write_to_file(filepath, data, mode, **kwargs) + + def quantile_sorted(sorted_arr, quantile): # For small arrays (less than about 4000 items) np.quantile is significantly # slower than sorting the array and picking the quantile out by index. Computing @@ -121,4 +147,4 @@ def get_intervals(start_time, end_time, interval_length): )) rounded_start_time = new_start_time - return time_str_intervals \ No newline at end of file + return time_str_intervals