Skip to content

Commit

Permalink
Merge pull request #283 from Hardcode84/pairwise-fixes
Browse files Browse the repository at this point in the history
Improve `pairwise_distance` workloads
  • Loading branch information
Diptorup Deb authored Jul 31, 2023
2 parents d5d0fc4 + de90464 commit ca5a710
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
@dpex.kernel
def _pairwise_distance_kernel(X1, X2, D):
i = dpex.get_global_id(0)
j = dpex.get_global_id(1)

X2_rows = X2.shape[0]
X1_cols = X1.shape[1]
for j in range(X2_rows):
d = X1.dtype.type(0.0)
for k in range(X1_cols):
tmp = X1[i, k] - X2[j, k]
d += tmp * tmp
D[i, j] = np.sqrt(d)

d = X1.dtype.type(0.0)
for k in range(X1_cols):
tmp = X1[i, k] - X2[j, k]
d += tmp * tmp
D[i, j] = np.sqrt(d)


def pairwise_distance(X1, X2, D):
_pairwise_distance_kernel[X1.shape[0],](X1, X2, D)
_pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D)
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
@nb.kernel(gpu_fp64_truncate="auto")
def _pairwise_distance_kernel(X1, X2, D):
i = nb.get_global_id(0)
j = nb.get_global_id(1)

X2_rows = X2.shape[0]
X1_cols = X1.shape[1]
for j in range(X2_rows):
d = 0.0
for k in range(X1_cols):
tmp = X1[i, k] - X2[j, k]
d += tmp * tmp
D[i, j] = np.sqrt(d)

d = 0.0
for k in range(X1_cols):
tmp = X1[i, k] - X2[j, k]
d += tmp * tmp
D[i, j] = np.sqrt(d)


def pairwise_distance(X1, X2, D):
_pairwise_distance_kernel[X1.shape[0], nb.DEFAULT_LOCAL_SIZE](X1, X2, D)
_pairwise_distance_kernel[
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
](X1, X2, D)
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ def _pairwise_distance(X1, X2, D):
x1 = np.sum(np.square(X1), axis=1)
x2 = np.sum(np.square(X2), axis=1)
np.dot(X1, X2.T, D)
# D *= -2 TODO: inplace ops doesn't work as intended
D[:] = D * -2
D *= -2
x3 = x1.reshape(x1.size, 1)
np.add(D, x3, D)
np.add(D, x2, D)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _pairwise_distance(X1, X2, D):
# Outermost parallel loop over the matrix X1
for i in numba.prange(X1_rows):
# Loop over the matrix X2
for j in range(X2_rows):
for j in numba.prange(X2_rows):
d = 0.0
# Compute exclidean distance
for k in range(X1_cols):
Expand Down

0 comments on commit ca5a710

Please sign in to comment.