diff --git a/trough/_arb.py b/trough/_arb.py index 7553bb5..a0454d7 100644 --- a/trough/_arb.py +++ b/trough/_arb.py @@ -36,7 +36,7 @@ def get_arb_paths(start_date, end_date, hemisphere, processed_dir): def get_arb_data(start_date, end_date, hemisphere, processed_dir=None): if processed_dir is None: processed_dir = config.processed_arb_dir - data = xr.concat([xr.open_dataarray(file) for file in get_arb_paths(start_date, end_date, hemisphere, processed_dir)], 'time') + data = utils.read_netcdfs(get_arb_paths(start_date, end_date, hemisphere, processed_dir), 'time') return data.sel(time=slice(start_date, end_date)) diff --git a/trough/_tec.py b/trough/_tec.py index b04134c..64c94b7 100644 --- a/trough/_tec.py +++ b/trough/_tec.py @@ -73,7 +73,7 @@ def get_tec_paths(start_date, end_date, hemisphere, processed_dir): def get_tec_data(start_date, end_date, hemisphere, processed_dir=None): if processed_dir is None: processed_dir = config.processed_tec_dir - data = xr.concat([xr.open_dataarray(file) for file in get_tec_paths(start_date, end_date, hemisphere, processed_dir)], 'time') + data = utils.read_netcdfs(get_tec_paths(start_date, end_date, hemisphere, processed_dir), 'time') return data.sel(time=slice(start_date, end_date)) diff --git a/trough/_trough.py b/trough/_trough.py index f9530d5..71c0245 100644 --- a/trough/_trough.py +++ b/trough/_trough.py @@ -33,11 +33,11 @@ def get_model(tec_data, hemisphere, omni_file): logger.info(f"{kp.shape=}") apex = Apex(date=utils.datetime64_to_datetime(tec_data.time.values[0])) mlat = 65.5 * np.ones((tec_data.time.shape[0], tec_data.mlt.shape[0])) - if hemisphere == 'south': - mlat = mlat * -1 for i in range(10): glat, glon = apex.convert(mlat, tec_data.mlt.values[None, :], 'mlt', 'geo', 350, tec_data.time.values[:, None]) mlat = _model_subroutine_lat(tec_data.mlt.values[None, :], glon, kp[:, None], hemisphere) + if hemisphere == 'south': + mlat = mlat * -1 tec_data['model'] = xr.DataArray( mlat, coords={'time': tec_data.time, 'mlt': tec_data.mlt}, @@ -311,7 +311,7 @@ def get_label_paths(start_date, end_date, hemisphere, processed_dir): def get_trough_labels(start_date, end_date, hemisphere, labels_dir=None): if labels_dir is None: labels_dir = config.processed_labels_dir - data = xr.concat([xr.open_dataarray(file) for file in get_label_paths(start_date, end_date, hemisphere, labels_dir)], 'time') + data = utils.read_netcdfs(get_label_paths(start_date, end_date, hemisphere, labels_dir), 'time') return data.sel(time=slice(start_date, end_date)) diff --git a/trough/utils.py b/trough/utils.py index d9a5040..e43a035 100644 --- a/trough/utils.py +++ b/trough/utils.py @@ -2,6 +2,7 @@ import datetime import warnings import logging +import xarray as xr try: import h5py from skimage.util import view_as_windows @@ -153,3 +154,20 @@ def check(start, end, dt, hemisphere, processed_file): return True return check + + +def read_netcdfs(files, dim): + """https://xarray.pydata.org/en/stable/user-guide/io.html#reading-multi-file-datasets + """ + def process_one_path(path): + # use a context manager, to ensure the file gets closed after use + with xr.open_dataarray(path) as ds: + # load all data from the transformed dataset, to ensure we can + # use it after closing each original file + ds.load() + return ds + + paths = sorted(files) + datasets = [process_one_path(p) for p in paths] + combined = xr.concat(datasets, dim) + return combined