Skip to content

Commit

Permalink
feat(L3)!: switch to using dask
Browse files Browse the repository at this point in the history
The implemented routine acts on groups of months as independent entities. While on could believe using a more pandas-native approach like resample, this seems not to be the case; probably because the data I used was date-sorted, such that the data of each grid cell is scattered throughout the entire input dataset.
  • Loading branch information
j-haacker committed Mar 11, 2024
1 parent 62a27c4 commit 7e89944
Show file tree
Hide file tree
Showing 2 changed files with 151 additions and 69 deletions.
85 changes: 68 additions & 17 deletions cryoswath/l2.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dask import dataframe as dd
import geopandas as gpd
from multiprocessing import Pool
import numpy as np
Expand All @@ -18,6 +19,7 @@
def from_id(track_idx: pd.DatetimeIndex|str, *,
reprocess: bool = True,
save_or_return: str = "both",
cache: str = None,
cores: int = len(os.sched_getaffinity(0)),
**kwargs) -> tuple[gpd.GeoDataFrame]:
# this function collects processed data and processes the remaining.
Expand All @@ -28,13 +30,13 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
if not isinstance(track_idx, pd.DatetimeIndex):
track_idx = pd.DatetimeIndex(track_idx if isinstance(track_idx, list) else [track_idx])
if track_idx.tz == None:
track_idx.tz_localize("UTC")
track_idx = track_idx.tz_localize("UTC")
# somehow the download thread prevents the processing of tracks. it may
# be due to GIL lock. for now, it is just disabled, so one has to
# download in advance. on the fly is always possible, however, with
# parallel processing this can lead to issues because ESA blocks ftp
# connections if there are too many.
print("Note that you can speed up processing substantially by previously downloading the L1b data.")
print("[note] You can speed up processing substantially by previously downloading the L1b data.")
# stop_event = Event()
# download_thread = Thread(target=l1b.download_wrapper,
# kwargs=dict(track_idx=track_idx, num_processes=8, stop_event=stop_event),
Expand All @@ -43,10 +45,36 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
# download_thread.start()
try:
start_datetime, end_datetime = track_idx.sort_values()[[0,-1]]
# ! below will not return data that is cached, even if save_or_return="both"
# this is a flaw in the current logic. rework.
if cache is not None and save_or_return != "return":
try:
with pd.HDFStore(cache, "r") as hdf:
cached = hdf.select("poca", columns=[])
# for better performance: reduce indices to two per month
sample_rate_ns = int(15*(24*60*60)*1e9)
tmp = cached.index.astype("int64")//sample_rate_ns
tmp = pd.arrays.DatetimeArray(np.append(np.unique(tmp)*sample_rate_ns,
# adding first and last element
# included for debugging. on default, at least adding the last
# index should not be added to prevent missing data
cached.index[[0,-1]].astype("int64")))
skip_months = np.unique(tmp.normalize()+pd.DateOffset(day=1))
# print(skip_months)
del cached
except (OSError, KeyError) as err:
if isinstance(err, KeyError):
warnings.warn(f"Removed cache because of KeyError (\"{str(err)}\").")
os.remove(cache)
skip_months = np.empty(0)
swath_list = []
poca_list = []
kwargs["cs_full_file_names"] = load_cs_full_file_names(update="no")
for current_month in pd.date_range(start_datetime.normalize()-pd.offsets.MonthBegin(), end_datetime, freq="MS"):
for current_month in pd.date_range(start_datetime.normalize()-pd.DateOffset(day=1),
end_datetime, freq="MS"):
if cache is not None and save_or_return != "return" and current_month.tz_localize(None) in skip_months:
print("Skipping cached month", current_month.strftime("%Y-%m"))
continue
current_subdir = current_month.strftime(f"%Y{os.path.sep}%m")
l2_paths = pd.DataFrame(columns=["swath", "poca"])
for l2_type in ["swath", "poca"]:
Expand All @@ -58,22 +86,45 @@ def from_id(track_idx: pd.DatetimeIndex|str, *,
else:
os.makedirs(os.path.join(data_path, f"L2_{l2_type}", current_subdir))
print("start processing", current_month)
with Pool(processes=cores) as p:
# function is defined at the bottom of this module
collective_swath_poca_list = p.starmap(
process_track,
[(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs)
for idx
# indices per month with work-around :/ should be easier
in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index],
chunksize=1)
if cores > 1:
with Pool(processes=cores) as p:
# function is defined at the bottom of this module
collective_swath_poca_list = p.starmap(
process_track,
[(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs) for idx
# indices per month with work-around :/ should be easier
in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index],
chunksize=1)
else:
collective_swath_poca_list = []
for idx in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index:
collective_swath_poca_list.append(process_track(idx, reprocess, l2_paths, save_or_return,
current_subdir, kwargs))
if cache is not None:
for l2_type, i in zip(["swath", "poca"], [0, 1]):
l2_data = pd.concat([item[i] for item in collective_swath_poca_list])
if l2_type == "swath":
l2_data.index = l2_data.index.get_level_values(0).astype(np.int64) \
+ l2_data.index.get_level_values(1)
l2_data.rename_axis("time", inplace=True)
l2_data = pd.DataFrame(index=l2_data.index,
data=pd.concat([l2_data.h_diff, l2_data.geometry.get_coordinates()],
axis=1, copy=False))
l2_data.astype(dict(h_diff=np.float32, x=np.int32, y=np.int32)).to_hdf(cache, key=l2_type, mode="a", append=True, format="table")
if save_or_return != "save":
for swath_poca_tuple in collective_swath_poca_list: # .get()
swath_list.append(swath_poca_tuple[0])
poca_list.append(swath_poca_tuple[1])
swath_list.append(pd.concat([item[0] for item in collective_swath_poca_list]))
poca_list.append(pd.concat([item[1] for item in collective_swath_poca_list]))
print("done processing", current_month)
if save_or_return != "save":
return pd.concat(swath_list), pd.concat(poca_list)
if swath_list == []:
swath_list = pd.DataFrame()
else:
swath_list = pd.concat(swath_list)
if poca_list == []:
poca_list = pd.DataFrame()
else:
poca_list = pd.concat(poca_list)
return swath_list, poca_list
except:
# print("Waiting for download threads to join.")
# stop_event.set()
Expand Down Expand Up @@ -181,7 +232,7 @@ def process_and_save(region_of_interest: str|shapely.Polygon,


# local helper function. can't be defined where it is needed because of namespace issues
def process_track(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs):
def process_track(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs):
print("getting", idx)
# print("kwargs", wargs)
try:
Expand Down
135 changes: 83 additions & 52 deletions cryoswath/l3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import dask.dataframe
from dateutil.relativedelta import relativedelta
import geopandas as gpd
# import numba
import numpy as np
import os
import pandas as pd
Expand All @@ -9,56 +10,109 @@
from .misc import *

__all__ = list()


# numba does not do help here easily. using the numpy functions is as fast as it gets.
def med_iqr_cnt(data):
quartiles = np.quantile(data, [.25, .5, .75])
return pd.DataFrame([[quartiles[1], quartiles[2]-quartiles[0], len(data)]], columns=["_median", "_iqr", "_count"])
__all__.append("med_iqr_cnt")


def build_dataset(region_of_interest: str|shapely.Polygon,
start_datetime: str|pd.Timestamp,
end_datetime: str|pd.Timestamp, *,
aggregation_period: relativedelta = relativedelta(months=3),
timestep: relativedelta = relativedelta(months=1),
l2_type: str = "swath",
timestep_months: int = 1,
window_ntimesteps: int = 3,
spatial_res_meter: float = 500,
**kwargs):
agg_func_and_meta: tuple[callable, dict] = (med_iqr_cnt,
{"_median": "f8", "_iqr": "f8", "_count": "i8"}),
**l2_from_id_kwargs):
if window_ntimesteps%2 - 1:
old_window = window_ntimesteps
window_ntimesteps = (window_ntimesteps//2+1)
warnings.warn(f"The window should be a uneven number of time steps. You asked for {old_window}, but it has "+ f"been changed to {window_ntimesteps}.")
# ! end time step should be included.
start_datetime, end_datetime = pd.to_datetime([start_datetime, end_datetime])
print("Building a gridded dataset of elevation estimates for the region",
f"{region_of_interest} from {start_datetime} to {end_datetime} for",
f"a rolling window of {aggregation_period} every {timestep}.")
# if len(aggregation_period.kwds.keys()) != 1 \
# or len(timestep.kwds.keys()) != 1 \
# or list(aggregation_period.kwds.keys())[0] not in ["years", "months", "days"] \
# or list(timestep.kwds.keys())[0] not in ["years", "months", "days"]:
# raise Exception("Only use one of years, months, days for agg_time and timestep.")
print("Building a gridded dataset of elevation estimates for",
"the region "+region_of_interest if isinstance(region_of_interest, str) else "a custom area",
f"from {start_datetime} to {end_datetime} every {timestep_months} for",
f"a rolling window of {window_ntimesteps} time steps.")
if "buffer_region_by" not in locals():
# buffer_by defaults to 30 km to not miss any tracks. Usually,
# 10 km should do.
buffer_region_by = 30_000
time_buffer = (aggregation_period-timestep)/2
time_buffer_months = (window_ntimesteps*timestep_months)//2
ext_t_axis = pd.date_range(start_datetime-pd.DateOffset(months=time_buffer_months),
end_datetime+pd.DateOffset(months=time_buffer_months),
freq=f"{timestep_months}MS",
).astype("int64")
cs_tracks = load_cs_ground_tracks(region_of_interest, start_datetime, end_datetime,
buffer_period_by=time_buffer,buffer_region_by=buffer_region_by)
buffer_period_by=relativedelta(months=time_buffer_months),
buffer_region_by=buffer_region_by)
print("First and last available ground tracks are on",
f"{cs_tracks.index[0]} and {cs_tracks.index[-1]}, respectively.,",
f"{cs_tracks.shape[0]} tracks in total.")
print("Run update_cs_ground_tracks, optionally with `full=True` or",
f"{cs_tracks.shape[0]} tracks in total."
"\n[note] Run update_cs_ground_tracks, optionally with `full=True` or",
"`incremental=True`, if you local ground tracks store is not up to",
"date. Consider pulling the latest version from the repository.")
# I believe passing loading l2 data to the function prevents copying
# on .drop. an alternative would be to define l2_data nonlocal
# within the gridding function
l3_data = med_mad_cnt_grid(l2.from_id(cs_tracks.index, **filter_kwargs(l2.from_id, kwargs)),
start_datetime=start_datetime, end_datetime=end_datetime,
aggregation_period=aggregation_period, timestep=timestep,
spatial_res_meter=spatial_res_meter)
l3_data.to_netcdf(build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period))

