Skip to content

Commit

Permalink
bad type check fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bnb32 committed Aug 8, 2024
1 parent d8e0fe2 commit acea804
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 28 deletions.
11 changes: 7 additions & 4 deletions sup3r/postprocessing/collectors/nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,17 +79,18 @@ def collect(
logger.info(f'overwrite=True, removing {out_file}.')
os.remove(out_file)

if not os.path.exists(out_file):
tmp_file = out_file + '.tmp'
if not os.path.exists(tmp_file):
res_kwargs = res_kwargs or {}
out = xr.open_mfdataset(collector.flist, **res_kwargs)
features = list(out.data_vars) if features == 'all' else features
features = set(features).intersection(_lowered(out.data_vars))
for feat in features:
mode = 'a' if os.path.exists(out_file) else 'w'
mode = 'a' if os.path.exists(tmp_file) else 'w'
out[feat].load().to_netcdf(
out_file, mode=mode, engine='h5netcdf', format='NETCDF4'
tmp_file, mode=mode, engine='h5netcdf', format='NETCDF4'
)
logger.info(f'Finished writing {feat} to {out_file}.')
logger.info(f'Finished writing {feat} to {tmp_file}.')

if write_status and job_name is not None:
status = {
Expand All @@ -102,6 +103,8 @@ def collect(
Status.make_single_job_file(
os.path.dirname(out_file), 'collect', job_name, status
)
os.replace(tmp_file, out_file)
logger.info('Moved %s to %s.', tmp_file, out_file)

logger.info('Finished file collection.')

Expand Down
11 changes: 6 additions & 5 deletions sup3r/preprocessing/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
_lowered,
_mem_check,
dims_array_tuple,
is_strings,
is_type_of,
ordered_array,
ordered_dims,
parse_ellipsis,
Expand Down Expand Up @@ -97,7 +97,7 @@ def parse_keys(self, keys):
dataset that can be passed to isel and transposed to standard dimension
order."""
keys = keys if isinstance(keys, tuple) else (keys,)
has_feats = is_strings(keys[0])
has_feats = is_type_of(keys[0], str)
just_coords = keys[0] == []
features = (
list(self.coords)
Expand Down Expand Up @@ -132,7 +132,7 @@ def __getitem__(
out = self._ds[features]
out = self.ordered(out) if single_feat else type(self)(out)
slices = {k: v for k, v in slices.items() if k in out.dims}
no_slices = is_strings(keys)
no_slices = is_type_of(keys, str)
just_coords = all(f in self.coords for f in parse_to_list(features))
is_fancy = self._needs_fancy_indexing(slices.values())

Expand Down Expand Up @@ -170,7 +170,7 @@ def __setitem__(self, keys, data):
then this is expected to have a trailing dimension with length
equal to the length of the list.
"""
if is_strings(keys):
if is_type_of(keys, str):
if isinstance(keys, (list, tuple)) and hasattr(data, 'data_vars'):
data_dict = {v: data[v] for v in keys}
elif isinstance(keys, (list, tuple)):
Expand Down Expand Up @@ -314,8 +314,9 @@ def sample(self, idx):
``(slice(0, 3), slice(1, 10), slice(None), ['u_10m', 'v_10m'])``"""
isel_kwargs = dict(zip(Dimension.dims_3d(), idx[:-1]))
features = (
self.features if not is_strings(idx[-1]) else _lowered(idx[-1])
_lowered(idx[-1]) if is_type_of(idx[-1], str) else self.features
)

out = self._ds[features].isel(**isel_kwargs)
return self.ordered(out.to_array()).data

Expand Down
6 changes: 2 additions & 4 deletions sup3r/preprocessing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import sup3r.preprocessing.accessor # noqa: F401 # pylint: disable=W0611
from sup3r.preprocessing.accessor import Sup3rX
from sup3r.preprocessing.utilities import composite_info
from sup3r.preprocessing.utilities import composite_info, is_type_of

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -359,9 +359,7 @@ def wrap(self, data):
if data is None:
return data

check_sup3rds = all(isinstance(d, Sup3rDataset) for d in data)
check_sup3rds = check_sup3rds or isinstance(data, Sup3rDataset)
if check_sup3rds:
if is_type_of(data, Sup3rDataset):
return data

if isinstance(data, tuple) and len(data) == 2:
Expand Down
17 changes: 5 additions & 12 deletions sup3r/preprocessing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,23 +381,16 @@ def contains_ellipsis(vals):
)


def is_strings(vals):
"""Check if vals is a string or iterable of all strings."""
return isinstance(vals, str) or (
def is_type_of(vals, vtype):
"""Check if vals is an instance of type or group of that type."""
return isinstance(vals, vtype) or (
isinstance(vals, (set, tuple, list))
and all(isinstance(v, str) for v in vals)
and all(isinstance(v, vtype) for v in vals)
)


def _get_strings(vals):
return [v for v in vals if is_strings(v)]


def _is_ints(vals):
return isinstance(vals, int) or (
isinstance(vals, (list, tuple, np.ndarray))
and all(isinstance(v, int) for v in vals)
)
return [v for v in vals if is_type_of(v, str)]


def _lowered(features):
Expand Down
6 changes: 3 additions & 3 deletions tests/batch_queues/test_bq_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def test_dual_batch_queue():
]
sampler_pairs = [
DualSampler(
Sup3rDataset((lr.data, hr.data)),
Sup3rDataset(low_res=lr.data, high_res=hr.data),
hr_sample_shape,
s_enhance=2,
t_enhance=2,
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_pair_batch_queue_with_lr_only_features():
]
sampler_pairs = [
DualSampler(
Sup3rDataset((lr, hr)),
Sup3rDataset(low_res=lr.data, high_res=hr.data),
hr_sample_shape,
s_enhance=2,
t_enhance=2,
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_bad_enhancement_factors():
with pytest.raises(AssertionError):
sampler_pairs = [
DualSampler(
Sup3rDataset((lr, hr)),
Sup3rDataset(low_res=lr.data, high_res=hr.data),
hr_sample_shape,
s_enhance=s_enhance,
t_enhance=t_enhance,
Expand Down

0 comments on commit acea804

Please sign in to comment.