From 908bf95b9a16e33a3c88540ea69e762eca4cd299 Mon Sep 17 00:00:00 2001 From: "Mohar, Boaz" Date: Thu, 2 Jun 2016 23:42:12 -0400 Subject: [PATCH] Fix a case of advance indexing if the list of indices is not sorted. Does not handle sorting keys with dim > 1 --- bolt/spark/array.py | 9 +++++++-- test/spark/test_spark_getting.py | 11 +++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/bolt/spark/array.py b/bolt/spark/array.py index 588279e..5bde63a 100644 --- a/bolt/spark/array.py +++ b/bolt/spark/array.py @@ -7,7 +7,7 @@ from bolt.spark.stack import StackedArray from bolt.spark.utils import zip_with_index from bolt.spark.statcounter import StatCounter -from bolt.utils import slicify, listify, tupleize, argpack, inshape, istransposeable, isreshapeable +from bolt.utils import slicify, listify, tupleize, argpack, inshape, istransposeable, isreshapeable, allclose class BoltArraySpark(BoltArray): @@ -536,8 +536,10 @@ def _getadvanced(self, index): def key_check(key): return key in key_tuples + sorted_index = argsort(index[0]) + def key_func(key): - return unravel_index(key, shape) + return unravel_index(sorted_index[key], shape) # filter records based on key targets filtered = self._rdd.filter(lambda kv: key_check(kv[0])) @@ -664,6 +666,9 @@ def __getitem__(self, index): # if any key indices used negative steps, records are no longer ordered if self._ordered is False or any([isinstance(s, slice) and s.step<0 for s in index[:self.split]]): ordered = False + # if any keys are not in order the records are no longer ordered + elif any([isinstance(s, ndarray) and len(s.shape) == 1 and not allclose(array(sorted(s)), s) for s in index[:self.split]]): + ordered = False else: ordered = True diff --git a/test/spark/test_spark_getting.py b/test/spark/test_spark_getting.py index 64f0358..018a8f4 100644 --- a/test/spark/test_spark_getting.py +++ b/test/spark/test_spark_getting.py @@ -77,6 +77,10 @@ def test_getitem_list(sc): assert allclose(b[[0, 1], [0, 2], [0, 3]].toarray(), x[[0, 1], [0, 2], [0, 3]]) assert allclose(b[[0, 1, 2], [0, 2, 1], [0, 3, 1]].toarray(), x[[0, 1, 2], [0, 2, 1], [0, 3, 1]]) + assert allclose(b[[1, 0], [0, 1], [0, 2]].toarray(), x[[1, 0], [0, 1], [0, 2]]) + assert allclose(b[[1, 0], [0, 2], [0, 3]].toarray(), x[[1, 0], [0, 2], [0, 3]]) + assert allclose(b[[1, 0, 2], [0, 2, 1], [0, 3, 1]].toarray(), x[[1, 0, 2], [0, 2, 1], [0, 3, 1]]) + b = array(x, sc, axis=(0,1)) assert allclose(b[[0, 1], [0, 1], [0, 2]].toarray(), x[[0, 1], [0, 1], [0, 2]]) assert allclose(b[[0, 1], [0, 2], [0, 3]].toarray(), x[[0, 1], [0, 2], [0, 3]]) @@ -108,6 +112,13 @@ def test_getitem_mixed(sc): assert allclose(b[:, :, i, :].toarray(), x[:, :, i, :]) assert allclose(b[s, s, i, s].toarray(), x[s, s, i, s]) + i = [1, 0] + s = slice(1, 3) + assert allclose(b[i, :, :, :].toarray(), x[i, :, :, :]) + assert allclose(b[i, s, s, s].toarray(), x[i, s, s, s]) + assert allclose(b[:, :, i, :].toarray(), x[:, :, i, :]) + assert allclose(b[s, s, i, s].toarray(), x[s, s, i, s]) + i = [1] assert allclose(b[i, :, :, :].toarray(), x[i, :, :, :]) assert allclose(b[:, :, i, :].toarray(), x[:, :, i, :])