print("Storing the essential L2 data in hdf5, downloading and",
"processing L1b files if not available...")
if isinstance(region_of_interest, str):
region_id = region_of_interest
else:
region_id = "_".join([region_of_interest.centroid.x, region_of_interest.centroid.y])
cache_path = os.path.join(data_path, "tmp", region_id)
l2.from_id(cs_tracks.index, save_or_return="save", cache=cache_path,
**filter_kwargs(l2.from_id, l2_from_id_kwargs, blacklist=["save_or_return", "cache"]))

print("Gridding the data...")
# one could drop some of the data before gridding. however, excluding
# off-glacier data is expensive and filtering large differences to the
# DEM can hide issues while statistics like the median and the IQR
# should be fairly robust.
l2_ddf = dask.dataframe.read_hdf(cache_path, l2_type, sorted_index=True)
l2_ddf = l2_ddf.loc[ext_t_axis[0]:ext_t_axis[-1]]
l2_ddf = l2_ddf.repartition(npartitions=3*len(os.sched_getaffinity(0)))

l2_ddf[["x", "y"]] = l2_ddf[["x", "y"]]//spatial_res_meter*spatial_res_meter
l2_ddf["roll_0"] = l2_ddf.index.map_partitions(pd.cut, bins=ext_t_axis, right=False, labels=False, include_lowest=True)
for i in range(1, window_ntimesteps):
l2_ddf[f"roll_{i}"] = l2_ddf.map_partitions(lambda df: df.roll_0-i).persist()
for i in range(window_ntimesteps):
l2_ddf[f"roll_{i}"] = l2_ddf[f"roll_{i}"].map_partitions(lambda series: series.astype("i4")//window_ntimesteps)

roll_res = [None]*window_ntimesteps
for i in range(window_ntimesteps):
roll_res[i] = l2_ddf.rename(columns={f"roll_{i}": "time_idx"}).groupby(["time_idx", "x", "y"], sort=False).h_diff.apply(agg_func_and_meta[0], meta=agg_func_and_meta[1]).persist()
for i in range(window_ntimesteps):
roll_res[i] = roll_res[i].compute().droplevel(3, axis=0)
roll_res[i].index = roll_res[i].index.set_levels(
(roll_res[i].index.levels[0]*window_ntimesteps+i+1), level=0).rename("time", level=0)

l3_data = pd.concat(roll_res).sort_index()\
.loc[(slice(0,len(ext_t_axis)-1),slice(None),slice(None)),:]
l3_data.index = l3_data.index.remove_unused_levels()
l3_data.index = l3_data.index.set_levels(
ext_t_axis[l3_data.index.levels[0]].astype("datetime64[ns]"), level=0)
l3_data = l3_data.query(f"time >= '{start_datetime}' and time <= '{end_datetime}'")
l3_data.to_xarray().to_netcdf(build_path(region_id, timestep_months, spatial_res_meter))
return l3_data
__all__.append("build_dataset")


def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period):
region_id = find_region_id(region_of_interest)
if list(timestep.kwds.values())[0]!=1:
timestep_str = str(list(timestep.kwds.values())[0])+"-"
def build_path(region_of_interest, timestep_months, spatial_res_meter, aggregation_period):
if not isinstance(region_of_interest, str):
region_id = find_region_id(region_of_interest)
else:
region_id = region_of_interest
if timestep_months != 1:
timestep_str = str(timestep_months)+"-"
else:
timestep_str = ""
timestep_str += list(timestep.kwds.keys())[0][:-1]+"ly"
timestep_str += "monthly"
if spatial_res_meter == 1000:
spatial_res_str = "1km"
elif np.floor(spatial_res_meter/1000) < 2:
Expand All @@ -69,29 +123,6 @@ def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_peri
return os.path.join(data_path, "L3", "_".join(
[region_id, timestep_str, spatial_res_str+".nc"]))
__all__.append("build_path")


