Skip to content

Commit

Permalink
shared weights for speed
Browse files Browse the repository at this point in the history
  • Loading branch information
JordanLaserGit committed Feb 8, 2024
1 parent 19edd9e commit 9e00058
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 94 deletions.
127 changes: 43 additions & 84 deletions forcingprocessor/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,88 +1,47 @@
affine==2.4.0
asciitree==0.3.3
attrs==23.1.0
beautifulsoup4==4.12.2
bokeh==3.2.0
boto3==1.28.3
botocore==1.31.3
Bottleneck==1.3.7
certifi==2023.5.7
cftime==1.6.2
charset-normalizer==3.2.0
click==8.1.5
click-plugins==1.1.1
cligj==0.7.2
cloudpickle==2.2.1
contourpy==1.1.0
cycler==0.11.0
dask==2023.7.0
distributed==2023.7.0
docopt==0.6.2
entrypoints==0.4
fasteners==0.18
Fiona==1.9.4.post1
flox==0.7.2
fonttools==4.43.0
fsspec==2023.6.0
geopandas==0.13.2
h5netcdf==1.2.0
h5py==3.9.0
idna==3.4
importlib-metadata==6.8.0
importlib-resources==6.0.0
Jinja2==3.1.3
aiobotocore==2.11.2
aiohttp==3.9.3
aioitertools==0.11.0
aiosignal==1.3.1
attrs==23.2.0
boto3==1.34.36
botocore==1.34.34
cachetools==5.3.2
certifi==2024.2.2
charset-normalizer==3.3.2
decorator==5.1.1
frozenlist==1.4.1
fsspec==2024.2.0
gcsfs==2024.2.0
google-api-core==2.16.2
google-auth==2.27.0
google-auth-oauthlib==1.2.0
google-cloud-core==2.4.1
google-cloud-storage==2.14.0
google-crc32c==1.5.0
google-resumable-media==2.7.0
googleapis-common-protos==1.62.0
idna==3.6
jmespath==1.0.1
kiwisolver==1.4.4
llvmlite==0.40.1
locket==1.0.0
lz4==4.3.2
MarkupSafe==2.1.3
matplotlib==3.7.2
msgpack==1.0.5
nc-time-axis==1.4.1
netCDF4==1.6.4
numba==0.57.1
numbagg==0.2.2
numcodecs==0.11.0
numpy==1.24.4
numpy-groupies==0.9.22
nwmurl==0.1.5
packaging==23.1
pandas==2.0.3
partd==1.4.0
pathlib==1.0.1
Pillow==10.0.0
platformdirs==3.8.1
pooch==1.7.0
psutil==5.9.5
pyarrow==14.0.1
pydap==3.4.1
pyparsing==3.0.9
pyproj==3.6.0
multidict==6.0.5
numpy==1.26.4
oauthlib==3.2.2
packaging==23.2
pandas==2.2.0
protobuf==4.25.2
psutil==5.9.8
pyarrow==15.0.0
pyasn1==0.5.1
pyasn1-modules==0.3.0
python-dateutil==2.8.2
pytz==2023.3
PyYAML==6.0
rasterio==1.3.8
pytz==2024.1
requests==2.31.0
rioxarray==0.14.1
s3fs==0.4.2
s3transfer==0.6.1
scipy==1.11.1
seaborn==0.12.2
shapely==2.0.1
requests-oauthlib==1.3.1
rsa==4.9
s3fs==2024.2.0
s3transfer==0.10.0
six==1.16.0
snuggs==1.4.7
sortedcontainers==2.4.0
soupsieve==2.4.1
tblib==2.0.0
toolz==0.12.0
tornado==6.3.2
tqdm==4.66.1
tzdata==2023.3
urllib3==1.26.16
WebOb==1.8.7
xarray==2023.6.0
xyzservices==2023.7.0
zarr==2.15.0
zict==3.0.0
zipp==3.16.1
tzdata==2023.4
urllib3==2.0.7
wrapt==1.16.0
xarray==2024.1.1
yarl==1.9.4
20 changes: 10 additions & 10 deletions forcingprocessor/src/forcingprocessor/forcingprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,31 +120,32 @@ def multiprocess_data_extract(files,nprocs,crosswalk_dict,fs):
Sets up the multiprocessing pool for forcing_grid2catchment and returns the data and time axis ordered in time
"""
launch_time = 2.5
cycle_time = 48
launch_time = 0.1
cycle_time = 35
files_per_cycle = 1
files_per_proc = distribute_work(files,nprocs)
files_per_proc = load_balance(files_per_proc,launch_time,cycle_time,files_per_cycle)
nprocs = len(files_per_proc)

start = 0
nfiles = len(files)
crosswalk_dict_list = []
files_list = []
fs_list = []
for i in range(nprocs):
end = min(start + files_per_proc[i],nfiles)
crosswalk_dict_list.append(crosswalk_dict)
files_list.append(files[start:end])
fs_list.append(fs)
start = end

def init_pool(the_data):
global weights_json
weights_json = the_data

data_ax = []
t_ax_local = []
with cf.ProcessPoolExecutor(max_workers=nprocs) as pool:
with cf.ProcessPoolExecutor(max_workers=nprocs, initializer=init_pool, initargs=(crosswalk_dict,)) as pool:
for results in pool.map(
forcing_grid2catchment,
crosswalk_dict_list,
files_list,
fs_list
):
Expand All @@ -159,12 +160,11 @@ def multiprocess_data_extract(files,nprocs,crosswalk_dict,fs):

return data_array, t_ax_local

def forcing_grid2catchment(crosswalk_dict: dict, nwm_files: list, fs=None):
def forcing_grid2catchment(nwm_files: list, fs=None):
"""
General function to retrieve catchment level data from national water model files
Inputs:
crosswalk_dict: dict of catchments to use as indices
nwm_files: list of filenames (urls for remote, local paths otherwise),
fs: an optional file system for cloud storage reads
Expand Down Expand Up @@ -219,10 +219,10 @@ def forcing_grid2catchment(crosswalk_dict: dict, nwm_files: list, fs=None):

t0 = time.perf_counter()
data_allvars = data_allvars.reshape(nvar, shp[1] * shp[2])
ncatch = len(crosswalk_dict)
ncatch = len(weights_json)
data_array = np.zeros((nvar,ncatch), dtype=np.float32)
jcatch = 0
for key, value in crosswalk_dict.items():
for key, value in weights_json.items():
weights = value[0]
coverage = np.array(value[1])
coverage_mat = np.repeat(coverage[None,:],nvar,axis=0)
Expand Down

0 comments on commit 9e00058

Please sign in to comment.