From 7e899446633bb1fe6c20f4cea6a02ac47bf3792d Mon Sep 17 00:00:00 2001 From: Jan Haacker <152862650+j-haacker@users.noreply.github.com> Date: Mon, 11 Mar 2024 10:44:39 +0100 Subject: [PATCH] feat(L3)!: switch to using dask The implemented routine acts on groups of months as independent entities. While on could believe using a more pandas-native approach like resample, this seems not to be the case; probably because the data I used was date-sorted, such that the data of each grid cell is scattered throughout the entire input dataset. --- cryoswath/l2.py | 85 ++++++++++++++++++++++++------ cryoswath/l3.py | 135 +++++++++++++++++++++++++++++------------------- 2 files changed, 151 insertions(+), 69 deletions(-) diff --git a/cryoswath/l2.py b/cryoswath/l2.py index 06d2b66..4f35e02 100644 --- a/cryoswath/l2.py +++ b/cryoswath/l2.py @@ -1,3 +1,4 @@ +from dask import dataframe as dd import geopandas as gpd from multiprocessing import Pool import numpy as np @@ -18,6 +19,7 @@ def from_id(track_idx: pd.DatetimeIndex|str, *, reprocess: bool = True, save_or_return: str = "both", + cache: str = None, cores: int = len(os.sched_getaffinity(0)), **kwargs) -> tuple[gpd.GeoDataFrame]: # this function collects processed data and processes the remaining. @@ -28,13 +30,13 @@ def from_id(track_idx: pd.DatetimeIndex|str, *, if not isinstance(track_idx, pd.DatetimeIndex): track_idx = pd.DatetimeIndex(track_idx if isinstance(track_idx, list) else [track_idx]) if track_idx.tz == None: - track_idx.tz_localize("UTC") + track_idx = track_idx.tz_localize("UTC") # somehow the download thread prevents the processing of tracks. it may # be due to GIL lock. for now, it is just disabled, so one has to # download in advance. on the fly is always possible, however, with # parallel processing this can lead to issues because ESA blocks ftp # connections if there are too many. - print("Note that you can speed up processing substantially by previously downloading the L1b data.") + print("[note] You can speed up processing substantially by previously downloading the L1b data.") # stop_event = Event() # download_thread = Thread(target=l1b.download_wrapper, # kwargs=dict(track_idx=track_idx, num_processes=8, stop_event=stop_event), @@ -43,10 +45,36 @@ def from_id(track_idx: pd.DatetimeIndex|str, *, # download_thread.start() try: start_datetime, end_datetime = track_idx.sort_values()[[0,-1]] + # ! below will not return data that is cached, even if save_or_return="both" + # this is a flaw in the current logic. rework. + if cache is not None and save_or_return != "return": + try: + with pd.HDFStore(cache, "r") as hdf: + cached = hdf.select("poca", columns=[]) + # for better performance: reduce indices to two per month + sample_rate_ns = int(15*(24*60*60)*1e9) + tmp = cached.index.astype("int64")//sample_rate_ns + tmp = pd.arrays.DatetimeArray(np.append(np.unique(tmp)*sample_rate_ns, + # adding first and last element + # included for debugging. on default, at least adding the last + # index should not be added to prevent missing data + cached.index[[0,-1]].astype("int64"))) + skip_months = np.unique(tmp.normalize()+pd.DateOffset(day=1)) + # print(skip_months) + del cached + except (OSError, KeyError) as err: + if isinstance(err, KeyError): + warnings.warn(f"Removed cache because of KeyError (\"{str(err)}\").") + os.remove(cache) + skip_months = np.empty(0) swath_list = [] poca_list = [] kwargs["cs_full_file_names"] = load_cs_full_file_names(update="no") - for current_month in pd.date_range(start_datetime.normalize()-pd.offsets.MonthBegin(), end_datetime, freq="MS"): + for current_month in pd.date_range(start_datetime.normalize()-pd.DateOffset(day=1), + end_datetime, freq="MS"): + if cache is not None and save_or_return != "return" and current_month.tz_localize(None) in skip_months: + print("Skipping cached month", current_month.strftime("%Y-%m")) + continue current_subdir = current_month.strftime(f"%Y{os.path.sep}%m") l2_paths = pd.DataFrame(columns=["swath", "poca"]) for l2_type in ["swath", "poca"]: @@ -58,22 +86,45 @@ def from_id(track_idx: pd.DatetimeIndex|str, *, else: os.makedirs(os.path.join(data_path, f"L2_{l2_type}", current_subdir)) print("start processing", current_month) - with Pool(processes=cores) as p: - # function is defined at the bottom of this module - collective_swath_poca_list = p.starmap( - process_track, - [(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs) - for idx - # indices per month with work-around :/ should be easier - in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index], - chunksize=1) + if cores > 1: + with Pool(processes=cores) as p: + # function is defined at the bottom of this module + collective_swath_poca_list = p.starmap( + process_track, + [(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs) for idx + # indices per month with work-around :/ should be easier + in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index], + chunksize=1) + else: + collective_swath_poca_list = [] + for idx in pd.Series(index=track_idx).loc[current_month:current_month+pd.offsets.MonthBegin(1)].index: + collective_swath_poca_list.append(process_track(idx, reprocess, l2_paths, save_or_return, + current_subdir, kwargs)) + if cache is not None: + for l2_type, i in zip(["swath", "poca"], [0, 1]): + l2_data = pd.concat([item[i] for item in collective_swath_poca_list]) + if l2_type == "swath": + l2_data.index = l2_data.index.get_level_values(0).astype(np.int64) \ + + l2_data.index.get_level_values(1) + l2_data.rename_axis("time", inplace=True) + l2_data = pd.DataFrame(index=l2_data.index, + data=pd.concat([l2_data.h_diff, l2_data.geometry.get_coordinates()], + axis=1, copy=False)) + l2_data.astype(dict(h_diff=np.float32, x=np.int32, y=np.int32)).to_hdf(cache, key=l2_type, mode="a", append=True, format="table") if save_or_return != "save": - for swath_poca_tuple in collective_swath_poca_list: # .get() - swath_list.append(swath_poca_tuple[0]) - poca_list.append(swath_poca_tuple[1]) + swath_list.append(pd.concat([item[0] for item in collective_swath_poca_list])) + poca_list.append(pd.concat([item[1] for item in collective_swath_poca_list])) print("done processing", current_month) if save_or_return != "save": - return pd.concat(swath_list), pd.concat(poca_list) + if swath_list == []: + swath_list = pd.DataFrame() + else: + swath_list = pd.concat(swath_list) + if poca_list == []: + poca_list = pd.DataFrame() + else: + poca_list = pd.concat(poca_list) + return swath_list, poca_list except: # print("Waiting for download threads to join.") # stop_event.set() @@ -181,7 +232,7 @@ def process_and_save(region_of_interest: str|shapely.Polygon, # local helper function. can't be defined where it is needed because of namespace issues -def process_track(idx, reprocess, l2_paths, save_or_return, data_path, current_subdir, kwargs): +def process_track(idx, reprocess, l2_paths, save_or_return, current_subdir, kwargs): print("getting", idx) # print("kwargs", wargs) try: diff --git a/cryoswath/l3.py b/cryoswath/l3.py index dab8e61..529b856 100644 --- a/cryoswath/l3.py +++ b/cryoswath/l3.py @@ -1,5 +1,6 @@ +import dask.dataframe from dateutil.relativedelta import relativedelta -import geopandas as gpd +# import numba import numpy as np import os import pandas as pd @@ -9,56 +10,109 @@ from .misc import * __all__ = list() + + +# numba does not do help here easily. using the numpy functions is as fast as it gets. +def med_iqr_cnt(data): + quartiles = np.quantile(data, [.25, .5, .75]) + return pd.DataFrame([[quartiles[1], quartiles[2]-quartiles[0], len(data)]], columns=["_median", "_iqr", "_count"]) +__all__.append("med_iqr_cnt") def build_dataset(region_of_interest: str|shapely.Polygon, start_datetime: str|pd.Timestamp, end_datetime: str|pd.Timestamp, *, - aggregation_period: relativedelta = relativedelta(months=3), - timestep: relativedelta = relativedelta(months=1), + l2_type: str = "swath", + timestep_months: int = 1, + window_ntimesteps: int = 3, spatial_res_meter: float = 500, - **kwargs): + agg_func_and_meta: tuple[callable, dict] = (med_iqr_cnt, + {"_median": "f8", "_iqr": "f8", "_count": "i8"}), + **l2_from_id_kwargs): + if window_ntimesteps%2 - 1: + old_window = window_ntimesteps + window_ntimesteps = (window_ntimesteps//2+1) + warnings.warn(f"The window should be a uneven number of time steps. You asked for {old_window}, but it has "+ f"been changed to {window_ntimesteps}.") + # ! end time step should be included. start_datetime, end_datetime = pd.to_datetime([start_datetime, end_datetime]) - print("Building a gridded dataset of elevation estimates for the region", - f"{region_of_interest} from {start_datetime} to {end_datetime} for", - f"a rolling window of {aggregation_period} every {timestep}.") - # if len(aggregation_period.kwds.keys()) != 1 \ - # or len(timestep.kwds.keys()) != 1 \ - # or list(aggregation_period.kwds.keys())[0] not in ["years", "months", "days"] \ - # or list(timestep.kwds.keys())[0] not in ["years", "months", "days"]: - # raise Exception("Only use one of years, months, days for agg_time and timestep.") + print("Building a gridded dataset of elevation estimates for", + "the region "+region_of_interest if isinstance(region_of_interest, str) else "a custom area", + f"from {start_datetime} to {end_datetime} every {timestep_months} for", + f"a rolling window of {window_ntimesteps} time steps.") if "buffer_region_by" not in locals(): # buffer_by defaults to 30 km to not miss any tracks. Usually, # 10 km should do. buffer_region_by = 30_000 - time_buffer = (aggregation_period-timestep)/2 + time_buffer_months = (window_ntimesteps*timestep_months)//2 + ext_t_axis = pd.date_range(start_datetime-pd.DateOffset(months=time_buffer_months), + end_datetime+pd.DateOffset(months=time_buffer_months), + freq=f"{timestep_months}MS", + ).astype("int64") cs_tracks = load_cs_ground_tracks(region_of_interest, start_datetime, end_datetime, - buffer_period_by=time_buffer,buffer_region_by=buffer_region_by) + buffer_period_by=relativedelta(months=time_buffer_months), + buffer_region_by=buffer_region_by) print("First and last available ground tracks are on", f"{cs_tracks.index[0]} and {cs_tracks.index[-1]}, respectively.,", - f"{cs_tracks.shape[0]} tracks in total.") - print("Run update_cs_ground_tracks, optionally with `full=True` or", + f"{cs_tracks.shape[0]} tracks in total." + "\n[note] Run update_cs_ground_tracks, optionally with `full=True` or", "`incremental=True`, if you local ground tracks store is not up to", "date. Consider pulling the latest version from the repository.") - # I believe passing loading l2 data to the function prevents copying - # on .drop. an alternative would be to define l2_data nonlocal - # within the gridding function - l3_data = med_mad_cnt_grid(l2.from_id(cs_tracks.index, **filter_kwargs(l2.from_id, kwargs)), - start_datetime=start_datetime, end_datetime=end_datetime, - aggregation_period=aggregation_period, timestep=timestep, - spatial_res_meter=spatial_res_meter) - l3_data.to_netcdf(build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period)) + + print("Storing the essential L2 data in hdf5, downloading and", + "processing L1b files if not available...") + if isinstance(region_of_interest, str): + region_id = region_of_interest + else: + region_id = "_".join([region_of_interest.centroid.x, region_of_interest.centroid.y]) + cache_path = os.path.join(data_path, "tmp", region_id) + l2.from_id(cs_tracks.index, save_or_return="save", cache=cache_path, + **filter_kwargs(l2.from_id, l2_from_id_kwargs, blacklist=["save_or_return", "cache"])) + + print("Gridding the data...") + # one could drop some of the data before gridding. however, excluding + # off-glacier data is expensive and filtering large differences to the + # DEM can hide issues while statistics like the median and the IQR + # should be fairly robust. + l2_ddf = dask.dataframe.read_hdf(cache_path, l2_type, sorted_index=True) + l2_ddf = l2_ddf.loc[ext_t_axis[0]:ext_t_axis[-1]] + l2_ddf = l2_ddf.repartition(npartitions=3*len(os.sched_getaffinity(0))) + + l2_ddf[["x", "y"]] = l2_ddf[["x", "y"]]//spatial_res_meter*spatial_res_meter + l2_ddf["roll_0"] = l2_ddf.index.map_partitions(pd.cut, bins=ext_t_axis, right=False, labels=False, include_lowest=True) + for i in range(1, window_ntimesteps): + l2_ddf[f"roll_{i}"] = l2_ddf.map_partitions(lambda df: df.roll_0-i).persist() + for i in range(window_ntimesteps): + l2_ddf[f"roll_{i}"] = l2_ddf[f"roll_{i}"].map_partitions(lambda series: series.astype("i4")//window_ntimesteps) + + roll_res = [None]*window_ntimesteps + for i in range(window_ntimesteps): + roll_res[i] = l2_ddf.rename(columns={f"roll_{i}": "time_idx"}).groupby(["time_idx", "x", "y"], sort=False).h_diff.apply(agg_func_and_meta[0], meta=agg_func_and_meta[1]).persist() + for i in range(window_ntimesteps): + roll_res[i] = roll_res[i].compute().droplevel(3, axis=0) + roll_res[i].index = roll_res[i].index.set_levels( + (roll_res[i].index.levels[0]*window_ntimesteps+i+1), level=0).rename("time", level=0) + + l3_data = pd.concat(roll_res).sort_index()\ + .loc[(slice(0,len(ext_t_axis)-1),slice(None),slice(None)),:] + l3_data.index = l3_data.index.remove_unused_levels() + l3_data.index = l3_data.index.set_levels( + ext_t_axis[l3_data.index.levels[0]].astype("datetime64[ns]"), level=0) + l3_data = l3_data.query(f"time >= '{start_datetime}' and time <= '{end_datetime}'") + l3_data.to_xarray().to_netcdf(build_path(region_id, timestep_months, spatial_res_meter)) return l3_data __all__.append("build_dataset") -def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_period): - region_id = find_region_id(region_of_interest) - if list(timestep.kwds.values())[0]!=1: - timestep_str = str(list(timestep.kwds.values())[0])+"-" +def build_path(region_of_interest, timestep_months, spatial_res_meter, aggregation_period): + if not isinstance(region_of_interest, str): + region_id = find_region_id(region_of_interest) + else: + region_id = region_of_interest + if timestep_months != 1: + timestep_str = str(timestep_months)+"-" else: timestep_str = "" - timestep_str += list(timestep.kwds.keys())[0][:-1]+"ly" + timestep_str += "monthly" if spatial_res_meter == 1000: spatial_res_str = "1km" elif np.floor(spatial_res_meter/1000) < 2: @@ -69,29 +123,6 @@ def build_path(region_of_interest, timestep, spatial_res_meter, aggregation_peri return os.path.join(data_path, "L3", "_".join( [region_id, timestep_str, spatial_res_str+".nc"])) __all__.append("build_path") - - -def med_mad_cnt_grid(l2_data: gpd.GeoDataFrame, *, - start_datetime: pd.Timestamp, - end_datetime: pd.Timestamp, - aggregation_period: relativedelta, - timestep: relativedelta, - spatial_res_meter: float): - def stats(data: pd.Series) -> pd.Series: - median = data.median() - mad = np.abs(data-median).median() - return pd.Series([median, mad, data.shape[0]]) - time_axis = pd.date_range(start_datetime+pd.offsets.MonthBegin(0), end_datetime, freq=timestep) - if time_axis.tz == None: time_axis = time_axis.tz_localize("UTC") - # if l2_data.index[0].tz == None: l2_data.index = l2_data.index.tz_localize("UTC") - def rolling_stats(data): - results_list = [None]*aggregation_period.months - for i in range(aggregation_period.months): - results_list[i] = data.groupby(subset.index.get_level_values("time")-pd.offsets.QuarterBegin(1, normalize=True)+pd.DateOffset(months=i)).apply(stats) - result = pd.concat(results_list).unstack().sort_index().rename(columns={0: "med_elev_diff", 1: "mad_elev_diff", 2: "cnt_elev_diff"})#, inplace=True - return result.loc[time_axis.join(result.index, how="inner")] - return l2.grid(l2_data, spatial_res_meter, rolling_stats).to_xarray() -__all__.append("med_mad_cnt_grid") __all__ = sorted(__all__)