Skip to content

Options to binary ops kwargs #1065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 12, 2016
13 changes: 11 additions & 2 deletions doc/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,8 @@ coordinates with the same name as a dimension, marked by ``*``) on objects used
in binary operations.

Similarly to pandas, this alignment is automatic for arithmetic on binary
operations. Note that unlike pandas, this the result of a binary operation is
by the *intersection* (not the union) of coordinate labels:
operations. The default result of a binary operation is by the *intersection*
(not the union) of coordinate labels:

.. ipython:: python

Expand All @@ -225,6 +225,15 @@ If the result would be empty, an error is raised instead:
In [1]: arr[:2] + arr[2:]
ValueError: no overlapping labels for some dimensions: ['x']

However, one can explicitly change this default automatic alignment type ("inner")
via :py:func:`~xarray.set_options()` in context manager:

.. ipython:: python

with xr.set_options(arithmetic_join="outer"):
arr + arr[:1]
arr + arr[:1]

Before loops or performance critical code, it's a good idea to align arrays
explicitly (e.g., by putting them in the same Dataset or using
:py:func:`~xarray.align`) to avoid the overhead of repeated alignment with each
Expand Down
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ Deprecations

Enhancements
~~~~~~~~~~~~
- Added the ability to change default automatic alignment (arithmetic_join="inner")
for binary operations via :py:func:`~xarray.set_options()`
(see :ref:`automatic alignment`).
By `Chun-Wei Yuan <https://github.com/chunweiyuan>`_.

- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.
Expand Down
6 changes: 4 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
assert_unique_multiindex_level_names)
from .formatting import format_item
from .utils import decode_numpy_dict_values, ensure_us_time_resolution
from .options import OPTIONS


def _infer_coords_and_dims(shape, coords, dims):
Expand Down Expand Up @@ -1357,13 +1358,14 @@ def func(self, *args, **kwargs):
return func

@staticmethod
def _binary_op(f, reflexive=False, join='inner', **ignored_kwargs):
def _binary_op(f, reflexive=False, join=None, **ignored_kwargs):
@functools.wraps(f)
def func(self, other):
if isinstance(other, (Dataset, groupby.GroupBy)):
return NotImplemented
if hasattr(other, 'indexes'):
self, other = align(self, other, join=join, copy=False)
align_type = OPTIONS['arithmetic_join'] if join is None else join
self, other = align(self, other, join=align_type, copy=False)
other_variable = getattr(other, 'variable', other)
other_coords = getattr(other, 'coords', None)

Expand Down
35 changes: 22 additions & 13 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .pycompat import (iteritems, basestring, OrderedDict,
dask_array_type)
from .combine import concat
from .options import OPTIONS


# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down Expand Up @@ -2012,15 +2013,17 @@ def func(self, *args, **kwargs):
return func

@staticmethod
def _binary_op(f, reflexive=False, join='inner', fillna=False):
def _binary_op(f, reflexive=False, join=None, fillna=False):
@functools.wraps(f)
def func(self, other):
if isinstance(other, groupby.GroupBy):
return NotImplemented
align_type = OPTIONS['arithmetic_join'] if join is None else join
if hasattr(other, 'indexes'):
self, other = align(self, other, join=join, copy=False)
self, other = align(self, other, join=align_type, copy=False)
g = f if not reflexive else lambda x, y: f(y, x)
ds = self._calculate_binary_op(g, other, fillna=fillna)
ds = self._calculate_binary_op(g, other, join=align_type,
fillna=fillna)
return ds
return func

Expand All @@ -2042,25 +2045,32 @@ def func(self, other):
return self
return func

def _calculate_binary_op(self, f, other, inplace=False, fillna=False):
def _calculate_binary_op(self, f, other, join='inner',
inplace=False, fillna=False):

