Skip to content

Commit ca5a710

Browse files
author
Diptorup Deb
authored
Merge pull request #283 from Hardcode84/pairwise-fixes
Improve `pairwise_distance` workloads
2 parents d5d0fc4 + de90464 commit ca5a710

File tree

4 files changed

+20
-19
lines changed

4 files changed

+20
-19
lines changed

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_dpex_k.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
@dpex.kernel
1010
def _pairwise_distance_kernel(X1, X2, D):
1111
i = dpex.get_global_id(0)
12+
j = dpex.get_global_id(1)
1213

13-
X2_rows = X2.shape[0]
1414
X1_cols = X1.shape[1]
15-
for j in range(X2_rows):
16-
d = X1.dtype.type(0.0)
17-
for k in range(X1_cols):
18-
tmp = X1[i, k] - X2[j, k]
19-
d += tmp * tmp
20-
D[i, j] = np.sqrt(d)
15+
16+
d = X1.dtype.type(0.0)
17+
for k in range(X1_cols):
18+
tmp = X1[i, k] - X2[j, k]
19+
d += tmp * tmp
20+
D[i, j] = np.sqrt(d)
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[X1.shape[0],](X1, X2, D)
24+
_pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_k.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,18 @@
99
@nb.kernel(gpu_fp64_truncate="auto")
1010
def _pairwise_distance_kernel(X1, X2, D):
1111
i = nb.get_global_id(0)
12+
j = nb.get_global_id(1)
1213

13-
X2_rows = X2.shape[0]
1414
X1_cols = X1.shape[1]
15-
for j in range(X2_rows):
16-
d = 0.0
17-
for k in range(X1_cols):
18-
tmp = X1[i, k] - X2[j, k]
19-
d += tmp * tmp
20-
D[i, j] = np.sqrt(d)
15+
16+
d = 0.0
17+
for k in range(X1_cols):
18+
tmp = X1[i, k] - X2[j, k]
19+
d += tmp * tmp
20+
D[i, j] = np.sqrt(d)
2121

2222

2323
def pairwise_distance(X1, X2, D):
24-
_pairwise_distance_kernel[X1.shape[0], nb.DEFAULT_LOCAL_SIZE](X1, X2, D)
24+
_pairwise_distance_kernel[
25+
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
26+
](X1, X2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_n.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ def _pairwise_distance(X1, X2, D):
1111
x1 = np.sum(np.square(X1), axis=1)
1212
x2 = np.sum(np.square(X2), axis=1)
1313
np.dot(X1, X2.T, D)
14-
# D *= -2 TODO: inplace ops doesn't work as intended
15-
D[:] = D * -2
14+
D *= -2
1615
x3 = x1.reshape(x1.size, 1)
1716
np.add(D, x3, D)
1817
np.add(D, x2, D)

dpbench/benchmarks/pairwise_distance/pairwise_distance_numba_mlir_p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _pairwise_distance(X1, X2, D):
2525
# Outermost parallel loop over the matrix X1
2626
for i in numba.prange(X1_rows):
2727
# Loop over the matrix X2
28-
for j in range(X2_rows):
28+
for j in numba.prange(X2_rows):
2929
d = 0.0
3030
# Compute exclidean distance
3131
for k in range(X1_cols):

0 commit comments

Comments
 (0)