Skip to content

Commit

Permalink
era downloader test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Sep 16, 2024
1 parent ba7ea1c commit a146672
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 28 deletions.
8 changes: 4 additions & 4 deletions sup3r/preprocessing/cachers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,10 @@ def get_chunksizes(cls, dset, data, chunks):
chunkmem = np.prod(chunksizes) * data_var.dtype.itemsize / 1e9
if chunkmem > 4:
msg = (
'Chunks cannot be larger than 4GB. Given chunksizes '
'result in %s. Will use chunksizes = None')
logger.warning(msg)
warn(msg)
'Chunks cannot be larger than 4GB. Given chunksizes %s '
'result in %sGB. Will use chunksizes = None')
logger.warning(msg, chunksizes, chunkmem)
warn(msg % (chunksizes, chunkmem))
chunksizes = None
return data_var, chunksizes

Expand Down
95 changes: 79 additions & 16 deletions sup3r/utilities/era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,8 @@ def run_for_var(
max_workers=None,
variable=None,
product_type='reanalysis',
chunks='auto',
res_kwargs=None,
):
"""Run routine for all requested months in the requested year for the
given variable.
Expand Down Expand Up @@ -558,6 +560,9 @@ def run_for_var(
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
"""
yearly_var_file = yearly_file_pattern.format(year=year, var=variable)
if os.path.exists(yearly_var_file) and not overwrite:
Expand Down Expand Up @@ -592,7 +597,12 @@ def run_for_var(

if yearly_file_pattern is not None and len(months) == 12:
cls.make_yearly_var_file(
year, monthly_file_pattern, yearly_file_pattern, variable
year,
monthly_file_pattern,
yearly_file_pattern,
variable,
chunks=chunks,
res_kwargs=res_kwargs,
)

@classmethod
Expand All @@ -608,6 +618,9 @@ def run(
max_workers=None,
variables=None,
product_type='reanalysis',
chunks='auto',
combine_all_files=False,
res_kwargs=None,
):
"""Run routine for all requested months in the requested year.
Expand Down Expand Up @@ -637,6 +650,12 @@ def run(
product_type : str
Can be 'reanalysis', 'ensemble_mean', 'ensemble_spread',
'ensemble_members'
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'
combine_all_files : bool
Whether to combine separate yearly variable files into a single
yearly file with all variables included
"""
for var in variables:
cls.run_for_var(
Expand All @@ -650,28 +669,57 @@ def run(
variable=var,
product_type=product_type,
max_workers=max_workers,
chunks=chunks,
res_kwargs=res_kwargs,
)

if cls.all_vars_exist(
year=year, file_pattern=yearly_file_pattern, variables=variables
if (
cls.all_vars_exist(
year=year,
file_pattern=yearly_file_pattern,
variables=variables,
)
and combine_all_files
):
cls.make_yearly_file(year, yearly_file_pattern, variables)
cls.make_yearly_file(
year,
yearly_file_pattern,
variables,
chunks=chunks,
res_kwargs=res_kwargs,
)

@classmethod
def make_yearly_var_file(
cls, year, monthly_file_pattern, yearly_file_pattern, variable
cls,
year,
monthly_file_pattern,
yearly_file_pattern,
variable,
chunks='auto',
res_kwargs=None,
):
"""Combine monthly variable files into a single yearly variable file.
Parameters
----------
year : int
Year used to download data
file_pattern : str
monthly_file_pattern : str
File pattern for monthly variable files. Must have year, month, and
var format keys. e.g. './era_{year}_{month}_{var}_combined.nc'
yearly_file_pattern : str
File pattern for yearly variable files. Must have year and var
format keys. e.g. './era_{year}_{var}_combined.nc'
variable : string
Variable name for the files to be combined.
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
res_kwargs : None | dict
Keyword arguments for base resource handler, like
``xr.open_mfdataset.`` This is passed to a ``Loader`` object and
then used in the base loader contained by that obkect.
"""
files = [
monthly_file_pattern.format(
Expand All @@ -681,21 +729,26 @@ def make_yearly_var_file(
]

outfile = yearly_file_pattern.format(year=year, var=variable)
cls._combine_files(files, outfile)
cls._combine_files(
files, outfile, chunks=chunks, res_kwargs=res_kwargs
)

@classmethod
def _combine_files(cls, files, outfile, kwargs=None):
def _combine_files(cls, files, outfile, chunks='auto', res_kwargs=None):
if not os.path.exists(outfile):
logger.info(f'Combining {files} into {outfile}.')
try:
kwargs = kwargs or {}
loader = Loader(files, res_kwargs=kwargs)
res_kwargs = res_kwargs or {}
loader = Loader(files, res_kwargs=res_kwargs)
tmp_file = cls.get_tmp_file(outfile)
for ignore_var in IGNORE_VARS:
if ignore_var in loader.coords:
loader.data = loader.data.drop_vars(ignore_var)
Cacher.write_netcdf(
data=loader.data, out_file=tmp_file, max_workers=1
data=loader.data,
out_file=tmp_file,
max_workers=1,
chunks=chunks,
)
os.replace(tmp_file, outfile)
logger.info('Moved %s to %s.', tmp_file, outfile)
Expand All @@ -707,19 +760,28 @@ def _combine_files(cls, files, outfile, kwargs=None):
logger.info(f'{outfile} already exists.')

@classmethod
def make_yearly_file(cls, year, file_pattern, variables):
"""Combine monthly files into a single file.
def make_yearly_file(
cls, year, file_pattern, variables, chunks='auto', res_kwargs=None
):
"""Combine yearly variable files into a single file.
Parameters
----------
year : int
Year of monthly data to make into a yearly file.
Year for the data to make into a yearly file.
file_pattern : str
File pattern for output files. Must have year and var
format keys. e.g. './era_{year}_{var}_combined.nc'
variables : list
List of variables corresponding to the yearly variable files to
combine.
chunks : str | dict
Dictionary of chunksizes used when writing data to netcdf files.
Can also be 'auto'.
res_kwargs : None | dict
Keyword arguments for base resource handler, like
``xr.open_mfdataset.`` This is passed to a ``Loader`` object and
then used in the base loader contained by that obkect.
"""
msg = (
f'Not all variable files with file_patten {file_pattern} for '
Expand All @@ -730,13 +792,14 @@ def make_yearly_file(cls, year, file_pattern, variables):
), msg

files = [file_pattern.format(year=year, var=var) for var in variables]
kwargs = {'combine': 'nested', 'concat_dim': 'time'}
yearly_file = (
file_pattern.replace('_{var}_', '')
.replace('_{var}', '')
.format(year=year)
)
cls._combine_files(files, yearly_file, kwargs)
cls._combine_files(
files, yearly_file, res_kwargs=res_kwargs, chunks=chunks
)

@classmethod
def run_qa(cls, file, res_kwargs=None, log_file=None):
Expand Down
8 changes: 5 additions & 3 deletions sup3r/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ def merge_datasets(files, **kwargs):
xr.open_mfdatasets fails due to different time index formats or coordinate
names, for example."""
dsets = [xr.open_mfdataset(f, **kwargs) for f in files]
time_indices = [dset.time for dset in dsets]
time_indices = []
for i, dset in enumerate(dsets):
if 'time' in dset and dset.time.size > 1:
dset['time'] = pd.DatetimeIndex(dset.time)
ti = pd.DatetimeIndex(dset.time)
dset['time'] = ti
dsets[i] = dset
time_indices.append(ti.to_series())
if 'latitude' in dset.dims:
dset = dset.swap_dims({'latitude': 'south_north'})
dsets[i] = dset
Expand All @@ -48,7 +50,7 @@ def merge_datasets(files, **kwargs):

def xr_open_mfdataset(files, **kwargs):
"""Wrapper for xr.open_mfdataset with default opening options."""
default_kwargs = {'engine': 'netcdf4', 'coords': 'minimal'}
default_kwargs = {'engine': 'netcdf4'}
default_kwargs.update(kwargs)
try:
return xr.open_mfdataset(files, **default_kwargs)
Expand Down
17 changes: 12 additions & 5 deletions tests/utilities/test_era_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_era_dl(tmpdir_factory):
month=month,
area=area,
levels=levels,
monthly_file_pattern=file_pattern,
file_pattern=file_pattern,
variables=variables,
)
for v in variables:
Expand All @@ -86,18 +86,25 @@ def test_era_dl_year(tmpdir_factory):
file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{month}_{var}.nc'
)
yearly_file = os.path.join(tmpdir_factory.mktemp('tmp'), 'era5_final.nc')
EraDownloaderTester.run_year(
yearly_file_pattern = os.path.join(
tmpdir_factory.mktemp('tmp'), 'era5_{year}_{var}_final.nc'
)
EraDownloaderTester.run(
year=2000,
area=[50, -130, 23, -65],
levels=[1000, 900, 800],
variables=variables,
monthly_file_pattern=file_pattern,
yearly_file=yearly_file,
yearly_file_pattern=yearly_file_pattern,
max_workers=1,
combine_all_files=True,
res_kwargs={'compat': 'override', 'engine': 'netcdf4'},
)

tmp = xr_open_mfdataset(yearly_file)
combined_file = yearly_file_pattern.replace('_{var}_', '').format(
year=2000
)
tmp = xr_open_mfdataset(combined_file)
for v in variables:
standard_name = FEATURE_NAMES.get(v, v)
assert standard_name in tmp

0 comments on commit a146672

Please sign in to comment.