From 27b288377574e5f85f0308322747a899fd0c871e Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 2 Nov 2016 19:11:16 -0700 Subject: [PATCH] New properties `Dataset.sizes` and `DataArray.sizes` This allows for consistent access to dimension lengths on ``Dataset`` and ``DataArray`` xref #921 (doesn't resolve it 100%, but should help significantly) --- doc/api.rst | 2 + doc/whats-new.rst | 5 +++ xarray/core/common.py | 70 +++++++++++++++++++++++++++-------- xarray/core/dataarray.py | 37 ++++-------------- xarray/core/dataset.py | 49 ++++++++++-------------- xarray/core/groupby.py | 5 +-- xarray/core/variable.py | 36 ++---------------- xarray/test/test_dataarray.py | 7 ++++ xarray/test/test_dataset.py | 1 + xarray/test/test_variable.py | 1 + 10 files changed, 101 insertions(+), 112 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index d006775ad55..2602d9f2e29 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -39,6 +39,7 @@ Attributes :toctree: generated/ Dataset.dims + Dataset.sizes Dataset.data_vars Dataset.coords Dataset.attrs @@ -187,6 +188,7 @@ Attributes DataArray.data DataArray.coords DataArray.dims + DataArray.sizes DataArray.name DataArray.attrs DataArray.encoding diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0d28346d740..d30925c8f28 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -77,6 +77,11 @@ Enhancements :py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram `_. +- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for + providing consistent access to dimension length on both ``Dataset`` and + ``DataArray`` (:issue:`921`). + By `Stephan Hoyer `_. + Bug fixes ~~~~~~~~~ - ``groupby_bins`` now restores empty bins by default (:issue:`1019`). diff --git a/xarray/core/common.py b/xarray/core/common.py index 38d76af22c2..a67591a7493 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1,9 +1,10 @@ import numpy as np import pandas as pd -from .pycompat import basestring, iteritems, suppress, dask_array_type, bytes_type +from .pycompat import (basestring, iteritems, suppress, dask_array_type, + OrderedDict) from . import formatting -from .utils import SortedKeysDict, not_implemented +from .utils import SortedKeysDict, not_implemented, Frozen class ImplementsArrayReduce(object): @@ -124,6 +125,8 @@ def wrapped_func(self, **kwargs): class AbstractArray(ImplementsArrayReduce, formatting.ReprMixin): + """Shared base class for DataArray and Variable.""" + def __bool__(self): return bool(self.values) @@ -186,6 +189,18 @@ def _get_axis_num(self, dim): raise ValueError("%r not found in array dimensions %r" % (dim, self.dims)) + @property + def sizes(self): + """Ordered mapping from dimension names to lengths. + + Immutable. + + See also + -------- + Dataset.sizes + """ + return Frozen(OrderedDict(zip(self.dims, self.shape))) + class AttrAccessMixin(object): """Mixin class that allows getting keys with attribute access @@ -231,7 +246,43 @@ def __dir__(self): return sorted(set(dir(type(self)) + extra_attrs)) -class BaseDataObject(AttrAccessMixin): +class SharedMethodsMixin(object): + """Shared methods for Dataset, DataArray and Variable.""" + + def squeeze(self, dim=None): + """Return a new object with squeezed data. + + Parameters + ---------- + dim : None or str or tuple of str, optional + Selects a subset of the length one dimensions. If a dimension is + selected with length greater than one, an error is raised. If + None, all length one dimensions are squeezed. + + Returns + ------- + squeezed : same type as caller + This object, but with with all or a subset of the dimensions of + length 1 removed. + + See Also + -------- + numpy.squeeze + """ + if dim is None: + dim = [d for d, s in self.sizes.items() if s == 1] + else: + if isinstance(dim, basestring): + dim = [dim] + if any(self.sizes[k] > 1 for k in dim): + raise ValueError('cannot select a dimension to squeeze out ' + 'which has length greater than one') + return self.isel(**{d: 0 for d in dim}) + + +class BaseDataObject(SharedMethodsMixin, AttrAccessMixin): + """Shared base class for Dataset and DataArray.""" + def _calc_assign_results(self, kwargs): results = SortedKeysDict() for k, v in kwargs.items(): @@ -615,19 +666,6 @@ def __exit__(self, exc_type, exc_value, traceback): __or__ = __div__ = __eq__ = __ne__ = not_implemented -def squeeze(xarray_obj, dims, dim=None): - """Squeeze the dims of an xarray object.""" - if dim is None: - dim = [d for d, s in iteritems(dims) if s == 1] - else: - if isinstance(dim, basestring): - dim = [dim] - if any(dims[k] > 1 for k in dim): - raise ValueError('cannot select a dimension to squeeze out ' - 'which has length greater than one') - return xarray_obj.isel(**dict((d, 0) for d in dim)) - - def _maybe_promote(dtype): """Simpler equivalent of pandas.core.common._maybe_promote""" # N.B. these casting rules should match pandas diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 71164072cc9..c34680d4380 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -1,4 +1,3 @@ -import contextlib import functools import warnings @@ -13,7 +12,7 @@ from . import ops from . import utils from .alignment import align -from .common import AbstractArray, BaseDataObject, squeeze +from .common import AbstractArray, BaseDataObject from .coordinates import (DataArrayCoordinates, LevelCoordinates, Indexes) from .dataset import Dataset @@ -411,7 +410,12 @@ def to_index(self): @property def dims(self): - """Dimension names associated with this array.""" + """Tuple of dimension names associated with this array. + + Note that the type of this property is inconsistent with `Dataset.dims`. + See `Dataset.sizes` and `DataArray.sizes` for consistently named + properties. + """ return self.variable.dims @dims.setter @@ -911,33 +915,6 @@ def transpose(self, *dims): variable = self.variable.transpose(*dims) return self._replace(variable) - def squeeze(self, dim=None): - """Return a new DataArray object with squeezed data. - - Parameters - ---------- - dim : None or str or tuple of str, optional - Selects a subset of the length one dimensions. If a dimension is - selected with length greater than one, an error is raised. If - None, all length one dimensions are squeezed. - - Returns - ------- - squeezed : DataArray - This array, but with with all or a subset of the dimensions of - length 1 removed. - - Notes - ----- - Although this operation returns a view of this array's data, it is - not lazy -- the data will be fully loaded. - - See Also - -------- - numpy.squeeze - """ - return squeeze(self, dict(zip(self.dims, self.shape)), dim) - def drop(self, labels, dim=None): """Drop coordinates or index labels from this DataArray. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 90d67e8e3cf..84671814c42 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -291,11 +291,29 @@ def attrs(self, value): def dims(self): """Mapping from dimension names to lengths. - This dictionary cannot be modified directly, but is updated when adding - new variables. + Cannot be modified directly, but is updated when adding new variables. + + Note that type of this object differs from `DataArray.dims`. + See `Dataset.sizes` and `DataArray.sizes` for consistently named + properties. """ return Frozen(SortedKeysDict(self._dims)) + @property + def sizes(self): + """Mapping from dimension names to lengths. + + Cannot be modified directly, but is updated when adding new variables. + + This is an alias for `Dataset.dims` provided for the benefit of + consistency with `DataArray.sizes`. + + See also + -------- + DataArray.sizes + """ + return self.dims + def load(self): """Manually trigger loading of this dataset's data from disk or a remote source into memory and return this dataset. @@ -1584,33 +1602,6 @@ def transpose(self, *dims): def T(self): return self.transpose() - def squeeze(self, dim=None): - """Returns a new dataset with squeezed data. - - Parameters - ---------- - dim : None or str or tuple of str, optional - Selects a subset of the length one dimensions. If a dimension is - selected with length greater than one, an error is raised. If - None, all length one dimensions are squeezed. - - Returns - ------- - squeezed : Dataset - This dataset, but with with all or a subset of the dimensions of - length 1 removed. - - Notes - ----- - Although this operation returns a view of each variable's data, it is - not lazy -- all variable data will be fully loaded. - - See Also - -------- - numpy.squeeze - """ - return common.squeeze(self, self.dims, dim) - def dropna(self, dim, how='any', thresh=None, subset=None): """Returns a new dataset with dropped labels for missing values along the provided dimension. diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index 2c03cf5421f..c50aaa702ac 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -178,10 +178,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None, raise ValueError("`group` must have a 'dims' attribute") group_dim, = group.dims - try: - expected_size = obj.dims[group_dim] - except TypeError: - expected_size = obj.shape[obj.get_axis_num(group_dim)] + expected_size = obj.sizes[group_dim] if group.size != expected_size: raise ValueError('the group variable\'s length does not ' 'match the length of this variable along its ' diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 4fe967dad6c..5589d8a7660 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -2,7 +2,6 @@ from collections import defaultdict import functools import itertools -import warnings import numpy as np import pandas as pd @@ -192,7 +191,8 @@ def _as_array_or_item(data): return data -class Variable(common.AbstractArray, utils.NdimSizeLenMixin): +class Variable(common.AbstractArray, common.SharedMethodsMixin, + utils.NdimSizeLenMixin): """A netcdf-like variable consisting of dimensions, data and attributes which describe a single Array. A single Variable object is not fully @@ -678,34 +678,6 @@ def transpose(self, *dims): data = ops.transpose(self.data, axes) return type(self)(dims, data, self._attrs, self._encoding, fastpath=True) - def squeeze(self, dim=None): - """Return a new Variable object with squeezed data. - - Parameters - ---------- - dim : None or str or tuple of str, optional - Selects a subset of the length one dimensions. If a dimension is - selected with length greater than one, an error is raised. If - None, all length one dimensions are squeezed. - - Returns - ------- - squeezed : Variable - This array, but with with all or a subset of the dimensions of - length 1 removed. - - Notes - ----- - Although this operation returns a view of this variable's data, it is - not lazy -- the data will be fully loaded. - - See Also - -------- - numpy.squeeze - """ - dims = dict(zip(self.dims, self.shape)) - return common.squeeze(self, dims, dim) - def expand_dims(self, dims, shape=None): """Return a new variable with expanded dimensions. @@ -814,8 +786,7 @@ def _unstack_once(self, dims, old_dim): raise ValueError('cannot create a new dimension with the same ' 'name as an existing dimension') - axis = self.get_axis_num(old_dim) - if np.prod(new_dim_sizes) != self.shape[axis]: + if np.prod(new_dim_sizes) != self.sizes[old_dim]: raise ValueError('the product of the new dimension sizes must ' 'equal the size of the old dimension') @@ -914,7 +885,6 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False, dims = [adim for n, adim in enumerate(self.dims) if n not in removed_axes] - attrs = self._attrs if keep_attrs else None return Variable(dims, data, attrs=attrs) diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index 0c2dfd3c4c7..6a1985060e8 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -159,6 +159,13 @@ def test_dims(self): with self.assertRaisesRegexp(AttributeError, 'you cannot assign'): arr.dims = ('w', 'z') + def test_sizes(self): + array = DataArray(np.zeros((3, 4)), dims=['x', 'y']) + self.assertEqual(array.sizes, {'x': 3, 'y': 4}) + self.assertEqual(tuple(array.sizes), array.dims) + with self.assertRaises(TypeError): + array.sizes['foo'] = 5 + def test_encoding(self): expected = {'foo': 'bar'} self.dv.encoding['foo'] = 'bar' diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index 4fa06d8f0db..87dee7603b5 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -338,6 +338,7 @@ def test_properties(self): self.assertEqual(ds.dims, {'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20}) self.assertEqual(list(ds.dims), sorted(ds.dims)) + self.assertEqual(ds.sizes, ds.dims) # These exact types aren't public API, but this makes sure we don't # change them inadvertently: diff --git a/xarray/test/test_variable.py b/xarray/test/test_variable.py index 7fe7b1fa7d7..a116540585c 100644 --- a/xarray/test/test_variable.py +++ b/xarray/test/test_variable.py @@ -26,6 +26,7 @@ def test_properties(self): self.assertEqual(v.dtype, float) self.assertEqual(v.shape, (10,)) self.assertEqual(v.size, 10) + self.assertEqual(v.sizes, {'time': 10}) self.assertEqual(v.nbytes, 80) self.assertEqual(v.ndim, 1) self.assertEqual(len(v), 10)