Skip to content

Commit 3efeb16

Browse files
committed
implement coders
1 parent 4448828 commit 3efeb16

File tree

2 files changed

+168
-14
lines changed

2 files changed

+168
-14
lines changed

xarray/coding/variables.py

+159
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,25 @@ def _choose_float_dtype(dtype: np.dtype, has_offset: bool) -> type[np.floating[A
251251
return np.float64
252252

253253

254+
class DefaultFillvalueCoder(VariableCoder):
255+
"""Encode default _FillValue if needed."""
256+
257+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
258+
dims, data, attrs, encoding = unpack_for_encoding(variable)
259+
# make NaN the fill value for float types
260+
if (
261+
"_FillValue" not in attrs
262+
and "_FillValue" not in encoding
263+
and np.issubdtype(variable.dtype, np.floating)
264+
):
265+
attrs["_FillValue"] = variable.dtype.type(np.nan)
266+
267+
return Variable(dims, data, attrs, encoding, fastpath=True)
268+
269+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
270+
return NotImplementedError
271+
272+
254273
class CFScaleOffsetCoder(VariableCoder):
255274
"""Scale and offset variables according to CF conventions.
256275
@@ -349,3 +368,143 @@ def decode(self, variable: Variable, name: T_Name = None) -> Variable:
349368
return Variable(dims, data, attrs, encoding, fastpath=True)
350369
else:
351370
return variable
371+
372+
373+
class BoolTypeArray(indexing.ExplicitlyIndexedNDArrayMixin):
374+
"""Decode arrays on the fly from integer to boolean datatype
375+
376+
This is useful for decoding boolean arrays from integer typed netCDF
377+
variables.
378+
379+
>>> x = np.array([1, 0, 1, 1, 0], dtype="i1")
380+
381+
>>> x.dtype
382+
dtype('int8')
383+
384+
>>> BoolTypeArray(x).dtype
385+
dtype('bool')
386+
387+
>>> indexer = indexing.BasicIndexer((slice(None),))
388+
>>> BoolTypeArray(x)[indexer].dtype
389+
dtype('bool')
390+
"""
391+
392+
__slots__ = ("array",)
393+
394+
def __init__(self, array):
395+
self.array = indexing.as_indexable(array)
396+
397+
@property
398+
def dtype(self):
399+
return np.dtype("bool")
400+
401+
def __getitem__(self, key):
402+
return np.asarray(self.array[key], dtype=self.dtype)
403+
404+
405+
class BooleanCoder(VariableCoder):
406+
"""Code boolean values."""
407+
408+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
409+
if (
410+
(variable.dtype == bool)
411+
and ("dtype" not in variable.encoding)
412+
and ("dtype" not in variable.attrs)
413+
):
414+
dims, data, attrs, encoding = unpack_for_encoding(variable)
415+
attrs["dtype"] = "bool"
416+
data = duck_array_ops.astype(data, dtype="i1", copy=True)
417+
418+
return Variable(dims, data, attrs, encoding, fastpath=True)
419+
else:
420+
return variable
421+
422+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
423+
if variable.attrs.get("dtype", False) == "bool":
424+
dims, data, attrs, encoding = unpack_for_decoding(variable)
425+
del attrs["dtype"]
426+
data = BoolTypeArray(data)
427+
return Variable(dims, data, attrs, encoding, fastpath=True)
428+
else:
429+
return variable
430+
431+
432+
class NativeEndiannessArray(indexing.ExplicitlyIndexedNDArrayMixin):
433+
"""Decode arrays on the fly from non-native to native endianness
434+
435+
This is useful for decoding arrays from netCDF3 files (which are all
436+
big endian) into native endianness, so they can be used with Cython
437+
functions, such as those found in bottleneck and pandas.
438+
439+
>>> x = np.arange(5, dtype=">i2")
440+
441+
>>> x.dtype
442+
dtype('>i2')
443+
444+
>>> NativeEndiannessArray(x).dtype
445+
dtype('int16')
446+
447+
>>> indexer = indexing.BasicIndexer((slice(None),))
448+
>>> NativeEndiannessArray(x)[indexer].dtype
449+
dtype('int16')
450+
"""
451+
452+
__slots__ = ("array",)
453+
454+
def __init__(self, array):
455+
self.array = indexing.as_indexable(array)
456+
457+
@property
458+
def dtype(self):
459+
return np.dtype(self.array.dtype.kind + str(self.array.dtype.itemsize))
460+
461+
def __getitem__(self, key):
462+
return np.asarray(self.array[key], dtype=self.dtype)
463+
464+
465+
class EndianCoder(VariableCoder):
466+
"""Decode Endianness to native."""
467+
468+
def encode(self):
469+
return NotImplementedError
470+
471+
def decode(self, variable: Variable, name: T_Name = None) -> Variable:
472+
dims, data, attrs, encoding = unpack_for_decoding(variable)
473+
if not data.dtype.isnative:
474+
data = NativeEndiannessArray(data)
475+
return Variable(dims, data, attrs, encoding, fastpath=True)
476+
else:
477+
return variable
478+
479+
480+
class NonStringCoder(VariableCoder):
481+
"""Encode NonString variables if dtypes differ."""
482+
483+
def encode(self, variable: Variable, name: T_Name = None) -> Variable:
484+
if "dtype" in variable.encoding and variable.encoding["dtype"] not in (
485+
"S1",
486+
str,
487+
):
488+
dims, data, attrs, encoding = unpack_for_encoding(variable)
489+
dtype = np.dtype(encoding.pop("dtype"))
490+
if dtype != variable.dtype:
491+
if np.issubdtype(dtype, np.integer):
492+
if (
493+
np.issubdtype(variable.dtype, np.floating)
494+
and "_FillValue" not in variable.attrs
495+
and "missing_value" not in variable.attrs
496+
):
497+
warnings.warn(
498+
f"saving variable {name} with floating "
499+
"point data as an integer dtype without "
500+
"any _FillValue to use for NaNs",
501+
SerializationWarning,
502+
stacklevel=10,
503+
)
504+
data = np.around(data)
505+
data = data.astype(dtype=dtype)
506+
variable = Variable(dims, data, attrs, encoding, fastpath=True)
507+
return variable
508+
509+
def decode(self):
510+
return NotImplementedError

xarray/conventions.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,12 @@ def encode_cf_variable(
292292
variables.CFScaleOffsetCoder(),
293293
variables.CFMaskCoder(),
294294
variables.UnsignedIntegerCoder(),
295+
variables.NonStringCoder(),
296+
variables.DefaultFillvalueCoder(),
297+
variables.BooleanCoder(),
295298
]:
296299
var = coder.encode(var, name=name)
297300

298-
# TODO(shoyer): convert all of these to use coders, too:
299-
var = maybe_encode_nonstring_dtype(var, name=name)
300-
var = maybe_default_fill_value(var)
301-
var = maybe_encode_bools(var)
302301
var = ensure_dtype_not_object(var, name=name)
303302

304303
for attr_name in CF_RELATED_DATA:
@@ -389,19 +388,15 @@ def decode_cf_variable(
389388
if decode_times:
390389
var = times.CFDatetimeCoder(use_cftime=use_cftime).decode(var, name=name)
391390

392-
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
393-
# TODO(shoyer): convert everything below to use coders
391+
if decode_endianness and not var.dtype.isnative:
392+
var = variables.EndianCoder().decode(var)
393+
original_dtype = var.dtype
394394

395-
if decode_endianness and not data.dtype.isnative:
396-
# do this last, so it's only done if we didn't already unmask/scale
397-
data = NativeEndiannessArray(data)
398-
original_dtype = data.dtype
395+
var = variables.BooleanCoder().decode(var)
399396

400-
encoding.setdefault("dtype", original_dtype)
397+
dimensions, data, attributes, encoding = variables.unpack_for_decoding(var)
401398

402-
if "dtype" in attributes and attributes["dtype"] == "bool":
403-
del attributes["dtype"]
404-
data = BoolTypeArray(data)
399+
encoding.setdefault("dtype", original_dtype)
405400

406401
if not is_duck_dask_array(data):
407402
data = indexing.LazilyIndexedArray(data)

0 commit comments

Comments
 (0)