diff --git a/xarray_filters/multi_index.py b/xarray_filters/multi_index.py index e2cd01b..ed727d8 100644 --- a/xarray_filters/multi_index.py +++ b/xarray_filters/multi_index.py @@ -53,7 +53,9 @@ def multi_index_to_coords(arr, axis=0): return coords, (dim,) if any(name is None for name in multi.names): raise ValueError('Expected MultiIndex with named components (found {})'.format(multi.names)) - np_arrs = (np.unique(x) for x in np.array(multi.tolist()).T) + cols = np.array(multi.tolist()).T + ascend = [(1 if c[0] < c[-1] else -1) for c in cols] + np_arrs = (np.unique(c)[::a] for c, a in zip(cols, ascend)) coords = OrderedDict(zip(multi.names, np_arrs)) dims = tuple(coords) return coords, dims diff --git a/xarray_filters/reshape.py b/xarray_filters/reshape.py index 7fd07e1..2e14f61 100644 --- a/xarray_filters/reshape.py +++ b/xarray_filters/reshape.py @@ -26,6 +26,8 @@ 'to_xy_arrays'] +RAVEL_ORDER = 'C' + def has_features(dset, raise_err=True, features_layer=None): '''Check if an MLDataset has a DataArray called "features" with dimensions (space, layer) @@ -174,7 +176,7 @@ def to_features(dset, layers=None, row_dim=None, index = getattr(arr, row_dim) else: index = create_multi_index(arr) - val = val.ravel()[:, np.newaxis] + val = val.ravel(order=RAVEL_ORDER)[:, np.newaxis] coords = OrderedDict([(row_dim, index), (col_dim, [layer])]) new_dims = (row_dim, col_dim) @@ -251,10 +253,11 @@ def from_features(arr, axis=0): val = np.full(shp, np.nan) feature_idx = arr.indexes[arr.dims[0]] dim_coord_pairs = feature_idx.tolist() + row, col = coords[dims[0]], coords[dims[1]] for idx, dim_coord_pair in enumerate(dim_coord_pairs): - val[dim_coord_pair] = arr_val[idx] + val[(dim_coord_pair[0] == row), (dim_coord_pair[1] == col)] = arr_val[idx] else: - val = arr_val.reshape(shp) + val = arr_val.reshape(shp, order=RAVEL_ORDER) layer = simple_np_arr[j] dset[layer] = xr.DataArray(val, coords=coords, dims=dims) return MLDataset(dset) @@ -290,7 +293,8 @@ def to_xy_arrays(dset=None, y=None, features_layer=None, ykw = {col_dim: yname} X = arr.isel(**xkw) if y is None: - if yname in getattr(arr, arr.dims[-1], pd.Series([]).values): + col_names = getattr(arr, arr.dims[-1]) + if yname in col_names.values: y = arr.isel(**ykw) if as_np: y = y.values