Skip to content

Commit

Permalink
tarball algorithm change
Browse files Browse the repository at this point in the history
  • Loading branch information
JordanLaserGit committed Jan 5, 2024
1 parent b66623d commit d9355cc
Showing 1 changed file with 61 additions and 53 deletions.
114 changes: 61 additions & 53 deletions forcingprocessor/src/forcingprocessor/forcingprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ def multiprocess_data_extract(files,nprocs,crosswalk_dict,fs):

def forcing_grid2catchment(crosswalk_dict: dict, nwm_files: list, fs=None):
"""
General function to read either remote or local nwm forcing files.
General function to retrieve catchment level data from national water model files
Inputs:
wgt_file: a path to the weights json,
filelist: list of filenames (urls for remote, local paths otherwise),
crosswalk_dict: dict of catchments to use as indices
nwm_files: list of filenames (urls for remote, local paths otherwise),
fs: an optional file system for cloud storage reads
Outputs:
Expand Down Expand Up @@ -226,9 +226,9 @@ def forcing_grid2catchment(crosswalk_dict: dict, nwm_files: list, fs=None):
if ii_verbose: print(f'{id} completed data extraction, returning data to primary process')
return [data_list, t_list]

def threaded_write_fun(data,t_ax,catchments,nprocs,output_bucket,out_path,ii_append):
def multiprocess_write(data,t_ax,catchments,nprocs,output_bucket,out_path,ii_append):
"""
Sets up the thread pool for write_data
Sets up the process pool for write_data
"""

Expand All @@ -248,7 +248,7 @@ def threaded_write_fun(data,t_ax,catchments,nprocs,output_bucket,out_path,ii_app
append_list = []
print_list = []
bucket_list = []
nthread_list = []
nprocess_list = []
worker_time_list = []
worker_data_list = []
worker_catchment_list = []
Expand All @@ -273,7 +273,7 @@ def threaded_write_fun(data,t_ax,catchments,nprocs,output_bucket,out_path,ii_app
append_list.append(ii_append)
print_list.append(ii_print)
worker_time_list.append(t_ax)
nthread_list.append(nprocs)
nprocess_list.append(nprocs)
bucket_list.append(output_bucket)

worker_catchments = {}
Expand All @@ -294,7 +294,7 @@ def threaded_write_fun(data,t_ax,catchments,nprocs,output_bucket,out_path,ii_app
out_path_list,
append_list,
print_list,
nthread_list
nprocess_list
):
ids.append(results[0])
dfs.append(results[1])
Expand All @@ -318,7 +318,7 @@ def write_data(
out_path,
ii_append,
ii_print,
nthreads
nprocesss
):
s3_client = boto3.session.Session().client("s3")

Expand Down Expand Up @@ -393,7 +393,7 @@ def write_data(
t_accum = time.perf_counter() - t00
rate = ((j+1)/t_accum)
bytes2bits = 8
bandwidth_Mbps = rate * file_size_MB *nthreads * bytes2bits
bandwidth_Mbps = rate * file_size_MB *nprocesss * bytes2bits
estimate_total_time = nfiles / rate
report_usage()
msg = f"\n{j+1} files written out of {nfiles}\n"
Expand All @@ -403,7 +403,7 @@ def write_data(
if storage_type.lower() == "s3": msg += f"put {t_put:.2f}s\n"
msg += f"estimated total write time {estimate_total_time:.2f}s\n"
msg += f"progress {(j+1)/nfiles*100:.2f}%\n"
msg += f"Bandwidth (all threads) {bandwidth_Mbps:.2f} Mbps"
msg += f"Bandwidth (all processs) {bandwidth_Mbps:.2f} Mbps"
print(msg)

return forcing_cat_ids, dfs, filenames
Expand Down Expand Up @@ -456,12 +456,12 @@ def prep_ngen_data(conf):
ii_verbose = conf["run"].get("verbose",False)
ii_collect_stats = conf["run"].get("collect_stats",True)
ii_tar = conf["run"].get("ii_tar",True)
proc_threads = conf["run"].get("proc_threads",None)
write_threads = conf["run"].get("write_threads",None)
proc_processs = conf["run"].get("proc_processs",None)
write_processs = conf["run"].get("write_processs",None)
nfile_chunk = conf["run"].get("nfile_chunk",None)

if proc_threads is None: proc_threads = int(os.cpu_count() * 0.8)
if write_threads is None: write_threads = os.cpu_count()
if proc_processs is None: proc_processs = int(os.cpu_count() * 0.8)
if write_processs is None: write_processs = os.cpu_count()
if nfile_chunk is None: nfile_chunk = 100000

if ii_verbose:
Expand Down Expand Up @@ -581,21 +581,22 @@ def prep_ngen_data(conf):
t0 = time.perf_counter()
if ii_verbose: print(f'Entering data extraction...\n')
# [data_array, t_ax] = forcing_grid2catchment(crosswalk_dict, jnwm_files, fs)
data_array, t_ax = multiprocess_data_extract(jnwm_files,proc_threads,crosswalk_dict,fs)
data_array, t_ax = multiprocess_data_extract(jnwm_files,proc_processs,crosswalk_dict,fs)
t_extract = time.perf_counter() - t0
complexity = (nfiles_tot * ncatchments) / 10000
score = complexity / t_extract
if ii_verbose: print(f'Data extract threads: {proc_threads:.2f}\nExtract time: {t_extract:.2f}\nComplexity: {complexity:.2f}\nScore: {score:.2f}\n', end=None)
if ii_verbose: print(f'Data extract processs: {proc_processs:.2f}\nExtract time: {t_extract:.2f}\nComplexity: {complexity:.2f}\nScore: {score:.2f}\n', end=None)

t0 = time.perf_counter()
out_path = (output_path/'forcings/').resolve()
if ii_verbose: print(f'Writing catchment forcings to {output_bucket} at {out_path}!', end=None)
forcing_cat_ids, dfs, filenames = threaded_write_fun(data_array,t_ax,crosswalk_dict.keys(),write_threads,output_bucket,out_path,ii_append)
forcing_cat_ids, dfs, filenames = multiprocess_write(data_array,t_ax,crosswalk_dict.keys(),write_processs,output_bucket,out_path,ii_append)


ii_append = True
write_time += time.perf_counter() - t0
write_rate = ncatchments / write_time
if ii_verbose: print(f'\n\nWrite threads: {write_threads}\nWrite time: {write_time:.2f}\nWrite rate {write_rate:.2f} files/second\n', end=None)
if ii_verbose: print(f'\n\nWrite processs: {write_processs}\nWrite time: {write_time:.2f}\nWrite rate {write_rate:.2f} files/second\n', end=None)

loop_time = time.perf_counter() - t00
if ii_verbose and nloops > 1: print(f'One loop took {loop_time:.2f} seconds. Estimated time to completion: {loop_time * (nloops - jloop):.2f}')
Expand Down Expand Up @@ -703,7 +704,9 @@ def prep_ngen_data(conf):

data_med = np.median(data_array,axis=0)
med_df = pd.DataFrame(data_med.T,columns=ngen_variables)
med_df.insert(0,"catchment id",forcing_cat_ids)
med_df.insert(0,"catchment id",forcing_cat_ids)

del data_array

# Save input config file and script commit
metadata_df = pd.DataFrame.from_dict(metadata)
Expand Down Expand Up @@ -751,43 +754,48 @@ def prep_ngen_data(conf):
if storage_type.lower() == 's3':
path = "/metadata/forcings_metadata/"
combined_tar_filename = 'forcings.tar.gz'
with tarfile.open(combined_tar_filename, 'w:gz') as combined_tar:
if ii_collect_stats:
buf = BytesIO()

filename = f"metadata." + output_file_type
metadata_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)

filename = f"catchments_avg." + output_file_type
avg_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)

filename = f"catchments_median." + output_file_type
med_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)

for j, jdf in enumerate(dfs):
jfilename = filenames[j]
with tempfile.NamedTemporaryFile() as tmpfile:
if output_file_type == "parquet":
jdf.to_parquet(tmpfile.name, index=False)
elif output_file_type == "csv":
jdf.to_csv(tmpfile.name, index=False)

combined_tar.add(tmpfile.name, arcname=jfilename)
else:
del dfs
path = str(metaf_path)
combined_tar_filename = str(forcing_path) + '/forcings.tar.gz'
with tarfile.open(combined_tar_filename, 'w:gz') as combined_tar:
if ii_collect_stats:
buf = BytesIO()

filename = f"metadata." + output_file_type
metadata_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)
tar_cmd = f'tar -czvf {combined_tar_filename} -C {forcing_path} .'
if ii_collect_stats: tar_cmd += f' -C {metaf_path} .'
os.system(tar_cmd)

filename = f"catchments_avg." + output_file_type
avg_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)

filename = f"catchments_median." + output_file_type
med_df.to_csv(buf, index=False)
buf.seek(0)
tarinfo = tarfile.TarInfo(name=path + filename)
tarinfo.size = len(buf.getvalue())
combined_tar.addfile(tarinfo, fileobj=buf)

for j, jdf in enumerate(dfs):
jfilename = filenames[j]
with tempfile.NamedTemporaryFile() as tmpfile:
if output_file_type == "parquet":
jdf.to_parquet(tmpfile.name, index=False)
elif output_file_type == "csv":
jdf.to_csv(tmpfile.name, index=False)

combined_tar.add(tmpfile.name, arcname=jfilename)
tar_time = time.perf_counter() - t0000

if storage_type == 'S3':
Expand Down

0 comments on commit d9355cc

Please sign in to comment.