diff --git a/doc/data-structures.rst b/doc/data-structures.rst index 97b716653f3..74617a43340 100644 --- a/doc/data-structures.rst +++ b/doc/data-structures.rst @@ -115,10 +115,6 @@ If you create a ``DataArray`` by supplying a pandas df xr.DataArray(df) -Xarray supports labeling coordinate values with a :py:class:`pandas.MultiIndex`. -While it handles multi-indexes with unnamed levels, it is recommended that you -explicitly set the names of the levels. - DataArray properties ~~~~~~~~~~~~~~~~~~~~ @@ -532,6 +528,41 @@ dimension and whose the values are ``Index`` objects: ds.indexes +MultiIndex coordinates +~~~~~~~~~~~~~~~~~~~~~~ + +Xarray supports labeling coordinate values with a :py:class:`pandas.MultiIndex`: + +.. ipython:: python + + midx = pd.MultiIndex.from_arrays([['R', 'R', 'V', 'V'], [.1, .2, .7, .9]], + names=('band', 'wn')) + mda = xr.DataArray(np.random.rand(4), coords={'spec': midx}, dims='spec') + mda + +For convenience multi-index levels are directly accessible as "virtual" or +"derived" coordinates (marked by ``-`` when printing a dataset or data array): + +.. ipython:: python + + mda['band'] + mda.wn + +Indexing with multi-index levels is also possible using the ``sel`` method +(see :ref:`multi-level indexing`). + +Unlike other coordinates, "virtual" level coordinates are not stored in +the ``coords`` attribute of ``DataArray`` and ``Dataset`` objects +(although they are shown when printing the ``coords`` attribute). +Consequently, most of the coordinates related methods don't apply for them. +It also can't be used to replace one particular level. + +Because in a ``DataArray`` or ``Dataset`` object each multi-index level is +accessible as a "virtual" coordinate, its name must not conflict with the names +of the other levels, coordinates and data variables of the same object. +Even though Xarray set default names for multi-indexes with unnamed levels, +it is recommended that you explicitly set the names of the levels. + .. [1] Latitude and longitude are 2D arrays because the dataset uses `projected coordinates`__. ``reference_time`` refers to the reference time at which the forecast was made, rather than ``time`` which is the valid time diff --git a/doc/indexing.rst b/doc/indexing.rst index d21adda2c8e..8e8783ef6f8 100644 --- a/doc/indexing.rst +++ b/doc/indexing.rst @@ -325,11 +325,25 @@ Additionally, xarray supports dictionaries: .. ipython:: python mda.sel(x={'one': 'a', 'two': 0}) - mda.loc[{'one': 'a'}, ...] + +For convenience, ``sel`` also accepts multi-index levels directly +as keyword arguments: + +.. ipython:: python + + mda.sel(one='a', two=0) + +Note that using ``sel`` it is not possible to mix a dimension +indexer with level indexers for that dimension +(e.g., ``mda.sel(x={'one': 'a'}, two=0)`` will raise a ``ValueError``). Like pandas, xarray handles partial selection on multi-index (level drop). -As shown in the last example above, it also renames the dimension / coordinate -when the multi-index is reduced to a single index. +As shown below, it also renames the dimension / coordinate when the +multi-index is reduced to a single index. + +.. ipython:: python + + mda.loc[{'one': 'a'}, ...] Unlike pandas, xarray does not guess whether you provide index levels or dimensions when using ``loc`` in some ambiguous cases. For example, for diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0f2d0535dc0..9795926240c 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,6 +40,13 @@ Deprecations Enhancements ~~~~~~~~~~~~ +- Multi-index levels are now accessible as "virtual" coordinate variables, + e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index + (see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels + as keyword arguments, e.g., ``ds.sel(time='2000-01')`` + (see :ref:`multi-level indexing`). + By `Benoit Bovy `_. + Bug fixes ~~~~~~~~~ diff --git a/xarray/core/coordinates.py b/xarray/core/coordinates.py index 5a5d4d519d0..30240c16b27 100644 --- a/xarray/core/coordinates.py +++ b/xarray/core/coordinates.py @@ -215,6 +215,27 @@ def __delitem__(self, key): del self._data._coords[key] +class DataArrayLevelCoordinates(AbstractCoordinates): + """Dictionary like container for DataArray MultiIndex level coordinates. + + Used for attribute style lookup. Not returned directly by any + public methods. + """ + def __init__(self, dataarray): + self._data = dataarray + + @property + def _names(self): + return set(self._data._level_coords) + + @property + def variables(self): + level_coords = OrderedDict( + (k, self._data[v].variable.get_level_variable(k)) + for k, v in self._data._level_coords.items()) + return Frozen(level_coords) + + class Indexes(Mapping, formatting.ReprMixin): """Ordered Mapping[str, pandas.Index] for xarray objects. """ diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 55b254fa358..7fb9b928823 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -14,11 +14,13 @@ from . import utils from .alignment import align from .common import AbstractArray, BaseDataObject, squeeze -from .coordinates import DataArrayCoordinates, Indexes +from .coordinates import (DataArrayCoordinates, DataArrayLevelCoordinates, + Indexes) from .dataset import Dataset from .pycompat import iteritems, basestring, OrderedDict, zip from .variable import (as_variable, Variable, as_compatible_data, IndexVariable, - default_index_coordinate) + default_index_coordinate, + assert_unique_multiindex_level_names) from .formatting import format_item @@ -82,6 +84,8 @@ def _infer_coords_and_dims(shape, coords, dims): 'length %s on the data but length %s on ' 'coordinate %r' % (d, sizes[d], s, k)) + assert_unique_multiindex_level_names(new_coords) + return new_coords, dims @@ -417,6 +421,20 @@ def _item_key_to_dict(self, key): key = indexing.expanded_indexer(key, self.ndim) return dict(zip(self.dims, key)) + @property + def _level_coords(self): + """Return a mapping of all MultiIndex levels and their corresponding + coordinate name. + """ + level_coords = OrderedDict() + for cname, var in self._coords.items(): + if var.ndim == 1: + level_names = var.to_index_variable().level_names + if level_names is not None: + dim, = var.dims + level_coords.update({lname: dim for lname in level_names}) + return level_coords + def __getitem__(self, key): if isinstance(key, basestring): from .dataset import _get_virtual_variable @@ -424,7 +442,8 @@ def __getitem__(self, key): try: var = self._coords[key] except KeyError: - _, key, var = _get_virtual_variable(self._coords, key) + _, key, var = _get_virtual_variable( + self._coords, key, self._level_coords) return self._replace_maybe_drop_dims(var, name=key) else: @@ -444,7 +463,7 @@ def __delitem__(self, key): @property def _attr_sources(self): """List of places to look-up items for attribute-style access""" - return [self.coords, self.attrs] + return [self.coords, DataArrayLevelCoordinates(self), self.attrs] def __contains__(self, key): return key in self._coords diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6f34107686c..c8f93e9263a 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -33,34 +33,48 @@ 'quarter'] -def _get_virtual_variable(variables, key): - """Get a virtual variable (e.g., 'time.year') from a dict of - xarray.Variable objects (if possible) +def _get_virtual_variable(variables, key, level_vars={}): + """Get a virtual variable (e.g., 'time.year' or a MultiIndex level) + from a dict of xarray.Variable objects (if possible) """ if not isinstance(key, basestring): raise KeyError(key) split_key = key.split('.', 1) - if len(split_key) != 2: + if len(split_key) == 2: + ref_name, var_name = split_key + elif len(split_key) == 1: + ref_name, var_name = key, None + else: raise KeyError(key) - ref_name, var_name = split_key - ref_var = variables[ref_name] - if ref_var.ndim == 1: - date = ref_var.to_index() - elif ref_var.ndim == 0: - date = pd.Timestamp(ref_var.values) + if ref_name in level_vars: + dim_var = variables[level_vars[ref_name]] + ref_var = dim_var.to_index_variable().get_level_variable(ref_name) else: - raise KeyError(key) + ref_var = variables[ref_name] - if var_name == 'season': - # TODO: move 'season' into pandas itself - seasons = np.array(['DJF', 'MAM', 'JJA', 'SON']) - month = date.month - data = seasons[(month // 3) % 4] + if var_name is None: + virtual_var = ref_var + var_name = key else: - data = getattr(date, var_name) - return ref_name, var_name, Variable(ref_var.dims, data) + if ref_var.ndim == 1: + date = ref_var.to_index() + elif ref_var.ndim == 0: + date = pd.Timestamp(ref_var.values) + else: + raise KeyError(key) + + if var_name == 'season': + # TODO: move 'season' into pandas itself + seasons = np.array(['DJF', 'MAM', 'JJA', 'SON']) + month = date.month + data = seasons[(month // 3) % 4] + else: + data = getattr(date, var_name) + virtual_var = Variable(ref_var.dims, data) + + return ref_name, var_name, virtual_var def calculate_dimensions(variables): @@ -424,6 +438,21 @@ def _subset_with_all_valid_coords(self, variables, coord_names, attrs): return self._construct_direct(variables, coord_names, dims, attrs) + @property + def _level_coords(self): + """Return a mapping of all MultiIndex levels and their corresponding + coordinate name. + """ + level_coords = OrderedDict() + for cname in self._coord_names: + var = self.variables[cname] + if var.ndim == 1: + level_names = var.to_index_variable().level_names + if level_names is not None: + dim, = var.dims + level_coords.update({lname: dim for lname in level_names}) + return level_coords + def _copy_listed(self, names): """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. @@ -436,7 +465,7 @@ def _copy_listed(self, names): variables[name] = self._variables[name] except KeyError: ref_name, var_name, var = _get_virtual_variable( - self._variables, name) + self._variables, name, self._level_coords) variables[var_name] = var if ref_name in self._coord_names: coord_names.add(var_name) @@ -452,7 +481,8 @@ def _construct_dataarray(self, name): try: variable = self._variables[name] except KeyError: - _, name, variable = _get_virtual_variable(self._variables, name) + _, name, variable = _get_virtual_variable( + self._variables, name, self._level_coords) coords = OrderedDict() needed_dims = set(variable.dims) @@ -521,6 +551,7 @@ def __setitem__(self, key, value): if utils.is_dict_like(key): raise NotImplementedError('cannot yet use a dictionary as a key ' 'to set Dataset values') + self.update({key: value}) def __delitem__(self, key): diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index e8b45c97e0f..e6a33989935 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -200,9 +200,24 @@ def _summarize_var_or_coord(name, var, col_width, show_values=True, values_str = format_array_flat(var, max_width - len(front_str)) else: values_str = u'...' + return front_str + values_str +def _summarize_coord_multiindex(coord, col_width, marker): + first_col = pretty_print(u' %s %s ' % (marker, coord.name), col_width) + return u'%s(%s) MultiIndex' % (first_col, unicode_type(coord.dims[0])) + + +def _summarize_coord_levels(coord, col_width, marker=u'-'): + relevant_coord = coord[:30] + return u'\n'.join( + [_summarize_var_or_coord(lname, + relevant_coord.get_level_variable(lname), + col_width, marker=marker) + for lname in coord.level_names]) + + def _not_remote(var): """Helper function to identify if array is positively identifiable as coming from a remote source. @@ -222,6 +237,12 @@ def summarize_coord(name, var, col_width): is_index = name in var.dims show_values = is_index or _not_remote(var) marker = u'*' if is_index else u' ' + if is_index: + coord = var.variable.to_index_variable() + if coord.level_names is not None: + return u'\n'.join( + [_summarize_coord_multiindex(coord, col_width, marker), + _summarize_coord_levels(coord, col_width)]) return _summarize_var_or_coord(name, var, col_width, show_values, marker) @@ -233,9 +254,26 @@ def summarize_attr(key, value, col_width=None): EMPTY_REPR = u' *empty*' -def _calculate_col_width(mapping): - max_name_length = (max(len(unicode_type(k)) for k in mapping) - if mapping else 0) +def _get_col_items(mapping): + """Get all column items to format, including both keys of `mapping` + and MultiIndex levels if any. + """ + from .variable import IndexVariable + + col_items = [] + for k, v in mapping.items(): + col_items.append(k) + var = getattr(v, 'variable', v) + if isinstance(var, IndexVariable): + level_names = var.to_index_variable().level_names + if level_names is not None: + col_items += list(level_names) + return col_items + + +def _calculate_col_width(col_items): + max_name_length = (max(len(unicode_type(s)) for s in col_items) + if col_items else 0) col_width = max(max_name_length, 7) + 6 return col_width @@ -251,10 +289,6 @@ def _mapping_repr(mapping, title, summarizer, col_width=None): return u'\n'.join(summary) -coords_repr = functools.partial(_mapping_repr, title=u'Coordinates', - summarizer=summarize_coord) - - vars_repr = functools.partial(_mapping_repr, title=u'Data variables', summarizer=summarize_var) @@ -263,6 +297,13 @@ def _mapping_repr(mapping, title, summarizer, col_width=None): summarizer=summarize_attr) +def coords_repr(coords, col_width=None): + if col_width is None: + col_width = _calculate_col_width(_get_col_items(coords)) + return _mapping_repr(coords, title=u'Coordinates', + summarizer=summarize_coord, col_width=col_width) + + def indexes_repr(indexes): summary = [] for k, v in indexes.items(): @@ -302,7 +343,7 @@ def array_repr(arr): def dataset_repr(ds): summary = [u'' % type(ds).__name__] - col_width = _calculate_col_width(ds) + col_width = _calculate_col_width(_get_col_items(ds)) dims_start = pretty_print(u'Dimensions:', col_width) all_dim_strings = [u'%s: %s' % (k, v) for k, v in iteritems(ds.dims)] diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index a182cae92c4..7b925af2105 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -283,7 +283,7 @@ def _yield_binary_applied(self, func, other): raise TypeError('GroupBy objects only support binary ops ' 'when the other argument is a Dataset or ' 'DataArray') - except KeyError: + except (KeyError, ValueError): if self.group.name not in other.dims: raise ValueError('incompatible dimensions for a grouped ' 'binary operation: the group variable %r ' diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 2fbcb317c37..bfdc6d305ad 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1,4 +1,5 @@ from datetime import timedelta +from collections import defaultdict import numpy as np import pandas as pd @@ -214,6 +215,39 @@ def convert_label_indexer(index, label, index_name='', method=None, return indexer, new_index +def get_dim_indexers(data_obj, indexers): + """Given a xarray data object and label based indexers, return a mapping + of label indexers with only dimension names as keys. + + It groups multiple level indexers given on a multi-index dimension + into a single, dictionary indexer for that dimension (Raise a ValueError + if it is not possible). + """ + invalid = [k for k in indexers + if k not in data_obj.dims and k not in data_obj._level_coords] + if invalid: + raise ValueError("dimensions or multi-index levels %r do not exist" + % invalid) + + level_indexers = defaultdict(dict) + dim_indexers = {} + for key, label in iteritems(indexers): + dim, = data_obj[key].dims + if key != dim: + # assume here multi-index level indexer + level_indexers[dim][key] = label + else: + dim_indexers[key] = label + + for dim, level_labels in iteritems(level_indexers): + if dim_indexers.get(dim, False): + raise ValueError("cannot combine multi-index level indexers " + "with an indexer for dimension %s" % dim) + dim_indexers[dim] = level_labels + + return dim_indexers + + def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): """Given an xarray data object and label based indexers, return a mapping of equivalent location based indexers. Also return a mapping of updated @@ -223,7 +257,7 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None): raise TypeError('``method`` must be a string') pos_indexers, new_indexes = {}, {} - for dim, label in iteritems(indexers): + for dim, label in iteritems(get_dim_indexers(data_obj, indexers)): index = data_obj[dim].to_index() idxr, new_idx = convert_label_indexer(index, label, dim, method, tolerance) diff --git a/xarray/core/merge.py b/xarray/core/merge.py index 28eb8ecea84..8cab70f9fff 100644 --- a/xarray/core/merge.py +++ b/xarray/core/merge.py @@ -2,7 +2,8 @@ from .alignment import align from .utils import Frozen, is_dict_like -from .variable import as_variable, default_index_coordinate +from .variable import (as_variable, default_index_coordinate, + assert_unique_multiindex_level_names) from .pycompat import (basestring, OrderedDict) @@ -110,7 +111,7 @@ def merge_variables( If provided, variables are always taken from this dict in preference to the input variable dictionaries, without checking for conflicts. compat : {'identical', 'equals', 'broadcast_equals', 'minimal'}, optional - Type of equality check to use wben checking for conflicts. + Type of equality check to use when checking for conflicts. Returns ------- @@ -278,6 +279,7 @@ def merge_coords_without_align(objs, priority_vars=None): """ expanded = expand_variable_dicts(objs) variables = merge_variables(expanded, priority_vars) + assert_unique_multiindex_level_names(variables) return variables @@ -370,6 +372,7 @@ def merge_coords(objs, compat='minimal', join='outer', priority_arg=None, expanded = expand_variable_dicts(aligned) priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) variables = merge_variables(expanded, priority_vars, compat=compat) + assert_unique_multiindex_level_names(variables) return variables @@ -431,6 +434,7 @@ def merge_core(objs, compat='broadcast_equals', join='outer', priority_arg=None, priority_vars = _get_priority_vars(aligned, priority_arg, compat=compat) variables = merge_variables(expanded, priority_vars, compat=compat) + assert_unique_multiindex_level_names(variables) dims = calculate_dimensions(variables) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index d29137fb61b..755cd0b9aee 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1,4 +1,5 @@ from datetime import timedelta +from collections import defaultdict import functools import itertools import warnings @@ -1071,12 +1072,13 @@ class IndexVariable(Variable): of a NumPy array. Hence, their values are immutable and must always be one- dimensional. - They also have a name property, which is the name of their sole dimension. + They also have a name property, which is the name of their sole dimension + unless another name is given. """ - def __init__(self, name, data, attrs=None, encoding=None, fastpath=False): - super(IndexVariable, self).__init__( - name, data, attrs, encoding, fastpath) + def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False): + super(IndexVariable, self).__init__(dims, data, attrs, encoding, + fastpath) if self.ndim != 1: raise ValueError('%s objects must be 1-dimensional' % type(self).__name__) @@ -1092,8 +1094,8 @@ def __getitem__(self, key): if not hasattr(values, 'ndim') or values.ndim == 0: return Variable((), values, self._attrs, self._encoding) else: - return type(self)(self.dims, values, self._attrs, self._encoding, - fastpath=True) + return type(self)(self.dims, values, self._attrs, + self._encoding, fastpath=True) def __setitem__(self, key, value): raise TypeError('%s values cannot be modified' % type(self).__name__) @@ -1145,8 +1147,8 @@ def copy(self, deep=True): # there is no need to copy the index values here even if deep=True # since pandas.Index objects are immutable data = PandasIndexAdapter(self) if deep else self._data - return type(self)(self.dims, data, self._attrs, self._encoding, - fastpath=True) + return type(self)(self.dims, data, self._attrs, + self._encoding, fastpath=True) def _data_equals(self, other): return self.to_index().equals(other.to_index()) @@ -1166,13 +1168,31 @@ def to_index(self): if isinstance(index, pd.MultiIndex): # set default names for multi-index unnamed levels so that # we can safely rename dimension / coordinate later - valid_level_names = [name or '{}_level_{}'.format(self.name, i) + valid_level_names = [name or '{}_level_{}'.format(self.dims[0], i) for i, name in enumerate(index.names)] index = index.set_names(valid_level_names) else: index = index.set_names(self.name) return index + @property + def level_names(self): + """Return MultiIndex level names or None if this IndexVariable has no + MultiIndex. + """ + index = self.to_index() + if isinstance(index, pd.MultiIndex): + return index.names + else: + return None + + def get_level_variable(self, level): + """Return a new IndexVariable from a given MultiIndex level.""" + if self.level_names is None: + raise ValueError("IndexVariable %r has no MultiIndex" % self.name) + index = self.to_index() + return type(self)(self.dims, index.get_level_values(level)) + @property def name(self): return self.dims[0] @@ -1276,3 +1296,29 @@ def concat(variables, dim='concat_dim', positions=None, shortcut=False): return IndexVariable.concat(variables, dim, positions, shortcut) else: return Variable.concat(variables, dim, positions, shortcut) + + +def assert_unique_multiindex_level_names(variables): + """Check for uniqueness of MultiIndex level names in all given + variables. + + Not public API. Used for checking consistency of DataArray and Dataset + objects. + """ + level_names = defaultdict(list) + for var_name, var in variables.items(): + if isinstance(var._data, PandasIndexAdapter): + idx_level_names = var.to_index_variable().level_names + if idx_level_names is not None: + for n in idx_level_names: + level_names[n].append('%r (%s)' % (n, var_name)) + + for k, v in level_names.items(): + if k in variables: + v.append('(%s)' % k) + + duplicate_names = [v for v in level_names.values() if len(v) > 1] + if duplicate_names: + conflict_str = '\n'.join([', '.join(v) for v in duplicate_names]) + raise ValueError('conflicting MultiIndex level name(s):\n%s' + % conflict_str) diff --git a/xarray/test/test_dataarray.py b/xarray/test/test_dataarray.py index 2e789516574..f613259f128 100644 --- a/xarray/test/test_dataarray.py +++ b/xarray/test/test_dataarray.py @@ -24,6 +24,10 @@ def setUp(self): self.ds = Dataset({'foo': self.v}) self.dv = self.ds['foo'] + self.mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=('level_1', 'level_2')) + self.mda = DataArray([0, 1, 2, 3], coords={'x': self.mindex}, dims='x') + def test_repr(self): v = Variable(['time', 'x'], [[1, 2, 3], [4, 5, 6]], {'foo': 'bar'}) data_array = DataArray(v, {'other': np.int64(0)}, name='my_variable') @@ -39,6 +43,16 @@ def test_repr(self): foo: bar""") self.assertEqual(expected, repr(data_array)) + def test_repr_multiindex(self): + expected = dedent("""\ + + array([0, 1, 2, 3]) + Coordinates: + * x (x) MultiIndex + - level_1 (x) object 'a' 'a' 'b' 'b' + - level_2 (x) int64 1 2 1 2""") + self.assertEqual(expected, repr(self.mda)) + def test_properties(self): self.assertVariableEqual(self.dv.variable, self.v) self.assertArrayEqual(self.dv.values, self.v.values) @@ -236,6 +250,11 @@ def test_constructor_invalid(self): with self.assertRaisesRegexp(ValueError, 'conflicting sizes for dim'): DataArray([1, 2], coords={'x': [0, 1], 'y': ('x', [1])}, dims='x') + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + DataArray(np.random.rand(4, 4), + [('x', self.mindex), ('y', self.mindex)]) + DataArray(np.random.rand(4, 4), + [('x', mindex), ('level_1', range(4))]) def test_constructor_from_self_described(self): data = [[-0.1, 21], [0, 2]] @@ -405,6 +424,11 @@ def test_getitem_coords(self): dims='x') self.assertDataArrayIdentical(expected, actual) + def test_attr_sources_multiindex(self): + # make sure attr-style access for multi-index levels + # returns DataArray objects + self.assertIsInstance(self.mda.level_1, DataArray) + def test_pickle(self): data = DataArray(np.random.random((3, 3)), dims=('id', 'time')) roundtripped = pickle.loads(pickle.dumps(data)) @@ -543,7 +567,7 @@ def test_loc_single_boolean(self): self.assertEqual(data.loc[True], 0) self.assertEqual(data.loc[False], 1) - def test_multiindex(self): + def test_selection_multiindex(self): mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], names=('one', 'two', 'three')) mdata = DataArray(range(8), [('x', mindex)]) @@ -579,11 +603,12 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, mdata.sel(x=('a', 1))) self.assertDataArrayIdentical(mdata.loc[{'one': 'a'}, ...], mdata.sel(x={'one': 'a'})) - with self.assertRaises(KeyError): - mdata.loc[{'one': 'a'}] with self.assertRaises(IndexError): mdata.loc[('a', 1)] + self.assertDataArrayIdentical(mdata.sel(x={'one': 'a', 'two': 1}), + mdata.sel(one='a', two=1)) + def test_time_components(self): dates = pd.date_range('2000-01-01', periods=10) da = DataArray(np.arange(1, 11), [('time', dates)]) @@ -626,6 +651,10 @@ def test_coords(self): with self.assertRaisesRegexp(ValueError, 'cannot delete'): del da.coords['x'] + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + self.mda['level_1'] = np.arange(4) + self.mda.coords['level_1'] = np.arange(4) + def test_coord_coords(self): orig = DataArray([10, 20], {'x': [1, 2], 'x2': ('x', ['a', 'b']), 'z': 4}, @@ -706,6 +735,9 @@ def test_assign_coords(self): expected.coords['d'] = ('x', [1.5, 1.5, 3.5, 3.5]) self.assertDataArrayIdentical(actual, expected) + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + self.mda.assign_coords(level_1=range(4)) + def test_coords_alignment(self): lhs = DataArray([1, 2, 3], [('x', [0, 1, 2])]) rhs = DataArray([2, 3, 4], [('x', [1, 2, 3])]) diff --git a/xarray/test/test_dataset.py b/xarray/test/test_dataset.py index 61170f85a71..a1da10b4ca5 100644 --- a/xarray/test/test_dataset.py +++ b/xarray/test/test_dataset.py @@ -45,6 +45,12 @@ def create_test_data(seed=None): return obj +def create_test_multiindex(): + mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=('level_1', 'level_2')) + return Dataset({}, {'x': mindex}) + + class InaccessibleVariableDataStore(backends.InMemoryDataStore): def get_variables(self): def lazy_inaccessible(x): @@ -110,6 +116,39 @@ def test_repr(self): data = Dataset(attrs={'foo': 'bar' * 1000}) self.assertTrue(len(repr(data)) < 1000) + def test_repr_multiindex(self): + data = create_test_multiindex() + expected = dedent("""\ + + Dimensions: (x: 4) + Coordinates: + * x (x) MultiIndex + - level_1 (x) object 'a' 'a' 'b' 'b' + - level_2 (x) int64 1 2 1 2 + Data variables: + *empty*""") + actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + print(actual) + self.assertEqual(expected, actual) + + # verify that long level names are not truncated + mindex = pd.MultiIndex.from_product( + [['a', 'b'], [1, 2]], + names=('a_quite_long_level_name', 'level_2')) + data = Dataset({}, {'x': mindex}) + expected = dedent("""\ + + Dimensions: (x: 4) + Coordinates: + * x (x) MultiIndex + - a_quite_long_level_name (x) object 'a' 'a' 'b' 'b' + - level_2 (x) int64 1 2 1 2 + Data variables: + *empty*""") + actual = '\n'.join(x.rstrip() for x in repr(data).split('\n')) + print(actual) + self.assertEqual(expected, actual) + def test_repr_period_index(self): data = create_test_data(seed=456) data.coords['time'] = pd.period_range('2000-01-01', periods=20, freq='B') @@ -288,6 +327,12 @@ def test_constructor_with_coords(self): self.assertFalse(ds.data_vars) self.assertItemsEqual(ds.coords.keys(), ['x', 'a']) + mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=('level_1', 'level_2')) + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + Dataset({}, {'x': mindex, 'y': mindex}) + Dataset({}, {'x': mindex, 'level_1': range(4)}) + def test_properties(self): ds = create_test_data() self.assertEqual(ds.dims, @@ -466,6 +511,11 @@ def test_coords_setitem_with_new_dimension(self): expected = Dataset(coords={'foo': ('x', [1, 2, 3])}) self.assertDatasetIdentical(expected, actual) + def test_coords_setitem_multiindex(self): + data = create_test_multiindex() + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + data.coords['level_1'] = range(4) + def test_coords_set(self): one_coord = Dataset({'x': ('x', [0]), 'yy': ('x', [1]), @@ -876,7 +926,7 @@ def test_loc(self): with self.assertRaises(TypeError): data.loc[dict(dim3='a')] = 0 - def test_multiindex(self): + def test_selection_multiindex(self): mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2], [-1, -2]], names=('one', 'two', 'three')) mdata = Dataset(data_vars={'var': ('x', range(8))}, @@ -916,8 +966,9 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False, mdata.sel(x=('a', 1))) self.assertDatasetIdentical(mdata.loc[{'x': ('a', 1, -1)}], mdata.sel(x=('a', 1, -1))) - with self.assertRaises(KeyError): - mdata.loc[{'one': 'a'}] + + self.assertDatasetIdentical(mdata.sel(x={'one': 'a', 'two': 1}), + mdata.sel(one='a', two=1)) def test_reindex_like(self): data = create_test_data() @@ -1473,6 +1524,22 @@ def test_virtual_variable_same_name(self): expected = DataArray(times.time, [('time', times)], name='time') self.assertDataArrayIdentical(actual, expected) + def test_virtual_variable_multiindex(self): + # access multi-index levels as virtual variables + data = create_test_multiindex() + expected = DataArray(['a', 'a', 'b', 'b'], name='level_1', + coords=[data['x'].to_index()], dims='x') + self.assertDataArrayIdentical(expected, data['level_1']) + + # combine multi-index level and datetime + dr_index = pd.date_range('1/1/2011', periods=4, freq='H') + mindex = pd.MultiIndex.from_arrays([['a', 'a', 'b', 'b'], dr_index], + names=('level_str', 'level_date')) + data = Dataset({}, {'x': mindex}) + expected = DataArray(mindex.get_level_values('level_date').hour, + name='hour', coords=[mindex], dims='x') + self.assertDataArrayIdentical(expected, data['level_date.hour']) + def test_time_season(self): ds = Dataset({'t': pd.date_range('2000-01-01', periods=12, freq='M')}) expected = ['DJF'] * 2 + ['MAM'] * 3 + ['JJA'] * 3 + ['SON'] * 3 + ['DJF'] @@ -1589,6 +1656,12 @@ def test_assign(self): expected = expected.set_coords('c') self.assertDatasetIdentical(actual, expected) + def test_assign_multiindex_level(self): + data = create_test_multiindex() + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + data.assign(level_1=range(4)) + data.assign_coords(level_1=range(4)) + def test_setitem_original_non_unique_index(self): # regression test for GH943 original = Dataset({'data': ('x', np.arange(5))}, @@ -1618,6 +1691,11 @@ def test_setitem_both_non_unique_index(self): actual['second'] = array self.assertDatasetIdentical(expected, actual) + def test_setitem_multiindex_level(self): + data = create_test_multiindex() + with self.assertRaisesRegexp(ValueError, 'conflicting MultiIndex'): + data['level_1'] = range(4) + def test_delitem(self): data = create_test_data() all_items = set(data) @@ -1744,9 +1822,9 @@ def test_groupby_math(self): actual = zeros + grouped self.assertDatasetEqual(expected, actual) - with self.assertRaisesRegexp(ValueError, 'dimensions .* do not exist'): + with self.assertRaisesRegexp(ValueError, 'incompat.* grouped binary'): grouped + ds - with self.assertRaisesRegexp(ValueError, 'dimensions .* do not exist'): + with self.assertRaisesRegexp(ValueError, 'incompat.* grouped binary'): ds + grouped with self.assertRaisesRegexp(TypeError, 'only support binary ops'): grouped + 1 diff --git a/xarray/test/test_indexing.py b/xarray/test/test_indexing.py index 1dca99ec99a..7ed3f5bc372 100644 --- a/xarray/test/test_indexing.py +++ b/xarray/test/test_indexing.py @@ -107,8 +107,22 @@ def test_convert_unsorted_datetime_index_raises(self): # slice is always a view. indexing.convert_label_indexer(index, slice('2001', '2002')) + def test_get_dim_indexers(self): + mindex = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=('one', 'two')) + mdata = DataArray(range(4), [('x', mindex)]) + + dim_indexers = indexing.get_dim_indexers(mdata, {'one': 'a', 'two': 1}) + self.assertEqual(dim_indexers, {'x': {'one': 'a', 'two': 1}}) + + with self.assertRaisesRegexp(ValueError, 'cannot combine'): + _ = indexing.get_dim_indexers(mdata, {'x': 'a', 'two': 1}) + + with self.assertRaisesRegexp(ValueError, 'do not exist'): + _ = indexing.get_dim_indexers(mdata, {'y': 'a'}) + _ = indexing.get_dim_indexers(data, {'four': 1}) + def test_remap_label_indexers(self): - # TODO: fill in more tests! def test_indexer(data, x, expected_pos, expected_idx=None): pos, idx = indexing.remap_label_indexers(data, {'x': x}) self.assertArrayEqual(pos.get('x'), expected_pos) diff --git a/xarray/test/test_variable.py b/xarray/test/test_variable.py index 10e360f5322..d6a61975659 100644 --- a/xarray/test/test_variable.py +++ b/xarray/test/test_variable.py @@ -1038,6 +1038,24 @@ def test_name(self): with self.assertRaises(AttributeError): coord.name = 'y' + def test_level_names(self): + midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=['level_1', 'level_2']) + x = IndexVariable('x', midx) + self.assertEqual(x.level_names, midx.names) + + self.assertIsNone(IndexVariable('y', [10.0]).level_names) + + def test_get_level_variable(self): + midx = pd.MultiIndex.from_product([['a', 'b'], [1, 2]], + names=['level_1', 'level_2']) + x = IndexVariable('x', midx) + level_1 = IndexVariable('x', midx.get_level_values('level_1')) + self.assertVariableIdentical(x.get_level_variable('level_1'), level_1) + + with self.assertRaisesRegexp(ValueError, 'has no MultiIndex'): + IndexVariable('y', [10.0]).get_level_variable('level') + def test_concat_periods(self): periods = pd.period_range('2000-01-01', periods=10) coords = [IndexVariable('t', periods[:5]), IndexVariable('t', periods[5:])]