Skip to content

Commit

Permalink
Implement pairwise distance on 2d grid
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Oct 21, 2023
1 parent 94c2d62 commit 009989b
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@dpex.kernel
def _pairwise_distance_kernel(X1, X2, D):
i = dpex.get_global_id(0)
j = dpex.get_global_id(1)
i = dpex.get_global_id(1)
j = dpex.get_global_id(0)

X1_cols = X1.shape[1]

Expand All @@ -21,4 +21,4 @@ def _pairwise_distance_kernel(X1, X2, D):


def pairwise_distance(X1, X2, D):
_pairwise_distance_kernel[dpex.Range(X1.shape[0], X2.shape[0])](X1, X2, D)
_pairwise_distance_kernel[dpex.Range(X2.shape[0], X1.shape[0])](X1, X2, D)
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@nb.kernel(gpu_fp64_truncate="auto")
def _pairwise_distance_kernel(X1, X2, D):
i = nb.get_global_id(0)
j = nb.get_global_id(1)
i = nb.get_global_id(1)
j = nb.get_global_id(0)

X1_cols = X1.shape[1]

Expand All @@ -22,5 +22,5 @@ def _pairwise_distance_kernel(X1, X2, D):

def pairwise_distance(X1, X2, D):
_pairwise_distance_kernel[
(X1.shape[0], X2.shape[0]), nb.DEFAULT_LOCAL_SIZE
(X2.shape[0], X1.shape[0]), nb.DEFAULT_LOCAL_SIZE
](X1, X2, D)
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,29 @@

using namespace sycl;

#ifdef __DO_FLOAT__
#define SQRT(x) sqrtf(x)
#else
#define SQRT(x) sqrt(x)
#endif
template<typename T>
class PairwiseDistanceKernel;

template <typename FpTy>
void pairwise_distance_impl(queue Queue,
size_t npoints,
size_t x1_npoints,
size_t x2_npoints,
size_t ndims,
const FpTy *p1,
const FpTy *p2,
FpTy *distance_op)
{
Queue.submit([&](handler &h) {
h.parallel_for<class PairwiseDistanceKernel>(
range<1>{npoints}, [=](id<1> myID) {
size_t i = myID[0];
for (size_t j = 0; j < npoints; j++) {
FpTy d = 0.;
for (size_t k = 0; k < ndims; k++) {
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
d += tmp * tmp;
}
if (d != 0.0) {
distance_op[i * npoints + j] = sqrt(d);
}
h.parallel_for<PairwiseDistanceKernel<FpTy>>(
range<2>{x1_npoints, x2_npoints}, [=](id<2> myID) {
auto i = myID[0];
auto j = myID[1];
FpTy d = 0.;
for (size_t k = 0; k < ndims; k++) {
auto tmp = p1[i * ndims + k] - p2[j * ndims + k];
d += tmp * tmp;
}
distance_op[i * x2_npoints + j] = sycl::sqrt(d);
});
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,22 @@ void pairwise_distance_sync(dpctl::tensor::usm_ndarray X1,
{
sycl::event res_ev;
auto Queue = X1.get_queue();
auto ndims = 3;
auto npoints = X1.get_size() / ndims;
auto ndims = X1.get_shape(1);
auto x1_npoints = X1.get_shape(0);
auto x2_npoints = X2.get_shape(0);

if (!ensure_compatibility(X1, X2, D))
throw std::runtime_error("Input arrays are not acceptable.");

if (X1.get_typenum() != UAR_DOUBLE || X2.get_typenum() != UAR_DOUBLE) {
throw std::runtime_error("Expected a double precision FP array.");
if (X1.get_typenum() == UAR_FLOAT) {
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, X1.get_data<float>(),
X2.get_data<float>(), D.get_data<float>());
} else if (X1.get_typenum() == UAR_DOUBLE) {
pairwise_distance_impl(Queue, x1_npoints, x2_npoints, ndims, X1.get_data<double>(),
X2.get_data<double>(), D.get_data<double>());
} else {
throw std::runtime_error("Expected a double or single precision FP array.");
}

pairwise_distance_impl(Queue, npoints, ndims, X1.get_data<double>(),
X2.get_data<double>(), D.get_data<double>());
}

PYBIND11_MODULE(_pairwise_distance_sycl, m)
Expand Down

0 comments on commit 009989b

Please sign in to comment.