Skip to content

Commit

Permalink
validate with multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
JordanLaserGit committed Jan 19, 2024
1 parent 50062c0 commit 8afa2e1
Showing 1 changed file with 57 additions and 15 deletions.
72 changes: 57 additions & 15 deletions python/run_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import geopandas
import pandas as pd
from datetime import datetime

def validate(catchments,realization_file=None):

import concurrent.futures as cf

def validate_realization(realization_file):
"""
Validates
1) Realization files meets pydantic model as defined in ngen-cal
2) Paths given in file exist
"""
relative_dir = os.path.dirname(os.path.dirname(realization_file))

print(f'Done\nValidating {realization_file}')
Expand All @@ -16,16 +21,22 @@ def validate(catchments,realization_file=None):
val = validate_paths(serialized_realization)
if len(val) > 0:
raise Exception(f'{val[0].value} does not exist!')

return serialized_realization, relative_dir

def validate_forcings(forcing_files, catchments):
"""
Validates
1) forcing file names match realization file description
2) forcing files exist for each catchment in geojson
3) forcingg start/end times and interval match realization file
"""

print(f'Done\nValidating individual catchment forcing paths')
foring_dir = os.path.join(relative_dir,serialized_realization.global_config.forcing.path)
forcing_files = sorted([x for _,_,x in os.walk(foring_dir)][0])
ncatchments = len(catchments)
catchments = sorted(catchments)
ncatchments = len(catchments)
write_int = 5000
for j, jcatch in enumerate(catchments):
if (j + 1) % write_int == 0:
print(f'{j/ncatchments:.1f}%', end = "\r")
print(f'{100*j/ncatchments:.1f}%', end = "\r")
jid = re.findall(r'\d+', jcatch)[0]
pattern = serialized_realization.global_config.forcing.file_pattern
jcatch_pattern = pattern.replace('{{id}}',jid)
Expand All @@ -38,7 +49,7 @@ def validate(catchments,realization_file=None):
start_time = serialized_realization.time.start_time
end_time = serialized_realization.time.end_time
dt_s = serialized_realization.time.output_interval
full_path = os.path.join(foring_dir,forcing_files[0])
full_path = os.path.join(forcing_dir,forcing_files[0])
df = pd.read_csv(full_path)
forcings_start = datetime.strptime(df['time'].iloc[0],'%Y-%m-%d %H:%M:%S')
forcings_end = datetime.strptime(df['time'].iloc[-1],'%Y-%m-%d %H:%M:%S')
Expand All @@ -47,8 +58,6 @@ def validate(catchments,realization_file=None):
assert end_time == forcings_end, f"Realization end time {end_time} does not match forcing end time {forcings_end}"
assert dt_s == dt_forcings_s, f"Realization output_interval {dt_s} does not match forcing time axis {dt_forcings_s}"

print(f'\nNGen run folder is valid\n')

def validate_data_dir(data_dir):

forcing_files = []
Expand Down Expand Up @@ -94,9 +103,42 @@ def validate_data_dir(data_dir):
print(f'Configurations found! Retrieving catchment data...')

catchments = geopandas.read_file(geopackage_file, layer='divides')
catchment_list = list(catchments['divide_id'])

validate(catchment_list,realization_file)
catchment_list = sorted(list(catchments['divide_id']))

global serialized_realization
serialized_realization, relative_dir = validate_realization(realization_file)

print(f'Done\nValidating individual catchment forcing paths')
global forcing_dir
forcing_dir = os.path.join(relative_dir,serialized_realization.global_config.forcing.path)
forcing_files = sorted([x for _,_,x in os.walk(forcing_dir)][0])

nprocs = os.cpu_count()
forcing_files_list = []
catchment_list_list = []
realization_file_list = []
ncatchments = len(catchment_list)
nper = ncatchments // nprocs
nleft = ncatchments - (nper * nprocs)
i = 0
k = 0
for _ in range(nprocs):
realization_file_list.append(realization_file)
k = nper + i + nleft
jfiles = forcing_files[i:k]
jcatchments = catchment_list[i:k]
forcing_files_list.append(jfiles)
catchment_list_list.append(jcatchments)
i = k

with cf.ProcessPoolExecutor() as pool:
for results in pool.map(
validate_forcings,
forcing_files_list,
catchment_list_list):
pass

print(f'\nNGen run folder is valid\n')

if __name__ == "__main__":

Expand Down

0 comments on commit 8afa2e1

Please sign in to comment.