Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix file write concurrency #578

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions backend/models/arrival_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,8 @@ def get_by_date(agency_id: str, route_id: str, d: date, version = DefaultVersion
mtime = os.stat(cache_path).st_mtime
now = time.time()
if now - mtime < 86400:
with open(cache_path, "r") as f:
text = f.read()
return ArrivalHistory.from_data(json.loads(text))
text = util.read_from_file(cache_path)
return ArrivalHistory.from_data(json.loads(text))
except FileNotFoundError as err:
pass

Expand All @@ -199,9 +198,7 @@ def get_by_date(agency_id: str, route_id: str, d: date, version = DefaultVersion
if not cache_dir.exists():
cache_dir.mkdir(parents = True, exist_ok = True)

with open(cache_path, "w") as f:
f.write(r.text)

util.write_to_file(cache_path, r.text)
return ArrivalHistory.from_data(data)

def save_for_date(history: ArrivalHistory, d: date, s3=False):
Expand All @@ -217,8 +214,7 @@ def save_for_date(history: ArrivalHistory, d: date, s3=False):
if not cache_dir.exists():
cache_dir.mkdir(parents = True, exist_ok = True)

with open(cache_path, "w") as f:
f.write(data_str)
util.write_to_file(cache_path, data_str)

if s3:
s3 = boto3.resource('s3')
Expand All @@ -231,4 +227,4 @@ def save_for_date(history: ArrivalHistory, d: date, s3=False):
ContentType='application/json',
ContentEncoding='gzip',
ACL='public-read'
)
)
12 changes: 4 additions & 8 deletions backend/models/precomputed_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,8 @@ def get_precomputed_stats(agency_id, stat_id: str, d: date, start_time_str = Non
cache_path = get_cache_path(agency_id, stat_id, d, start_time_str, end_time_str, scheduled, version)

try:
with open(cache_path, "r") as f:
text = f.read()
return PrecomputedStats(json.loads(text))
text = util.read_from_file(cache_path)
return PrecomputedStats(json.loads(text))
except FileNotFoundError as err:
pass

Expand All @@ -116,9 +115,7 @@ def get_precomputed_stats(agency_id, stat_id: str, d: date, start_time_str = Non
if not cache_dir.exists():
cache_dir.mkdir(parents = True, exist_ok = True)

with open(cache_path, "w") as f:
f.write(r.text)

util.write_to_file(cache_path, r.text)
return PrecomputedStats(data)

def get_time_range_path(start_time_str, end_time_str):
Expand Down Expand Up @@ -172,9 +169,8 @@ def save_stats(agency_id, stat_id, d, start_time_str, end_time_str, scheduled, d
cache_dir.mkdir(parents = True, exist_ok = True)

print(f'saving to {cache_path}')
with open(cache_path, "w") as f:
f.write(data_str)

util.write_to_file(cache_path, data_str)
if save_to_s3:
s3 = boto3.resource('s3')
s3_path = get_s3_path(agency_id, stat_id, d, start_time_str, end_time_str, scheduled)
Expand Down
17 changes: 7 additions & 10 deletions backend/models/routeconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,11 @@ def route_list_from_data(data):
mtime = os.stat(cache_path).st_mtime
now = time.time()
if now - mtime < 86400:
with open(cache_path, mode='r', encoding='utf-8') as f:
data_str = f.read()
try:
return route_list_from_data(json.loads(data_str))
except Exception as err:
print(err)
data_str = util.read_from_file(cache_path, encoding='utf-8')
try:
return route_list_from_data(json.loads(data_str))
except Exception as err:
print(err)
except FileNotFoundError as err:
pass

Expand All @@ -168,8 +167,7 @@ def route_list_from_data(data):
if not 'routes' in data:
raise Exception("S3 object did not contain 'routes' key")

with open(cache_path, mode='w', encoding='utf-8') as f:
f.write(r.text)
util.write_to_file(cache_path, r.text, encoding='utf-8')

return route_list_from_data(data)

Expand All @@ -187,8 +185,7 @@ def save_routes(agency_id, routes, save_to_s3=False):

cache_path = get_cache_path(agency_id)

with open(cache_path, "w") as f:
f.write(data_str)
util.write_to_file(cache_path, data_str)

if save_to_s3:
s3 = boto3.resource('s3')
Expand Down
19 changes: 7 additions & 12 deletions backend/models/timetables.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,8 @@ def match_schedule_to_actual_times(scheduled_times, actual_times, early_sec=60,
def get_data_by_date_key(agency_id: str, route_id: str, date_key: str, version = DefaultVersion) -> Timetable:
cache_path = get_cache_path(agency_id, route_id, date_key, version)
try:
with open(cache_path, "r") as f:
text = f.read()
return json.loads(text)
text = util.read_from_file(cache_path)
return json.loads(text)
except FileNotFoundError as err:
pass

Expand All @@ -283,9 +282,7 @@ def get_data_by_date_key(agency_id: str, route_id: str, date_key: str, version =
if not cache_dir.exists():
cache_dir.mkdir(parents = True, exist_ok = True)

with open(cache_path, "w") as f:
f.write(r.text)

util.write_to_file(cache_path, r.text)
return data

def get_cache_path(agency_id, route_id, date_key, version = DefaultVersion):
Expand Down Expand Up @@ -318,9 +315,9 @@ def get_date_keys(agency_id, version = DefaultVersion):
cache_path = get_date_keys_cache_path(agency_id, version)

try:
with open(cache_path, "r") as f:
data = json.loads(f.read())
return data['date_keys']
data = util.read_from_file(cache_path)
data = json.loads(data)
return data['date_keys']
except FileNotFoundError as err:
pass

Expand All @@ -343,9 +340,7 @@ def get_date_keys(agency_id, version = DefaultVersion):
if not cache_dir.exists():
cache_dir.mkdir(parents = True, exist_ok = True)

with open(cache_path, "w") as f:
f.write(r.text)

util.write_to_file(cache_path, r.text)
return data['date_keys']

def get_date_keys_cache_path(agency_id, version = DefaultVersion):
Expand Down
28 changes: 27 additions & 1 deletion backend/models/util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,34 @@
from datetime import datetime, date, timedelta
import os
import pytz
import fcntl
import numpy as np


def read_from_file(filepath=None, mode='r', **kwargs) -> str:
with open(filepath, mode, **kwargs) as f:
text = f.read()
return text


def write_to_file(filepath=None, data=None, mode='w', **kwargs) -> int:
'''
Concurrency safe function for writing data to the file.
Param mode: file open mode ('a' for appending, 'w' for writing)
'''
assert mode == 'w' or mode == 'a'

with open(filepath, mode, **kwargs) as f:
fcntl.flock(f, fcntl.LOCK_EX)
written_characters = f.write(data)
fcntl.flock(f, fcntl.LOCK_UN)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth explicitly closing file

return written_characters
Comment on lines +14 to +25
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This question comes from my lack of knowledge:
If another process tries to access file while it is locked does it lead to an IOError? Or does it automatically wait for lock to be released?



def append_to_file(filepath=None, data=None, mode='a', **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it is strictly append -- might be good to take out kwarg and just pass in mode='a' util.to write_to_file

return write_to_file(filepath, data, mode, **kwargs)


def quantile_sorted(sorted_arr, quantile):
# For small arrays (less than about 4000 items) np.quantile is significantly
# slower than sorting the array and picking the quantile out by index. Computing
Expand Down Expand Up @@ -121,4 +147,4 @@ def get_intervals(start_time, end_time, interval_length):
))
rounded_start_time = new_start_time

return time_str_intervals
return time_str_intervals