def med_mad_cnt_grid(l2_data: gpd.GeoDataFrame, *,
start_datetime: pd.Timestamp,
end_datetime: pd.Timestamp,
aggregation_period: relativedelta,
timestep: relativedelta,
spatial_res_meter: float):
def stats(data: pd.Series) -> pd.Series:
median = data.median()
mad = np.abs(data-median).median()
return pd.Series([median, mad, data.shape[0]])
time_axis = pd.date_range(start_datetime+pd.offsets.MonthBegin(0), end_datetime, freq=timestep)
if time_axis.tz == None: time_axis = time_axis.tz_localize("UTC")
# if l2_data.index[0].tz == None: l2_data.index = l2_data.index.tz_localize("UTC")
def rolling_stats(data):
results_list = [None]*aggregation_period.months
for i in range(aggregation_period.months):
results_list[i] = data.groupby(subset.index.get_level_values("time")-pd.offsets.QuarterBegin(1, normalize=True)+pd.DateOffset(months=i)).apply(stats)
result = pd.concat(results_list).unstack().sort_index().rename(columns={0: "med_elev_diff", 1: "mad_elev_diff", 2: "cnt_elev_diff"})#, inplace=True
return result.loc[time_axis.join(result.index, how="inner")]
return l2.grid(l2_data, spatial_res_meter, rolling_stats).to_xarray()
__all__.append("med_mad_cnt_grid")


__all__ = sorted(__all__)

0 comments on commit 7e89944

Please sign in to comment.