From acea804e3422f9076ab05821a890ff42eb89067d Mon Sep 17 00:00:00 2001 From: "Brandon N. Benton" Date: Thu, 8 Aug 2024 16:11:06 -0600 Subject: [PATCH] bad type check fix --- sup3r/postprocessing/collectors/nc.py | 11 +++++++---- sup3r/preprocessing/accessor.py | 11 ++++++----- sup3r/preprocessing/base.py | 6 ++---- sup3r/preprocessing/utilities.py | 17 +++++------------ tests/batch_queues/test_bq_general.py | 6 +++--- 5 files changed, 23 insertions(+), 28 deletions(-) diff --git a/sup3r/postprocessing/collectors/nc.py b/sup3r/postprocessing/collectors/nc.py index 8dd01eb9a..4eb6b90ba 100644 --- a/sup3r/postprocessing/collectors/nc.py +++ b/sup3r/postprocessing/collectors/nc.py @@ -79,17 +79,18 @@ def collect( logger.info(f'overwrite=True, removing {out_file}.') os.remove(out_file) - if not os.path.exists(out_file): + tmp_file = out_file + '.tmp' + if not os.path.exists(tmp_file): res_kwargs = res_kwargs or {} out = xr.open_mfdataset(collector.flist, **res_kwargs) features = list(out.data_vars) if features == 'all' else features features = set(features).intersection(_lowered(out.data_vars)) for feat in features: - mode = 'a' if os.path.exists(out_file) else 'w' + mode = 'a' if os.path.exists(tmp_file) else 'w' out[feat].load().to_netcdf( - out_file, mode=mode, engine='h5netcdf', format='NETCDF4' + tmp_file, mode=mode, engine='h5netcdf', format='NETCDF4' ) - logger.info(f'Finished writing {feat} to {out_file}.') + logger.info(f'Finished writing {feat} to {tmp_file}.') if write_status and job_name is not None: status = { @@ -102,6 +103,8 @@ def collect( Status.make_single_job_file( os.path.dirname(out_file), 'collect', job_name, status ) + os.replace(tmp_file, out_file) + logger.info('Moved %s to %s.', tmp_file, out_file) logger.info('Finished file collection.') diff --git a/sup3r/preprocessing/accessor.py b/sup3r/preprocessing/accessor.py index 444d3527d..bd1b313ce 100644 --- a/sup3r/preprocessing/accessor.py +++ b/sup3r/preprocessing/accessor.py @@ -17,7 +17,7 @@ _lowered, _mem_check, dims_array_tuple, - is_strings, + is_type_of, ordered_array, ordered_dims, parse_ellipsis, @@ -97,7 +97,7 @@ def parse_keys(self, keys): dataset that can be passed to isel and transposed to standard dimension order.""" keys = keys if isinstance(keys, tuple) else (keys,) - has_feats = is_strings(keys[0]) + has_feats = is_type_of(keys[0], str) just_coords = keys[0] == [] features = ( list(self.coords) @@ -132,7 +132,7 @@ def __getitem__( out = self._ds[features] out = self.ordered(out) if single_feat else type(self)(out) slices = {k: v for k, v in slices.items() if k in out.dims} - no_slices = is_strings(keys) + no_slices = is_type_of(keys, str) just_coords = all(f in self.coords for f in parse_to_list(features)) is_fancy = self._needs_fancy_indexing(slices.values()) @@ -170,7 +170,7 @@ def __setitem__(self, keys, data): then this is expected to have a trailing dimension with length equal to the length of the list. """ - if is_strings(keys): + if is_type_of(keys, str): if isinstance(keys, (list, tuple)) and hasattr(data, 'data_vars'): data_dict = {v: data[v] for v in keys} elif isinstance(keys, (list, tuple)): @@ -314,8 +314,9 @@ def sample(self, idx): ``(slice(0, 3), slice(1, 10), slice(None), ['u_10m', 'v_10m'])``""" isel_kwargs = dict(zip(Dimension.dims_3d(), idx[:-1])) features = ( - self.features if not is_strings(idx[-1]) else _lowered(idx[-1]) + _lowered(idx[-1]) if is_type_of(idx[-1], str) else self.features ) + out = self._ds[features].isel(**isel_kwargs) return self.ordered(out.to_array()).data diff --git a/sup3r/preprocessing/base.py b/sup3r/preprocessing/base.py index b81266fad..7edd0d9d8 100644 --- a/sup3r/preprocessing/base.py +++ b/sup3r/preprocessing/base.py @@ -20,7 +20,7 @@ import sup3r.preprocessing.accessor # noqa: F401 # pylint: disable=W0611 from sup3r.preprocessing.accessor import Sup3rX -from sup3r.preprocessing.utilities import composite_info +from sup3r.preprocessing.utilities import composite_info, is_type_of logger = logging.getLogger(__name__) @@ -359,9 +359,7 @@ def wrap(self, data): if data is None: return data - check_sup3rds = all(isinstance(d, Sup3rDataset) for d in data) - check_sup3rds = check_sup3rds or isinstance(data, Sup3rDataset) - if check_sup3rds: + if is_type_of(data, Sup3rDataset): return data if isinstance(data, tuple) and len(data) == 2: diff --git a/sup3r/preprocessing/utilities.py b/sup3r/preprocessing/utilities.py index 7dcbd9725..d37ed4db4 100644 --- a/sup3r/preprocessing/utilities.py +++ b/sup3r/preprocessing/utilities.py @@ -381,23 +381,16 @@ def contains_ellipsis(vals): ) -def is_strings(vals): - """Check if vals is a string or iterable of all strings.""" - return isinstance(vals, str) or ( +def is_type_of(vals, vtype): + """Check if vals is an instance of type or group of that type.""" + return isinstance(vals, vtype) or ( isinstance(vals, (set, tuple, list)) - and all(isinstance(v, str) for v in vals) + and all(isinstance(v, vtype) for v in vals) ) def _get_strings(vals): - return [v for v in vals if is_strings(v)] - - -def _is_ints(vals): - return isinstance(vals, int) or ( - isinstance(vals, (list, tuple, np.ndarray)) - and all(isinstance(v, int) for v in vals) - ) + return [v for v in vals if is_type_of(v, str)] def _lowered(features): diff --git a/tests/batch_queues/test_bq_general.py b/tests/batch_queues/test_bq_general.py index 7da685cbd..52b2ffb25 100644 --- a/tests/batch_queues/test_bq_general.py +++ b/tests/batch_queues/test_bq_general.py @@ -126,7 +126,7 @@ def test_dual_batch_queue(): ] sampler_pairs = [ DualSampler( - Sup3rDataset((lr.data, hr.data)), + Sup3rDataset(low_res=lr.data, high_res=hr.data), hr_sample_shape, s_enhance=2, t_enhance=2, @@ -179,7 +179,7 @@ def test_pair_batch_queue_with_lr_only_features(): ] sampler_pairs = [ DualSampler( - Sup3rDataset((lr, hr)), + Sup3rDataset(low_res=lr.data, high_res=hr.data), hr_sample_shape, s_enhance=2, t_enhance=2, @@ -234,7 +234,7 @@ def test_bad_enhancement_factors(): with pytest.raises(AssertionError): sampler_pairs = [ DualSampler( - Sup3rDataset((lr, hr)), + Sup3rDataset(low_res=lr.data, high_res=hr.data), hr_sample_shape, s_enhance=s_enhance, t_enhance=t_enhance,