Skip to content

Commit 3f490a3

Browse files
authored
New properties Dataset.sizes and DataArray.sizes (#1076)
This allows for consistent access to dimension lengths on ``Dataset`` and ``DataArray`` xref #921 (doesn't resolve it 100%, but should help significantly)
1 parent a4f5ec2 commit 3f490a3

10 files changed

+101
-112
lines changed

doc/api.rst

+2
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Attributes
3939
:toctree: generated/
4040

4141
Dataset.dims
42+
Dataset.sizes
4243
Dataset.data_vars
4344
Dataset.coords
4445
Dataset.attrs
@@ -187,6 +188,7 @@ Attributes
187188
DataArray.data
188189
DataArray.coords
189190
DataArray.dims
191+
DataArray.sizes
190192
DataArray.name
191193
DataArray.attrs
192194
DataArray.encoding

doc/whats-new.rst

+5
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ Enhancements
7777
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
7878
<https://github.com/pwolfram>`_.
7979

80+
- New properties :py:attr:`Dataset.sizes` and :py:attr:`DataArray.sizes` for
81+
providing consistent access to dimension length on both ``Dataset`` and
82+
``DataArray`` (:issue:`921`).
83+
By `Stephan Hoyer <https://github.com/shoyer>`_.
84+
8085
Bug fixes
8186
~~~~~~~~~
8287
- ``groupby_bins`` now restores empty bins by default (:issue:`1019`).

xarray/core/common.py

+54-16
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import pandas as pd
33

4-
from .pycompat import basestring, iteritems, suppress, dask_array_type, bytes_type
4+
from .pycompat import (basestring, iteritems, suppress, dask_array_type,
5+
OrderedDict)
56
from . import formatting
6-
from .utils import SortedKeysDict, not_implemented
7+
from .utils import SortedKeysDict, not_implemented, Frozen
78

89

910
class ImplementsArrayReduce(object):
@@ -124,6 +125,8 @@ def wrapped_func(self, **kwargs):
124125

125126

126127
class AbstractArray(ImplementsArrayReduce, formatting.ReprMixin):
128+
"""Shared base class for DataArray and Variable."""
129+
127130
def __bool__(self):
128131
return bool(self.values)
129132

@@ -186,6 +189,18 @@ def _get_axis_num(self, dim):
186189
raise ValueError("%r not found in array dimensions %r" %
187190
(dim, self.dims))
188191

192+
@property
193+
def sizes(self):
194+
"""Ordered mapping from dimension names to lengths.
195+
196+
Immutable.
197+
198+
See also
199+
--------
200+
Dataset.sizes
201+
"""
202+
return Frozen(OrderedDict(zip(self.dims, self.shape)))
203+
189204

190205
class AttrAccessMixin(object):
191206
"""Mixin class that allows getting keys with attribute access
@@ -231,7 +246,43 @@ def __dir__(self):
231246
return sorted(set(dir(type(self)) + extra_attrs))
232247

233248

234-
class BaseDataObject(AttrAccessMixin):
249+
class SharedMethodsMixin(object):
250+
"""Shared methods for Dataset, DataArray and Variable."""
251+
252+
def squeeze(self, dim=None):
253+
"""Return a new object with squeezed data.
254+
255+
Parameters
256+
----------
257+
dim : None or str or tuple of str, optional
258+
Selects a subset of the length one dimensions. If a dimension is
259+
selected with length greater than one, an error is raised. If
260+
None, all length one dimensions are squeezed.
261+
262+
Returns
263+
-------
264+
squeezed : same type as caller
265+
This object, but with with all or a subset of the dimensions of
266+
length 1 removed.
267+
268+
See Also
269+
--------
270+
numpy.squeeze
271+
"""
272+
if dim is None:
273+
dim = [d for d, s in self.sizes.items() if s == 1]
274+
else:
275+
if isinstance(dim, basestring):
276+
dim = [dim]
277+
if any(self.sizes[k] > 1 for k in dim):
278+
raise ValueError('cannot select a dimension to squeeze out '
279+
'which has length greater than one')
280+
return self.isel(**{d: 0 for d in dim})
281+
282+
283+
class BaseDataObject(SharedMethodsMixin, AttrAccessMixin):
284+
"""Shared base class for Dataset and DataArray."""
285+
235286
def _calc_assign_results(self, kwargs):
236287
results = SortedKeysDict()
237288
for k, v in kwargs.items():
@@ -615,19 +666,6 @@ def __exit__(self, exc_type, exc_value, traceback):
615666
__or__ = __div__ = __eq__ = __ne__ = not_implemented
616667

617668

618-
def squeeze(xarray_obj, dims, dim=None):
619-
"""Squeeze the dims of an xarray object."""
620-
if dim is None:
621-
dim = [d for d, s in iteritems(dims) if s == 1]
622-
else:
623-
if isinstance(dim, basestring):
624-
dim = [dim]
625-
if any(dims[k] > 1 for k in dim):
626-
raise ValueError('cannot select a dimension to squeeze out '
627-
'which has length greater than one')
628-
return xarray_obj.isel(**dict((d, 0) for d in dim))
629-
630-
631669
def _maybe_promote(dtype):
632670
"""Simpler equivalent of pandas.core.common._maybe_promote"""
633671
# N.B. these casting rules should match pandas

xarray/core/dataarray.py

+7-30
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import contextlib
21
import functools
32
import warnings
43

@@ -13,7 +12,7 @@
1312
from . import ops
1413
from . import utils
1514
from .alignment import align
16-
from .common import AbstractArray, BaseDataObject, squeeze
15+
from .common import AbstractArray, BaseDataObject
1716
from .coordinates import (DataArrayCoordinates, LevelCoordinates,
1817
Indexes)
1918
from .dataset import Dataset
@@ -411,7 +410,12 @@ def to_index(self):
411410

412411
@property
413412
def dims(self):
414-
"""Dimension names associated with this array."""
413+
"""Tuple of dimension names associated with this array.
414+
415+
Note that the type of this property is inconsistent with `Dataset.dims`.
416+
See `Dataset.sizes` and `DataArray.sizes` for consistently named
417+
properties.
418+
"""
415419
return self.variable.dims
416420

417421
@dims.setter
@@ -911,33 +915,6 @@ def transpose(self, *dims):
911915
variable = self.variable.transpose(*dims)
912916
return self._replace(variable)
913917

914-
def squeeze(self, dim=None):
915-
"""Return a new DataArray object with squeezed data.
916-
917-
Parameters
918-
----------
919-
dim : None or str or tuple of str, optional
920-
Selects a subset of the length one dimensions. If a dimension is
921-
selected with length greater than one, an error is raised. If
922-
None, all length one dimensions are squeezed.
923-
924-
Returns
925-
-------
926-
squeezed : DataArray
927-
This array, but with with all or a subset of the dimensions of
928-
length 1 removed.
929-
930-
Notes
931-
-----
932-
Although this operation returns a view of this array's data, it is
933-
not lazy -- the data will be fully loaded.
934-
935-
See Also
936-
--------
937-
numpy.squeeze
938-
"""
939-
return squeeze(self, dict(zip(self.dims, self.shape)), dim)
940-
941918
def drop(self, labels, dim=None):
942919
"""Drop coordinates or index labels from this DataArray.
943920

xarray/core/dataset.py

+20-29
Original file line numberDiff line numberDiff line change
@@ -291,11 +291,29 @@ def attrs(self, value):
291291
def dims(self):
292292
"""Mapping from dimension names to lengths.
293293
294-
This dictionary cannot be modified directly, but is updated when adding
295-
new variables.
294+
Cannot be modified directly, but is updated when adding new variables.
295+
296+
Note that type of this object differs from `DataArray.dims`.
297+
See `Dataset.sizes` and `DataArray.sizes` for consistently named
298+
properties.
296299
"""
297300
return Frozen(SortedKeysDict(self._dims))
298301

302+
@property
303+
def sizes(self):
304+
"""Mapping from dimension names to lengths.
305+
306+
Cannot be modified directly, but is updated when adding new variables.
307+
308+
This is an alias for `Dataset.dims` provided for the benefit of
309+
consistency with `DataArray.sizes`.
310+
311+
See also
312+
--------
313+
DataArray.sizes
314+
"""
315+
return self.dims
316+
299317
def load(self):
300318
"""Manually trigger loading of this dataset's data from disk or a
301319
remote source into memory and return this dataset.
@@ -1584,33 +1602,6 @@ def transpose(self, *dims):
15841602
def T(self):
15851603
return self.transpose()
15861604

1587-
def squeeze(self, dim=None):
1588-
"""Returns a new dataset with squeezed data.
1589-
1590-
Parameters
1591-
----------
1592-
dim : None or str or tuple of str, optional
1593-
Selects a subset of the length one dimensions. If a dimension is
1594-
selected with length greater than one, an error is raised. If
1595-
None, all length one dimensions are squeezed.
1596-
1597-
Returns
1598-
-------
1599-
squeezed : Dataset
1600-
This dataset, but with with all or a subset of the dimensions of
1601-
length 1 removed.
1602-
1603-
Notes
1604-
-----
1605-
Although this operation returns a view of each variable's data, it is
1606-
not lazy -- all variable data will be fully loaded.
1607-
1608-
See Also
1609-
--------
1610-
numpy.squeeze
1611-
"""
1612-
return common.squeeze(self, self.dims, dim)
1613-
16141605
def dropna(self, dim, how='any', thresh=None, subset=None):
16151606
"""Returns a new dataset with dropped labels for missing values along
16161607
the provided dimension.

xarray/core/groupby.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -178,10 +178,7 @@ def __init__(self, obj, group, squeeze=False, grouper=None, bins=None,
178178
raise ValueError("`group` must have a 'dims' attribute")
179179
group_dim, = group.dims
180180

181-
try:
182-
expected_size = obj.dims[group_dim]
183-
except TypeError:
184-
expected_size = obj.shape[obj.get_axis_num(group_dim)]
181+
expected_size = obj.sizes[group_dim]
185182
if group.size != expected_size:
186183
raise ValueError('the group variable\'s length does not '
187184
'match the length of this variable along its '

xarray/core/variable.py

+3-33
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections import defaultdict
33
import functools
44
import itertools
5-
import warnings
65

76
import numpy as np
87
import pandas as pd
@@ -192,7 +191,8 @@ def _as_array_or_item(data):
192191
return data
193192

194193

195-
class Variable(common.AbstractArray, utils.NdimSizeLenMixin):
194+
class Variable(common.AbstractArray, common.SharedMethodsMixin,
195+
utils.NdimSizeLenMixin):
196196

197197
"""A netcdf-like variable consisting of dimensions, data and attributes
198198
which describe a single Array. A single Variable object is not fully
@@ -678,34 +678,6 @@ def transpose(self, *dims):
678678
data = ops.transpose(self.data, axes)
679679
return type(self)(dims, data, self._attrs, self._encoding, fastpath=True)
680680

681-
def squeeze(self, dim=None):
682-
"""Return a new Variable object with squeezed data.
683-
684-
Parameters
685-
----------
686-
dim : None or str or tuple of str, optional
687-
Selects a subset of the length one dimensions. If a dimension is
688-
selected with length greater than one, an error is raised. If
689-
None, all length one dimensions are squeezed.
690-
691-
Returns
692-
-------
693-
squeezed : Variable
694-
This array, but with with all or a subset of the dimensions of
695-
length 1 removed.
696-
697-
Notes
698-
-----
699-
Although this operation returns a view of this variable's data, it is
700-
not lazy -- the data will be fully loaded.
701-
702-
See Also
703-
--------
704-
numpy.squeeze
705-
"""
706-
dims = dict(zip(self.dims, self.shape))
707-
return common.squeeze(self, dims, dim)
708-
709681
def expand_dims(self, dims, shape=None):
710682
"""Return a new variable with expanded dimensions.
711683
@@ -814,8 +786,7 @@ def _unstack_once(self, dims, old_dim):
814786
raise ValueError('cannot create a new dimension with the same '
815787
'name as an existing dimension')
816788

817-
axis = self.get_axis_num(old_dim)
818-
if np.prod(new_dim_sizes) != self.shape[axis]:
789+
if np.prod(new_dim_sizes) != self.sizes[old_dim]:
819790
raise ValueError('the product of the new dimension sizes must '
820791
'equal the size of the old dimension')
821792

@@ -914,7 +885,6 @@ def reduce(self, func, dim=None, axis=None, keep_attrs=False,
914885
dims = [adim for n, adim in enumerate(self.dims)
915886
if n not in removed_axes]
916887

917-
918888
attrs = self._attrs if keep_attrs else None
919889

920890
return Variable(dims, data, attrs=attrs)

xarray/test/test_dataarray.py

+7
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,13 @@ def test_dims(self):
159159
with self.assertRaisesRegexp(AttributeError, 'you cannot assign'):
160160
arr.dims = ('w', 'z')
161161

162+
def test_sizes(self):
163+
array = DataArray(np.zeros((3, 4)), dims=['x', 'y'])
164+
self.assertEqual(array.sizes, {'x': 3, 'y': 4})
165+
self.assertEqual(tuple(array.sizes), array.dims)
166+
with self.assertRaises(TypeError):
167+
array.sizes['foo'] = 5
168+
162169
def test_encoding(self):
163170
expected = {'foo': 'bar'}
164171
self.dv.encoding['foo'] = 'bar'

xarray/test/test_dataset.py

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@ def test_properties(self):
338338
self.assertEqual(ds.dims,
339339
{'dim1': 8, 'dim2': 9, 'dim3': 10, 'time': 20})
340340
self.assertEqual(list(ds.dims), sorted(ds.dims))
341+
self.assertEqual(ds.sizes, ds.dims)
341342

342343
# These exact types aren't public API, but this makes sure we don't
343344
# change them inadvertently:

xarray/test/test_variable.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def test_properties(self):
2626
self.assertEqual(v.dtype, float)
2727
self.assertEqual(v.shape, (10,))
2828
self.assertEqual(v.size, 10)
29+
self.assertEqual(v.sizes, {'time': 10})
2930
self.assertEqual(v.nbytes, 80)
3031
self.assertEqual(v.ndim, 1)
3132
self.assertEqual(len(v), 10)

0 commit comments

Comments
 (0)