Skip to content

Commit

Permalink
Merge pull request #12 from florian-huber/issue_8_0
Browse files Browse the repository at this point in the history
Improve merging
  • Loading branch information
florian-huber authored Dec 2, 2022
2 parents 8a5d272 + a1d7b29 commit a8480af
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 65 deletions.
23 changes: 13 additions & 10 deletions sparsestack/StackedSparseArray.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,30 +231,33 @@ def add_dense_matrix(self, matrix: np.ndarray,
"""
if matrix is None:
self.data = np.array([])
elif len(matrix.dtype) > 1: # if structured array
for dtype_name in matrix.dtype.names:
self._add_dense_matrix(matrix[dtype_name],
f"{name}_{dtype_name}",
join_type)
else:
self._add_dense_matrix(matrix, name, join_type)

def _add_dense_matrix(self, matrix, name, join_type):
if matrix.dtype.type == np.void:
input_dtype = matrix.dtype[0]
else:
input_dtype = matrix.dtype
def get_dtype(data):
if data.dtype.type == np.void:
return data.dtype[0]
return data.dtype

# Handle 1D arrays
if matrix.ndim == 1:
matrix = matrix.reshape(-1, 1)

# Handle structured arrays > 1 dimension
if len(matrix.dtype) > 1:
dtype_data = [(f"{name}_{dtype_name}", get_dtype(matrix[dtype_name])) for dtype_name in matrix.dtype.names]
else:
dtype_data = [(name, get_dtype(matrix))]


if self.shape[2] == 0 or (self.shape[2] == 1 and name in self.score_names):
# Add first (sparse) array of scores
(idx_row, idx_col) = np.where(matrix)
self.row = idx_row
self.col = idx_col
self.data = np.array(matrix[idx_row, idx_col], dtype=[(name, input_dtype)])

self.data = np.array(matrix[idx_row, idx_col], dtype=dtype_data)
else:
# Add new stack of scores
(idx_row, idx_col) = np.where(matrix)
Expand Down
154 changes: 99 additions & 55 deletions sparsestack/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,38 @@ def _join_arrays(row1, col1, data1,
#pylint: disable=too-many-arguments
#pylint: disable=too-many-locals

idx1 = np.lexsort((col1, row1))
idx2 = np.lexsort((col2, row2))
# join types
if join_type == "left":
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
idx1, idx2, join_type="inner")
data_join = set_and_fill_new_array(data1, data2, name,
np.arange(0, len(row1)), np.arange(0, len(row1)),
idx_inner_right, idx_inner_left,
len(row1))
return row1, col1, data_join
if join_type == "right":
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
idx1, idx2, join_type="inner")
data_join = set_and_fill_new_array(data1, data2, name,
idx_inner_left, idx_inner_right,
np.arange(0, len(row2)), np.arange(0, len(row2)),
len(row2))
return row2, col2, data_join
if join_type == "inner":
idx_inner_left, idx_inner_right = get_idx(row1, col1, row2, col2, join_type="inner")
idx_inner_left, idx_inner_right, _, _, _, _ = get_idx(row1, col1, row2, col2,
idx1, idx2, join_type="inner")
data_join = set_and_fill_new_array(data1, data2, name,
idx_inner_left, np.arange(0, len(idx_inner_left)),
idx_inner_right, np.arange(0, len(idx_inner_left)),
len(idx_inner_left))
return row1[idx_inner_left], col1[idx_inner_left], data_join
if join_type == "outer":
idx_left, idx_left_new, idx_right, idx_right_new, row_new, col_new = get_idx_outer(row1, col1, row2, col2)
idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new = get_idx_outer(
row1, col1, row2, col2,
idx1, idx2
)
data_join = set_and_fill_new_array(data1, data2, name,
idx_left, idx_left_new, idx_right, idx_right_new,
len(row_new))
Expand All @@ -68,6 +76,7 @@ def set_and_fill_new_array(data1, data2, name,
"""Create new structured numpy array and fill with data1 and data2.
"""
#pylint: disable=too-many-arguments

new_dtype = [(dname, d[0]) for dname, d in data1.dtype.fields.items()]
if data2.dtype.names is None:
new_dtype += [(name, data2.dtype)]
Expand All @@ -92,69 +101,104 @@ def set_and_fill_new_array(data1, data2, name,


@numba.jit(nopython=True)
def get_idx_inner_brute_force(left_row, left_col, right_row, right_col):
#Get indexes for entries for a inner join.
idx_inner_left = []
idx_inner_right = []
for i, right_row_id in enumerate(right_row):
if right_row_id in left_row:
idx = np.where((left_row == right_row_id)
& (left_col == right_col[i]))[0]
if len(idx) > 0:
idx_inner_left.append(idx[0])
idx_inner_right.append(i)
return idx_inner_left, idx_inner_right
def get_idx_inner(left_row, left_col, right_row, right_col,
idx1, idx2):
"""Get current and new indices for inner merge.
idx1, idx2
Numpy array of pre-sorted (np.lexsort) indices for left/right arrays.
"""
#pylint: disable=too-many-arguments
#pylint: disable=too-many-locals

@numba.jit(nopython=True)
def get_idx(left_row, left_col, right_row, right_col,
join_type="left"):
list1 = list(zip(left_row, left_col))
list2 = list(zip(right_row, right_col))
if join_type == "left":
uniques = set(list1)
elif join_type == "right":
uniques = set(list2)
elif join_type == "inner":
uniques = set(list1).intersection(set(list2))
#elif join_type == "outer":
# uniques = set(list1).union(set(list2))
else:
raise ValueError("Unknown join_type")
uniques = sorted(list(uniques))
idx_left = []
idx_left_new = []
idx_right = []
for (r, c) in uniques:
i_left = np.where((left_row == r) & (left_col == c))[0]
if len(i_left) > 0:
idx_left.append(i_left[0])
i_right = np.where((right_row == r) & (right_col == c))[0]
if len(i_right) > 0:
idx_right.append(i_right[0])
return idx_left, idx_right
idx_right_new = []
row_new = []
col_new = []
low = 0
counter = 0
for i in idx1:
for j in idx2[low:]:
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
idx_left.append(i)
idx_left_new.append(counter)
idx_right.append(j)
idx_right_new.append(counter)
row_new.append(left_row[i])
col_new.append(left_col[i])
counter += 1
if left_row[i] > right_row[j]:
low = j
if left_row[i] < right_row[j]:
break
return idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new


@numba.jit(nopython=True)
def get_idx_outer(left_row, left_col, right_row, right_col):
def get_idx_outer(left_row, left_col, right_row, right_col,
idx1, idx2):
"""Get current and new indices for outer merge.
idx1, idx2
Numpy array of pre-sorted (np.lexsort) indices for left/right arrays.
"""
#pylint: disable=too-many-arguments
#pylint: disable=too-many-locals
uniques = set(zip(left_row, left_col)).union(set(zip(right_row, right_col)))
uniques = sorted(list(uniques))

idx_left = []
idx_left_new = []
idx_right = []
idx_right_new = []
row_new = []
col_new = []
for i, (r, c) in enumerate(uniques):
row_new.append(r)
col_new.append(c)
i_left = np.where((left_row == r) & (left_col == c))[0]
if len(i_left) > 0:
idx_left.append(i_left[0])
idx_left_new.append(i)
i_right = np.where((right_row == r) & (right_col == c))[0]
if len(i_right) > 0:
idx_right.append(i_right[0])
idx_right_new.append(i)
return idx_left, idx_left_new, idx_right, idx_right_new, row_new, col_new

right_in_inner = []
low = 0
counter = 0
for i in idx1:
current_match = False
for j in idx2[low:]:
if (left_row[i] == right_row[j]) and (left_col[i] == right_col[j]):
right_in_inner.append(j)
current_match = True
if left_row[i] > right_row[j]:
low = j
if left_row[i] < right_row[j]:
break
if current_match:
x = right_in_inner[-1]
idx_left.append(i)
idx_left_new.append(counter)
idx_right.append(x)
idx_right_new.append(counter)
row_new.append(left_row[i])
col_new.append(left_col[i])
counter += 1
else:
idx_left.append(i)
idx_left_new.append(counter)
row_new.append(left_row[i])
col_new.append(left_col[i])
counter += 1

for j in set(idx2).difference(set(right_in_inner)):
idx_right.append(j)
idx_right_new.append(counter)
row_new.append(right_row[j])
col_new.append(right_col[j])
counter += 1
return idx_left, idx_right, idx_left_new, idx_right_new, row_new, col_new


def get_idx(left_row, left_col, right_row, right_col, idx1, idx2,
join_type="left"):
#pylint: disable=too-many-arguments
if join_type == "inner":
return get_idx_inner(left_row, left_col, right_row, right_col,
idx1, idx2)
if join_type == "outer":
return get_idx_outer(left_row, left_col, right_row, right_col,
idx1, idx2)
raise ValueError("Unknown join_type")
1 change: 1 addition & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def test_join_arrays(row2, col2):
assert np.allclose([x[0] for x in c], [0, 1, 2, 3, 4])
assert np.allclose([x[1] for x in c], [0, 0, 10, 0, 20])


@pytest.mark.parametrize("join_type, expected_data, expected_row", [
["left", np.array([[0, 0], [1, 0], [2, 2], [4, 0], [5, 5]]), np.array([0, 1, 2, 4, 5])],
["right", np.array([[2, 2],[0, 3], [5, 5], [0, 6], [0, 7],]), np.array([2, 3, 5, 6, 7])],
Expand Down

0 comments on commit a8480af

Please sign in to comment.