def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
if fillna and join != 'left':
raise ValueError('`fillna` must be accompanied by left join')
if fillna and not set(rhs_data_vars) <= set(lhs_data_vars):
raise ValueError('all variables in the argument to `fillna` '
'must be contained in the original dataset')
if inplace and set(lhs_data_vars) != set(rhs_data_vars):
raise ValueError('datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))

dest_vars = OrderedDict()

for k in lhs_data_vars:
if k in rhs_data_vars:
dest_vars[k] = f(lhs_vars[k], rhs_vars[k])
elif inplace:
raise ValueError(
'datasets must have the same data variables '
'for in-place arithmetic operations: %s, %s'
% (list(lhs_data_vars), list(rhs_data_vars)))
elif fillna:
# this shortcuts left alignment of variables for fillna
dest_vars[k] = lhs_vars[k]
elif join in ["left", "outer"]:
dest_vars[k] = (lhs_vars[k] if fillna else
f(lhs_vars[k], np.nan))
for k in rhs_data_vars:
if k not in dest_vars and join in ["right", "outer"]:
dest_vars[k] = (rhs_vars[k] if fillna else
f(rhs_vars[k], np.nan))
return dest_vars

if utils.is_dict_like(other) and not isinstance(other, Dataset):
Expand All @@ -2080,7 +2090,6 @@ def apply_over_both(lhs_data_vars, rhs_data_vars, lhs_vars, rhs_vars):
other_variable = getattr(other, 'variable', other)
new_vars = OrderedDict((k, f(self.variables[k], other_variable))
for k in self.data_vars)

ds._variables.update(new_vars)
return ds

Expand Down
10 changes: 7 additions & 3 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
OPTIONS = {'display_width': 80}
OPTIONS = {'display_width': 80,
'arithmetic_join': "inner"}


class set_options(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you update the docstring for this class?

"""Set options for xarray in a controlled context.

Currently, the only supported option is ``display_width``, which has a
default value of 80.
Currently, the only supported options are:
1.) display_width: maximum terminal display width of data arrays.
Default=80.
2.) arithmetic_join: dataarray/dataset alignment in binary operations.
Default='inner'.

You can use ``set_options`` either as a context manager:

Expand Down
14 changes: 14 additions & 0 deletions xarray/test/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2278,3 +2278,17 @@ def test_dot(self):
da.dot(dm.values)
with self.assertRaisesRegexp(ValueError, 'no shared dimensions'):
da.dot(DataArray(1))

def test_binary_op_join_setting(self):
dim = 'x'
align_type = "outer"
coords_l, coords_r = [0, 1, 2], [1, 2, 3]
missing_3 = xr.DataArray(coords_l, [(dim, coords_l)])
missing_0 = xr.DataArray(coords_r, [(dim, coords_r)])
with xr.set_options(arithmetic_join=align_type):
actual = missing_0 + missing_3
missing_0_aligned, missing_3_aligned = xr.align(missing_0,
missing_3,
join=align_type)
expected = xr.DataArray([np.nan, 2, 4, np.nan], [(dim, [0, 1, 2, 3])])
self.assertDataArrayEqual(actual, expected)
32 changes: 32 additions & 0 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import numpy as np
import pandas as pd
import xarray as xr
import pytest

from xarray import (align, broadcast, concat, merge, conventions, backends,
Expand Down Expand Up @@ -2935,6 +2936,37 @@ def test_filter_by_attrs(self):
for var in new_ds.data_vars:
self.assertEqual(new_ds[var].height, '10 m')

def test_binary_op_join_setting(self):
# arithmetic_join applies to data array coordinates
missing_2 = xr.Dataset({'x':[0, 1]})
missing_0 = xr.Dataset({'x':[1, 2]})
with xr.set_options(arithmetic_join='outer'):
actual = missing_2 + missing_0
expected = xr.Dataset({'x':[0, 1, 2]})
self.assertDatasetEqual(actual, expected)

# arithmetic join also applies to data_vars
ds1 = xr.Dataset({'foo': 1, 'bar': 2})
ds2 = xr.Dataset({'bar': 2, 'baz': 3})
expected = xr.Dataset({'bar': 4}) # default is inner joining
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='outer'):
expected = xr.Dataset({'foo': np.nan, 'bar': 4, 'baz': np.nan})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='left'):
expected = xr.Dataset({'foo': np.nan, 'bar': 4})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)

with xr.set_options(arithmetic_join='right'):
expected = xr.Dataset({'bar': 4, 'baz': np.nan})
actual = ds1 + ds2
self.assertDatasetEqual(actual, expected)


### Py.test tests

Expand Down