-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
implement interp() #2104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
implement interp() #2104
Changes from 5 commits
91e6723
db89669
921ecdc
6b198bd
c4961b0
14404c9
78144e9
7004f75
b1360ee
642e6b3
3328128
dfc347e
3284ad2
c19e9dd
39a0005
6c77873
4ff8477
0807652
230aada
0a4a196
359412a
281dc7f
03ed045
2530b24
01243f1
b3c76d7
7cfa56b
82e04c5
8c29a4b
d89a1bb
ed718d9
aec3bbc
0f17044
d361508
05b4c8f
d8ca99f
7cf370f
c0d796a
7ab6eec
f9a819a
21d4390
6b8f05e
58b4c13
63aa0b3
92c4d27
193bb88
b671257
cf9351b
f2dc499
6e00999
91d92f6
86a3823
9512d13
4df36da
ec8e709
60e2ca3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ | |
import xarray as xr | ||
|
||
from . import ( | ||
alignment, duck_array_ops, formatting, groupby, indexing, ops, resample, | ||
rolling, utils) | ||
alignment, dtypes, duck_array_ops, formatting, groupby, indexing, ops, | ||
resample, rolling, utils) | ||
from .. import conventions | ||
from .alignment import align | ||
from .common import DataWithCoords, ImplementsDatasetReduce | ||
|
@@ -1312,6 +1312,14 @@ def _validate_indexers(self, indexers): | |
raise TypeError('cannot use a Dataset as an indexer') | ||
else: | ||
v = np.asarray(v) | ||
if v.ndim == 0: | ||
v = as_variable(v) | ||
elif v.ndim == 1: | ||
v = as_variable((k, v)) | ||
else: | ||
raise IndexError( | ||
"Unlabeled multi-dimensional array cannot be " | ||
"used for indexing: {}".format(k)) | ||
indexers_list.append((k, v)) | ||
return indexers_list | ||
|
||
|
@@ -1322,6 +1330,9 @@ def _get_indexers_coordinates(self, indexers): | |
|
||
Only coordinate with a name different from any of self.variables will | ||
be attached. | ||
|
||
If remove_dimensional_coord is True, the dimensional coordinate of | ||
indexers will be removed. | ||
""" | ||
from .dataarray import DataArray | ||
|
||
|
@@ -1775,6 +1786,123 @@ def reindex(self, indexers=None, method=None, tolerance=None, copy=True, | |
coord_names.update(indexers) | ||
return self._replace_vars_and_dims(variables, coord_names) | ||
|
||
def interpolate_at(self, method='linear', fill_value=np.nan, kwargs={}, | ||
**coords): | ||
""" Multidimensional interpolation of Dataset. | ||
|
||
Parameters | ||
---------- | ||
**coords : {dim: new_coordinate, ...} | ||
Keyword arguments with names matching dimensions and values. | ||
coords can be an integer, array-like or DataArray. | ||
If DataArrays are passed as coords, their dimensions are used | ||
for the broadcasting. | ||
method: {'linear', 'nearest'} for multidimensional array, | ||
{‘linear’, ‘nearest’, ‘zero’, ‘slinear’, ‘quadratic’, ‘cubic’ } | ||
for 1-dimensional array. | ||
|
||
Returns | ||
------- | ||
interpolated: xr.Dataset | ||
New dataset on the new coordinates. | ||
|
||
Note | ||
---- | ||
scipy is required. If NaN is in the array, ValueError will be raised. | ||
|
||
See Also | ||
-------- | ||
scipy.interpolate.interp1d | ||
scipy.interpolate.RegularGridInterpolator | ||
|
||
Examples | ||
-------- | ||
>>> da = xr.DataArray([0, 0.1, 0.2, 0.1], dims='x', | ||
>>> coords={'x': [0, 1, 2, 3]}) | ||
>>> | ||
>>> da.interpolate_at(x=[0.5, 1.5]) # simple linear interpolation | ||
<xarray.DataArray (x: 2)> | ||
array([0.05, 0.15]) | ||
Coordinates: | ||
* x (x) float64 0.5 1.5 | ||
>>> | ||
>>> # with cubic spline interpolation | ||
... da.interpolate_at(x=[0.5, 1.5], method='cubic') | ||
<xarray.DataArray (x: 2)> | ||
array([0.0375, 0.1625]) | ||
Coordinates: | ||
* x (x) float64 0.5 1.5 | ||
>>> | ||
>>> # interpolation at one single position | ||
... da.interpolate_at(x=0.5) | ||
<xarray.DataArray ()> | ||
array(0.05) | ||
Coordinates: | ||
x float64 0.5 | ||
>>> | ||
>>> # interpolation with broadcasting | ||
... da.interpolate_at(x=xr.DataArray([[0.5, 1.0], [1.5, 2.0]], | ||
... dims=['y', 'z'])) | ||
<xarray.DataArray (y: 2, z: 2)> | ||
array([[0.05, 0.1 ], | ||
[0.15, 0.2 ]]) | ||
Coordinates: | ||
x (y, z) float64 0.5 1.0 1.5 2.0 | ||
Dimensions without coordinates: y, z | ||
>>> | ||
>>> da = xr.DataArray([[0, 0.1, 0.2], [1.0, 1.1, 1.2]], | ||
... dims=['x', 'y'], | ||
... coords={'x': [0, 1], 'y': [0, 10, 20]}) | ||
>>> | ||
>>> # multidimensional interpolation | ||
... da.interpolate_at(x=[0.5, 1.5], y=[5, 15]) | ||
<xarray.DataArray (x: 2, y: 2)> | ||
array([[0.55, 0.65], | ||
[ nan, nan]]) | ||
Coordinates: | ||
* x (x) float64 0.5 1.5 | ||
* y (y) int64 5 15 | ||
>>> | ||
>>> # multidimensional interpolation with broadcasting | ||
... da.interpolate_at(x=xr.DataArray([0.5, 1.5], dims='z'), | ||
... y=xr.DataArray([5, 15], dims='z')) | ||
<xarray.DataArray (z: 2)> | ||
array([0.55, nan]) | ||
Coordinates: | ||
x (z) float64 0.5 1.5 | ||
y (z) int64 5 15 | ||
Dimensions without coordinates: z | ||
""" | ||
from . import interp | ||
|
||
indexers_list = self._validate_indexers(coords) | ||
|
||
variables = OrderedDict() | ||
for name, var in iteritems(self._variables): | ||
var_indexers = {k: (self._variables[k], v) for k, v | ||
in indexers_list if k in var.dims} | ||
if name not in [k for k, v in indexers_list]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that you already need |
||
if duck_array_ops.count(var.data) != var.size: | ||
raise ValueError( | ||
'intarpolate_at can not be used for an array with ' | ||
'nan. {} has {} nans.'.format( | ||
name, var.count() - var.size)) | ||
variables[name] = interp.interpolate( | ||
var, var_indexers, method, fill_value, kwargs) | ||
|
||
coord_names = set(variables).intersection(self._coord_names) | ||
selected = self._replace_vars_and_dims(variables, | ||
coord_names=coord_names) | ||
# attach indexer as coordinate | ||
variables.update({k: v for k, v in indexers_list}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that If you call it inside update, you don't even need to call |
||
# Extract coordinates from indexers | ||
coord_vars = selected._get_indexers_coordinates(coords) | ||
variables.update(coord_vars) | ||
coord_names = (set(variables) | ||
.intersection(self._coord_names) | ||
.union(coord_vars)) | ||
return self._replace_vars_and_dims(variables, coord_names=coord_names) | ||
|
||
def rename(self, name_dict, inplace=False): | ||
"""Returns a new object with renamed variables and dimensions. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from __future__ import absolute_import, division, print_function | ||
from functools import partial | ||
|
||
import numpy as np | ||
from .computation import apply_ufunc | ||
from .pycompat import (OrderedDict, dask_array_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. F401 '.pycompat.OrderedDict' imported but unused |
||
from .variable import broadcast_variables | ||
|
||
|
||
def _localize(obj, index_coord): | ||
""" Speed up for linear and nearest neighbor method. | ||
Only consider a subspace that is needed for the interpolation | ||
""" | ||
for dim, [x, new_x] in index_coord.items(): | ||
try: | ||
imin = x.to_index().get_loc(np.min(new_x), method='ffill') | ||
imax = x.to_index().get_loc(np.max(new_x), method='bfill') | ||
|
||
idx = slice(np.maximum(imin-1, 0), imax+1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
index_coord[dim] = (x[idx], new_x) | ||
obj = obj.isel(**{dim: idx}) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Currently, only a small portion of arrays will be used for the interpolation with |
||
except: | ||
pass | ||
return obj, index_coord | ||
|
||
|
||
def interpolate(obj, indexes_coords, method, fill_value, kwargs): | ||
""" Make an interpolation of Variable | ||
|
||
Parameters | ||
---------- | ||
obj: Variable | ||
index_coord: | ||
mapping from dimension name to a pair of original and new coordinates. | ||
method: string | ||
One of {'linear', 'nearest', 'zero', 'slinear', 'quadratic', | ||
'cubic'}. For multidimensional interpolation, only | ||
{'linear', 'nearest'} can be used. | ||
fill_value: | ||
fill value for extrapolation | ||
kwargs: | ||
keyword arguments to be passed to scipy.interpolate | ||
|
||
Returns | ||
------- | ||
Interpolated Variable | ||
""" | ||
try: | ||
import scipy.interpolate | ||
except ImportError: | ||
raise ImportError( | ||
'Interpolation with method `%s` requires scipy' % method) | ||
|
||
if len(indexes_coords) == 0: | ||
return obj | ||
|
||
# simple speed up for the local interpolation | ||
if method in ['linear', 'nearest']: | ||
obj, indexes_coords = _localize(obj, indexes_coords) | ||
|
||
# target dimensions | ||
dims = list(indexes_coords) | ||
x = [indexes_coords[d][0] for d in dims] | ||
new_x = [indexes_coords[d][1] for d in dims] | ||
destination = broadcast_variables(*new_x) | ||
|
||
if len(indexes_coords) == 1: | ||
if method in ['linear', 'nearest', 'zero', 'slinear', 'quadratic', | ||
'cubic']: | ||
func = partial(scipy.interpolate.interp1d, kind=method, axis=-1, | ||
bounds_error=False, fill_value=fill_value) | ||
else: | ||
raise NotImplementedError | ||
|
||
rslt = apply_ufunc(_interpolate_1d, obj, | ||
input_core_dims=[dims], | ||
output_core_dims=[destination[0].dims], | ||
output_dtypes=[obj.dtype], dask='allowed', | ||
kwargs={'x': x, 'new_x': destination, 'func': func}, | ||
keep_attrs=True) | ||
else: | ||
if method in ['linear', 'nearest']: | ||
func = partial(scipy.interpolate.RegularGridInterpolator, | ||
method=method, bounds_error=False, | ||
fill_value=fill_value) | ||
else: | ||
raise NotImplementedError | ||
|
||
rslt = apply_ufunc(_interpolate_nd, obj, | ||
input_core_dims=[dims], | ||
output_core_dims=[destination[0].dims], | ||
output_dtypes=[obj.dtype], dask='allowed', | ||
kwargs={'x': x, 'new_x': destination, 'func': func}, | ||
keep_attrs=True) | ||
if all(x1.dims == new_x1.dims for x1, new_x1 in zip(x, new_x)): | ||
return rslt.transpose(*obj.dims) | ||
return rslt | ||
|
||
|
||
def _interpolate_1d(obj, x, new_x, func): | ||
if isinstance(obj, dask_array_type): | ||
import dask.array as da | ||
|
||
_assert_single_chunks(obj, [-1]) | ||
chunks = obj.chunks[:-len(x)] + new_x[0].shape | ||
drop_axis = range(obj.ndim-len(x), obj.ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
new_axis = range(obj.ndim-len(x), obj.ndim-len(x)+new_x[0].ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
# call this function recursively | ||
return da.map_blocks(_interpolate_1d, obj, x, new_x, func, | ||
dtype=obj.dtype, chunks=chunks, | ||
new_axis=new_axis, drop_axis=drop_axis) | ||
|
||
# x, new_x are tuples of size 1. | ||
x, new_x = x[0], new_x[0] | ||
rslt = func(x, obj)(np.ravel(new_x)) | ||
if new_x.ndim > 1: | ||
return rslt.reshape(obj.shape[:-1] + new_x.shape) | ||
if new_x.ndim == 0: | ||
return rslt[..., -1] | ||
return rslt | ||
|
||
|
||
def _interpolate_nd(obj, x, new_x, func): | ||
""" dask compatible interpolation function. | ||
The last len(x) dimensions are used for the interpolation | ||
""" | ||
if isinstance(obj, dask_array_type): | ||
import dask.array as da | ||
|
||
_assert_single_chunks(obj, range(-len(x), 0)) | ||
chunks = obj.chunks[:-len(x)] + new_x[0].shape | ||
drop_axis = range(obj.ndim-len(x), obj.ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
new_axis = range(obj.ndim-len(x), obj.ndim-len(x)+new_x[0].ndim) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
return da.map_blocks(_interpolate_nd, obj, x, new_x, func, | ||
dtype=obj.dtype, chunks=chunks, | ||
new_axis=new_axis, drop_axis=drop_axis) | ||
|
||
# move the interpolation axes to the start position | ||
obj = obj.transpose(range(-len(x), obj.ndim - len(x))) | ||
# stack new_x to 1 vector, with reshape | ||
xi = np.stack([x1.values.ravel() for x1 in new_x], axis=-1) | ||
rslt = func(x, obj)(xi) | ||
# move back the interpolation axes to the last position | ||
rslt = rslt.transpose(range(-rslt.ndim+1, 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E226 missing whitespace around arithmetic operator |
||
return rslt.reshape(rslt.shape[:-1] + new_x[0].shape) | ||
|
||
|
||
def _assert_single_chunks(obj, axes): | ||
for axis in axes: | ||
if len(obj.chunks[axis]) > 1: | ||
raise ValueError('Chunk along the dimension to be interpolated ' | ||
'({}) is not allowed.'.format(axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
F401 '.dtypes' imported but unused