From 2dd9abb35091f98ae0c0d01667e815cc8fbb3dc5 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 5 Jan 2023 12:19:05 -0800 Subject: [PATCH] Remove faiss dependency from fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh (#1108) Remove the dependency on faiss from the fused_l2_knn.cuh, selection_faiss.cuh, ball_cover.cuh and haversine_distance.cuh headers. This takes a copy of the faiss BlockSelect/WarpSelect device code for top-k selection, and updates to use raft primitives for things like reductions, KeyValuePair, warp shuffling etc. Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/raft/pull/1108 --- ci/checks/copyright.py | 4 +- cpp/include/raft/core/kvp.hpp | 25 +- .../raft/spatial/knn/detail/ball_cover.cuh | 7 +- .../knn/detail/ball_cover/registers.cuh | 57 +- .../knn/detail/faiss_select/Comparators.cuh | 29 + .../detail/faiss_select/MergeNetworkBlock.cuh | 277 +++++++++ .../detail/faiss_select/MergeNetworkUtils.cuh | 25 + .../MergeNetworkWarp.cuh} | 354 +++++------ .../knn/detail/faiss_select/Select.cuh | 555 ++++++++++++++++++ .../knn/detail/faiss_select/StaticUtils.h | 48 ++ .../key_value_block_select.cuh} | 46 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 8 +- .../spatial/knn/detail/haversine_distance.cuh | 17 +- .../knn/detail/knn_brute_force_faiss.cuh | 15 +- .../spatial/knn/detail/selection_faiss.cuh | 15 +- thirdparty/LICENSES/LICENSE.faiss | 21 + 16 files changed, 1216 insertions(+), 287 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh rename cpp/include/raft/spatial/knn/detail/{warp_select_faiss.cuh => faiss_select/MergeNetworkWarp.cuh} (51%) create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h rename cpp/include/raft/spatial/knn/detail/{block_select_faiss.cuh => faiss_select/key_value_block_select.cuh} (80%) create mode 100644 thirdparty/LICENSES/LICENSE.faiss diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index bfef5392f5..43a4a186f8 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh"] +ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index f6ea841dc4..8d3321eb77 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include +#include #endif namespace raft { /** @@ -58,5 +59,27 @@ struct KeyValuePair { { return (value != b.value) || (key != b.key); } + + RAFT_INLINE_FUNCTION bool operator<(const KeyValuePair<_Key, _Value>& b) const + { + return (key < b.key) || ((key == b.key) && value < b.value); + } + + RAFT_INLINE_FUNCTION bool operator>(const KeyValuePair<_Key, _Value>& b) const + { + return (key > b.key) || ((key == b.key) && value > b.value); + } }; + +#ifdef _RAFT_HAS_CUDA +template +RAFT_INLINE_FUNCTION KeyValuePair<_Key, _Value> shfl_xor(const KeyValuePair<_Key, _Value>& input, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + return KeyValuePair<_Key, _Value>(shfl_xor(input.key, laneMask, width, mask), + shfl_xor(input.value, laneMask, width, mask)); +} +#endif } // end namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 797dbaab50..fd0314dbcc 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,6 @@ #include "../ball_cover_types.hpp" #include "ball_cover/common.cuh" #include "ball_cover/registers.cuh" -#include "block_select_faiss.cuh" #include "haversine_distance.cuh" #include "knn_brute_force_faiss.cuh" #include "selection_faiss.cuh" @@ -31,6 +30,8 @@ #include +#include + #include #include #include @@ -38,8 +39,6 @@ #include #include -#include - #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh index a883a1eadd..530b0d3d04 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "common.cuh" #include "../../ball_cover_types.hpp" -#include "../block_select_faiss.cuh" +#include "../faiss_select/key_value_block_select.cuh" #include "../haversine_distance.cuh" #include "../selection_faiss.cuh" @@ -28,9 +28,6 @@ #include -#include -#include - #include namespace raft { @@ -172,10 +169,10 @@ __global__ void compute_final_dists_registers(const value_t* X_index, dist_func dfunc, value_int* dist_counter) { - static constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + static constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ faiss::gpu::KeyValuePair shared_memV[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; const value_t* x_ptr = X + (n_cols * blockIdx.x); value_t local_x_ptr[col_q]; @@ -183,21 +180,21 @@ __global__ void compute_final_dists_registers(const value_t* X_index, local_x_ptr[j] = x_ptr[j]; } - faiss::gpu::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(faiss::gpu::Limits::getMax(), - faiss::gpu::Limits::getMax(), + faiss_select::KeyValueBlockSelect, + warp_q, + thread_q, + tpb> + heap(std::numeric_limits::max(), + std::numeric_limits::max(), -1, shared_memK, shared_memV, k); - const value_int n_k = faiss::gpu::utils::roundDown(k, faiss::gpu::kWarpSize); + const value_int n_k = Pow2::roundDown(k); value_int i = threadIdx.x; for (; i < n_k; i += tpb) { value_idx ind = knn_inds[blockIdx.x * k + i]; @@ -224,7 +221,7 @@ __global__ void compute_final_dists_registers(const value_t* X_index, // Round R_size to the nearest warp threads so they can // all be computing in parallel. - const value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize); + const value_int limit = Pow2::roundDown(R_size); i = threadIdx.x; for (; i < limit; i += tpb) { @@ -334,10 +331,10 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, distance_func dfunc, float weight = 1.0) { - static constexpr value_int kNumWarps = tpb / faiss::gpu::kWarpSize; + static constexpr value_int kNumWarps = tpb / WarpSize; __shared__ value_t shared_memK[kNumWarps * warp_q]; - __shared__ faiss::gpu::KeyValuePair shared_memV[kNumWarps * warp_q]; + __shared__ KeyValuePair shared_memV[kNumWarps * warp_q]; // TODO: Separate kernels for different widths: // 1. Very small (between 3 and 32) just use registers for columns of "blockIdx.x" @@ -352,15 +349,15 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, } // Each warp works on 1 R - faiss::gpu::KeyValueBlockSelect, - warp_q, - thread_q, - tpb> - heap(faiss::gpu::Limits::getMax(), - faiss::gpu::Limits::getMax(), + faiss_select::KeyValueBlockSelect, + warp_q, + thread_q, + tpb> + heap(std::numeric_limits::max(), + std::numeric_limits::max(), -1, shared_memK, shared_memV, @@ -390,7 +387,7 @@ __global__ void block_rbc_kernel_registers(const value_t* X_index, value_idx R_size = R_stop_offset - R_start_offset; - value_int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize); + value_int limit = Pow2::roundDown(R_size); value_int i = threadIdx.x; for (; i < limit; i += tpb) { // Index and distance of current candidate's nearest landmark diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh new file mode 100644 index 0000000000..173c06af30 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/Comparators.cuh @@ -0,0 +1,29 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +template +struct Comparator { + __device__ static inline bool lt(T a, T b) { return a < b; } + + __device__ static inline bool gt(T a, T b) { return a > b; } +}; + +template <> +struct Comparator { + __device__ static inline bool lt(half a, half b) { return __hlt(a, b); } + + __device__ static inline bool gt(half a, half b) { return __hgt(a, b); } +}; + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh new file mode 100644 index 0000000000..d923b41ded --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkBlock.cuh @@ -0,0 +1,277 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +// Merge pairs of lists smaller than blockDim.x (NumThreads) +template +inline __device__ void blockMergeSmall(K* listK, V* listV) +{ + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); + static_assert(L <= NumThreads, "merge list size must be <= NumThreads"); + + // Which pair of lists we are merging + int mergeId = threadIdx.x / L; + + // Which thread we are within the merge + int tid = threadIdx.x % L; + + // listK points to a region of size N * 2 * L + listK += 2 * L * mergeId; + listV += 2 * L * mergeId; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { + int pos = 2 * tid - (tid & (stride - 1)); + + if (AllThreads || (threadIdx.x < N * L)) { + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +// Merge pairs of sorted lists larger than blockDim.x (NumThreads) +template +inline __device__ void blockMergeLarge(K* listK, V* listV) +{ + static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); + static_assert(L >= WarpSize, "merge list size must be >= 32"); + static_assert(utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2"); + static_assert(L >= NumThreads, "merge list size must be >= NumThreads"); + + // For L > NumThreads, each thread has to perform more work + // per each stride. + constexpr int kLoopPerThread = L / NumThreads; + + // It's not a bitonic merge, both lists are in the same direction, + // so handle the first swap assuming the second list is reversed +#pragma unroll + for (int loop = 0; loop < kLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = L - 1 - tid; + int stride = 2 * tid + 1; + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + + constexpr int kSecondLoopPerThread = FullMerge ? kLoopPerThread : kLoopPerThread / 2; + +#pragma unroll + for (int stride = L / 2; stride > 0; stride /= 2) { +#pragma unroll + for (int loop = 0; loop < kSecondLoopPerThread; ++loop) { + int tid = loop * NumThreads + threadIdx.x; + int pos = 2 * tid - (tid & (stride - 1)); + + K ka = listK[pos]; + K kb = listK[pos + stride]; + + bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + listK[pos] = swap ? kb : ka; + listK[pos + stride] = swap ? ka : kb; + + V va = listV[pos]; + V vb = listV[pos + stride]; + listV[pos] = swap ? vb : va; + listV[pos + stride] = swap ? va : vb; + + // FIXME: is this a CUDA 9 compiler bug? + // K& ka = listK[pos]; + // K& kb = listK[pos + stride]; + + // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); + // swap(s, ka, kb); + + // V& va = listV[pos]; + // V& vb = listV[pos + stride]; + // swap(s, va, vb); + } + + __syncthreads(); + } +} + +/// Class template to prevent static_assert from firing for +/// mixing smaller/larger than block cases +template +struct BlockMerge { +}; + +/// Merging lists smaller than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) + { + constexpr int kNumParallelMerges = NumThreads / L; + constexpr int kNumIterations = N / kNumParallelMerges; + + static_assert(L <= NumThreads, "list must be <= NumThreads"); + static_assert((N < kNumParallelMerges) || (kNumIterations * kNumParallelMerges == N), + "improper selection of N and L"); + + if (N < kNumParallelMerges) { + // We only need L threads per each list to perform the merge + blockMergeSmall(listK, listV); + } else { + // All threads participate +#pragma unroll + for (int i = 0; i < kNumIterations; ++i) { + int start = i * kNumParallelMerges * 2 * L; + + blockMergeSmall(listK + start, + listV + start); + } + } + } +}; + +/// Merging lists larger than a block +template +struct BlockMerge { + static inline __device__ void merge(K* listK, V* listV) + { + // Each pair of lists is merged sequentially +#pragma unroll + for (int i = 0; i < N; ++i) { + int start = i * 2 * L; + + blockMergeLarge(listK + start, listV + start); + } + } +}; + +template +inline __device__ void blockMerge(K* listK, V* listV) +{ + constexpr bool kSmallerThanBlock = (L <= NumThreads); + + BlockMerge::merge(listK, listV); +} + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh new file mode 100644 index 0000000000..2cb01f9199 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkUtils.cuh @@ -0,0 +1,25 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +namespace raft::spatial::knn::detail::faiss_select { + +template +inline __device__ void swap(bool swap, T& x, T& y) +{ + T tmp = x; + x = swap ? y : x; + y = swap ? tmp : y; +} + +template +inline __device__ void assign(bool assign, T& x, T y) +{ + x = assign ? y : x; +} +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh similarity index 51% rename from cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh rename to cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh index 2ce2d34cca..bce739b2d8 100644 --- a/cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/MergeNetworkWarp.cuh @@ -2,36 +2,31 @@ * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * LICENSE file thirdparty/LICENSES/LICENSE.faiss */ #pragma once -#include -#include -#include -#include -#include +#include +#include -#include +#include -namespace faiss { -namespace gpu { -using raft::KeyValuePair; +namespace raft::spatial::knn::detail::faiss_select { // // This file contains functions to: // // -perform bitonic merges on pairs of sorted lists, held in -// registers. Each list contains N * kWarpSize (multiple of 32) +// registers. Each list contains N * WarpSize (multiple of 32) // elements for some N. // The bitonic merge is implemented for arbitrary sizes; -// sorted list A of size N1 * kWarpSize registers -// sorted list B of size N2 * kWarpSize registers => -// sorted list C if size (N1 + N2) * kWarpSize registers. N1 and N2 +// sorted list A of size N1 * WarpSize registers +// sorted list B of size N2 * WarpSize registers => +// sorted list C if size (N1 + N2) * WarpSize registers. N1 and N2 // are >= 1 and don't have to be powers of 2. // -// -perform bitonic sorts on a set of N * kWarpSize key/value pairs +// -perform bitonic sorts on a set of N * WarpSize key/value pairs // held in registers, by using the above bitonic merge as a // primitive. // N can be an arbitrary N >= 1; i.e., the bitonic sort here supports @@ -80,7 +75,7 @@ using raft::KeyValuePair; // performing both < and > comparisons with the variables, so I just // stick with this. -// This function merges kWarpSize / 2L lists in parallel using warp +// This function merges WarpSize / 2L lists in parallel using warp // shuffles. // It works on at most size-16 lists, as we need 32 threads for this // shuffle merge. @@ -88,22 +83,19 @@ using raft::KeyValuePair; // If IsBitonic is false, the first stage is reversed, so we don't // need to sort directionally. It's still technically a bitonic sort. template -inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) +inline __device__ void warpBitonicMergeLE16(K& k, V& v) { static_assert(utils::isPowerOf2(L), "L must be a power-of-2"); - static_assert(L <= kWarpSize / 2, "merge list size must be <= 16"); + static_assert(L <= WarpSize / 2, "merge list size must be <= 16"); - int laneId = getLaneId(); + int laneId = raft::laneId(); if (!IsBitonic) { // Reverse the first comparison stage. // For example, merging a list of size 8 has the exchanges: // 0 <-> 15, 1 <-> 14, ... - K otherK = shfl_xor(k, 2 * L - 1); - K otherVk = shfl_xor(v.key, 2 * L - 1); - V otherVv = shfl_xor(v.value, 2 * L - 1); - - KeyValuePair otherV = KeyValuePair(otherVk, otherVv); + K otherK = shfl_xor(k, 2 * L - 1); + V otherV = shfl_xor(v, 2 * L - 1); // Whether we are the lesser thread in the exchange bool small = !(laneId & L); @@ -114,24 +106,19 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) // alternatives in practice bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); - assign(s, v.value, otherV.value); - assign(s, v.key, otherV.key); + assign(s, v, otherV); } } #pragma unroll for (int stride = IsBitonic ? L : L / 2; stride > 0; stride /= 2) { - K otherK = shfl_xor(k, stride); - K otherVk = shfl_xor(v.key, stride); - V otherVv = shfl_xor(v.value, stride); - - KeyValuePair otherV = KeyValuePair(otherVk, otherVv); + K otherK = shfl_xor(k, stride); + V otherV = shfl_xor(v, stride); // Whether we are the lesser thread in the exchange bool small = !(laneId & stride); @@ -139,14 +126,12 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) if (Dir) { bool s = small ? Comp::gt(k, otherK) : Comp::lt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } else { bool s = small ? Comp::lt(k, otherK) : Comp::gt(k, otherK); assign(s, k, otherK); - assign(s, v.key, otherV.key); - assign(s, v.value, otherV.value); + assign(s, v, otherV); } } } @@ -154,7 +139,7 @@ inline __device__ void warpBitonicMergeLE16KVP(K& k, KeyValuePair& v) // Template for performing a bitonic merge of an arbitrary set of // registers template -struct BitonicMergeStepKVP { +struct BitonicMergeStep { }; // @@ -163,74 +148,69 @@ struct BitonicMergeStepKVP { // All merges eventually call this template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[1], KeyValuePair v[1]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[1], V v[1]) { // Use warp shuffles - warpBitonicMergeLE16KVP(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); } }; template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(utils::isPowerOf2(N), "must be power of 2"); static_assert(N > 1, "must be N > 1"); #pragma unroll for (int i = 0; i < N / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + N / 2]; - KeyValuePair& vb = v[i + N / 2]; + K& kb = k[i + N / 2]; + V& vb = v[i + N / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } { K newK[N / 2]; - KeyValuePair newV[N / 2]; + V newV[N / 2]; #pragma unroll for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < N / 2; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[N / 2]; - KeyValuePair newV[N / 2]; + V newV[N / 2]; #pragma unroll for (int i = 0; i < N / 2; ++i) { - newK[i] = k[i + N / 2]; - newV[i].key = v[i + N / 2].key; - newV[i].value = v[i + N / 2].value; + newK[i] = k[i + N / 2]; + newV[i] = v[i + N / 2]; } - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < N / 2; ++i) { - k[i + N / 2] = newK[i]; - v[i + N / 2].key = newV[i].key; - v[i + N / 2].value = newV[i].value; + k[i + N / 2] = newK[i]; + v[i + N / 2] = newV[i]; } } } @@ -242,8 +222,8 @@ struct BitonicMergeStepKVP { // Low recursion template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); static_assert(N >= 3, "must be N >= 3"); @@ -252,77 +232,73 @@ struct BitonicMergeStepKVP { #pragma unroll for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + kNextHighestPowerOf2 / 2]; - KeyValuePair& vb = v[i + kNextHighestPowerOf2 / 2]; + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } constexpr int kLowSize = N - kNextHighestPowerOf2 / 2; constexpr int kHighSize = kNextHighestPowerOf2 / 2; { K newK[kLowSize]; - KeyValuePair newV[kLowSize]; + V newV[kLowSize]; #pragma unroll for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[kHighSize]; - KeyValuePair newV[kHighSize]; + V newV[kHighSize]; #pragma unroll for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i].key = v[i + kLowSize].key; - newV[i].value = v[i + kLowSize].value; + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; } constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); - BitonicMergeStepKVP::merge(newK, newV); + // constexpr bool kHighIsPowerOf2 = + // utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize].key = newV[i].key; - v[i + kLowSize].value = newV[i].value; + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; } } } @@ -330,8 +306,8 @@ struct BitonicMergeStepKVP { // High recursion template -struct BitonicMergeStepKVP { - static inline __device__ void merge(K k[N], KeyValuePair v[N]) +struct BitonicMergeStep { + static inline __device__ void merge(K k[N], V v[N]) { static_assert(!utils::isPowerOf2(N), "must be non-power-of-2"); static_assert(N >= 3, "must be N >= 3"); @@ -340,149 +316,137 @@ struct BitonicMergeStepKVP { #pragma unroll for (int i = 0; i < N - kNextHighestPowerOf2 / 2; ++i) { - K& ka = k[i]; - KeyValuePair& va = v[i]; + K& ka = k[i]; + V& va = v[i]; - K& kb = k[i + kNextHighestPowerOf2 / 2]; - KeyValuePair& vb = v[i + kNextHighestPowerOf2 / 2]; + K& kb = k[i + kNextHighestPowerOf2 / 2]; + V& vb = v[i + kNextHighestPowerOf2 / 2]; bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb); swap(s, ka, kb); - swap(s, va.key, vb.key); - swap(s, va.value, vb.value); + swap(s, va, vb); } constexpr int kLowSize = kNextHighestPowerOf2 / 2; constexpr int kHighSize = N - kNextHighestPowerOf2 / 2; { K newK[kLowSize]; - KeyValuePair newV[kLowSize]; + V newV[kLowSize]; #pragma unroll for (int i = 0; i < kLowSize; ++i) { - newK[i] = k[i]; - newV[i].key = v[i].key; - newV[i].value = v[i].value; + newK[i] = k[i]; + newV[i] = v[i]; } constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? // constexpr bool kLowIsPowerOf2 = utils::isPowerOf2(kLowSize); - BitonicMergeStepKVP::merge(newK, newV); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kLowSize; ++i) { - k[i] = newK[i]; - v[i].key = newV[i].key; - v[i].value = newV[i].value; + k[i] = newK[i]; + v[i] = newV[i]; } } { K newK[kHighSize]; - KeyValuePair newV[kHighSize]; + V newV[kHighSize]; #pragma unroll for (int i = 0; i < kHighSize; ++i) { - newK[i] = k[i + kLowSize]; - newV[i].key = v[i + kLowSize].key; - newV[i].value = v[i + kLowSize].value; + newK[i] = k[i + kLowSize]; + newV[i] = v[i + kLowSize]; } constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(N - kNextHighestPowerOf2 / 2); // FIXME: compiler doesn't like this expression? compiler bug? - // constexpr bool kHighIsPowerOf2 = utils::isPowerOf2(kHighSize); - BitonicMergeStepKVP::merge(newK, newV); + // constexpr bool kHighIsPowerOf2 = + // utils::isPowerOf2(kHighSize); + BitonicMergeStep::merge(newK, newV); #pragma unroll for (int i = 0; i < kHighSize; ++i) { - k[i + kLowSize] = newK[i]; - v[i + kLowSize].key = newV[i].key; - v[i + kLowSize].value = newV[i].value; + k[i + kLowSize] = newK[i]; + v[i + kLowSize] = newV[i]; } } } }; /// Merges two sets of registers across the warp of any size; -/// i.e., merges a sorted k/v list of size kWarpSize * N1 with a -/// sorted k/v list of size kWarpSize * N2, where N1 and N2 are any +/// i.e., merges a sorted k/v list of size WarpSize * N1 with a +/// sorted k/v list of size WarpSize * N2, where N1 and N2 are any /// value >= 1 template -inline __device__ void warpMergeAnyRegistersKVP(K k1[N1], - KeyValuePair v1[N1], - K k2[N2], - KeyValuePair v2[N2]) +inline __device__ void warpMergeAnyRegisters(K k1[N1], V v1[N1], K k2[N2], V v2[N2]) { constexpr int kSmallestN = N1 < N2 ? N1 : N2; #pragma unroll for (int i = 0; i < kSmallestN; ++i) { - K& ka = k1[N1 - 1 - i]; - KeyValuePair& va = v1[N1 - 1 - i]; + K& ka = k1[N1 - 1 - i]; + V& va = v1[N1 - 1 - i]; - K& kb = k2[i]; - KeyValuePair& vb = v2[i]; + K& kb = k2[i]; + V& vb = v2[i]; K otherKa; - KeyValuePair otherVa; + V otherVa; if (FullMerge) { // We need the other values - otherKa = shfl_xor(ka, kWarpSize - 1); - K otherVak = shfl_xor(va.key, kWarpSize - 1); - V otherVav = shfl_xor(va.value, kWarpSize - 1); - otherVa = KeyValuePair(otherVak, otherVav); + otherKa = shfl_xor(ka, WarpSize - 1); + otherVa = shfl_xor(va, WarpSize - 1); } - K otherKb = shfl_xor(kb, kWarpSize - 1); - K otherVbk = shfl_xor(vb.key, kWarpSize - 1); - V otherVbv = shfl_xor(vb.value, kWarpSize - 1); + K otherKb = shfl_xor(kb, WarpSize - 1); + V otherVb = shfl_xor(vb, WarpSize - 1); // ka is always first in the list, so we needn't use our lane // in this comparison bool swapa = Dir ? Comp::gt(ka, otherKb) : Comp::lt(ka, otherKb); assign(swapa, ka, otherKb); - assign(swapa, va.key, otherVbk); - assign(swapa, va.value, otherVbv); + assign(swapa, va, otherVb); // kb is always second in the list, so we needn't use our lane // in this comparison if (FullMerge) { bool swapb = Dir ? Comp::lt(kb, otherKa) : Comp::gt(kb, otherKa); assign(swapb, kb, otherKa); - assign(swapb, vb.key, otherVa.key); - assign(swapb, vb.value, otherVa.value); + assign(swapb, vb, otherVa); } else { // We don't care about updating elements in the second list } } - BitonicMergeStepKVP::merge(k1, v1); + BitonicMergeStep::merge(k1, v1); if (FullMerge) { // Only if we care about N2 do we need to bother merging it fully - BitonicMergeStepKVP::merge(k2, v2); + BitonicMergeStep::merge(k2, v2); } } // Recursive template that uses the above bitonic merge to perform a // bitonic sort template -struct BitonicSortStepKVP { - static inline __device__ void sort(K k[N], KeyValuePair v[N]) +struct BitonicSortStep { + static inline __device__ void sort(K k[N], V v[N]) { static_assert(N > 1, "did not hit specialized case"); @@ -491,71 +455,67 @@ struct BitonicSortStepKVP { constexpr int kSizeB = N - kSizeA; K aK[kSizeA]; - KeyValuePair aV[kSizeA]; + V aV[kSizeA]; #pragma unroll for (int i = 0; i < kSizeA; ++i) { - aK[i] = k[i]; - aV[i].key = v[i].key; - aV[i].value = v[i].value; + aK[i] = k[i]; + aV[i] = v[i]; } - BitonicSortStepKVP::sort(aK, aV); + BitonicSortStep::sort(aK, aV); K bK[kSizeB]; - KeyValuePair bV[kSizeB]; + V bV[kSizeB]; #pragma unroll for (int i = 0; i < kSizeB; ++i) { - bK[i] = k[i + kSizeA]; - bV[i].key = v[i + kSizeA].key; - bV[i].value = v[i + kSizeA].value; + bK[i] = k[i + kSizeA]; + bV[i] = v[i + kSizeA]; } - BitonicSortStepKVP::sort(bK, bV); + BitonicSortStep::sort(bK, bV); // Merge halves - warpMergeAnyRegistersKVP(aK, aV, bK, bV); + warpMergeAnyRegisters(aK, aV, bK, bV); #pragma unroll for (int i = 0; i < kSizeA; ++i) { - k[i] = aK[i]; - v[i].key = aV[i].key; - v[i].value = aV[i].value; + k[i] = aK[i]; + v[i] = aV[i]; } #pragma unroll for (int i = 0; i < kSizeB; ++i) { - k[i + kSizeA] = bK[i]; - v[i + kSizeA].key = bV[i].key; - v[i + kSizeA].value = bV[i].value; + k[i + kSizeA] = bK[i]; + v[i + kSizeA] = bV[i]; } } }; // Single warp (N == 1) sorting specialization template -struct BitonicSortStepKVP { - static inline __device__ void sort(K k[1], KeyValuePair v[1]) +struct BitonicSortStep { + static inline __device__ void sort(K k[1], V v[1]) { // Update this code if this changes - // should go from 1 -> kWarpSize in multiples of 2 - static_assert(kWarpSize == 32, "unexpected warp size"); - - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); - warpBitonicMergeLE16KVP(k[0], v[0]); + // should go from 1 -> WarpSize in multiples of 2 + static_assert(WarpSize == 32, "unexpected warp size"); + + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); + warpBitonicMergeLE16(k[0], v[0]); } }; -/// Sort a list of kWarpSize * N elements in registers, where N is an +/// Sort a list of WarpSize * N elements in registers, where N is an /// arbitrary >= 1 template -inline __device__ void warpSortAnyRegistersKVP(K k[N], KeyValuePair v[N]) +inline __device__ void warpSortAnyRegisters(K k[N], V v[N]) { - BitonicSortStepKVP::sort(k, v); + BitonicSortStep::sort(k, v); } -} // namespace gpu -} // namespace faiss + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh new file mode 100644 index 0000000000..e4faff7a6c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/Select.cuh @@ -0,0 +1,555 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace raft::spatial::knn::detail::faiss_select { + +// Specialization for block-wide monotonic merges producing a merge sort +// since what we really want is a constexpr loop expansion +template +struct FinalBlockMerge { +}; + +template +struct FinalBlockMerge<1, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + // no merge required; single warp + } +}; + +template +struct FinalBlockMerge<2, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + // Final merge doesn't need to fully merge the second list + blockMerge(sharedK, + sharedV); + } +}; + +template +struct FinalBlockMerge<4, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + blockMerge(sharedK, + sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge( + sharedK, sharedV); + } +}; + +template +struct FinalBlockMerge<8, NumThreads, K, V, NumWarpQ, Dir, Comp> { + static inline __device__ void merge(K* sharedK, V* sharedV) + { + blockMerge(sharedK, + sharedV); + blockMerge(sharedK, + sharedV); + // Final merge doesn't need to fully merge the second list + blockMerge( + sharedK, sharedV); + } +}; + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + static constexpr int kTotalWarpSortSize = NumWarpQ; + + __device__ inline BlockSelect(K initKVal, V initVVal, K* smemK, V* smemV, int k) + : initK(initKVal), + initV(initVVal), + numVals(0), + warpKTop(initKVal), + sharedK(smemK), + sharedV(smemV), + kMinus1(k - 1) + { + static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; + warpK = sharedK + warpId * kTotalWarpSortSize; + warpV = sharedV + warpId * kTotalWarpSortSize; + + // Fill warp queue (only the actual queue space is fine, not where + // we write the per-thread queues for merging) + for (int i = laneId; i < NumWarpQ; i += WarpSize) { + warpK[i] = initK; + warpV[i] = initV; + } + + warpFence(); + } + + __device__ inline void addThreadQ(K k, V v) + { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + // Rotate right +#pragma unroll + for (int i = NumThreadQ - 1; i > 0; --i) { + threadK[i] = threadK[i - 1]; + threadV[i] = threadV[i - 1]; + } + + threadK[0] = k; + threadV[0] = v; + ++numVals; + } + } + + __device__ inline void checkThreadQ() + { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + // This has a trailing warpFence + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = warpK[kMinus1]; + + warpFence(); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() + { + int laneId = raft::laneId(); + + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; + K warpKRegisters[kNumWarpQRegisters]; + V warpVRegisters[kNumWarpQRegisters]; + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpKRegisters[i] = warpK[i * WarpSize + laneId]; + warpVRegisters[i] = warpV[i * WarpSize + laneId]; + } + + warpFence(); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpKRegisters, warpVRegisters, threadK, threadV); + + // Write back out the warp queue +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i * WarpSize + laneId] = warpKRegisters[i]; + warpV[i * WarpSize + laneId] = warpVRegisters[i]; + } + + warpFence(); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) + { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() + { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + + // block-wide dep; thus far, all warps have been completely + // independent + __syncthreads(); + + // All warp queues are contiguous in smem. + // Now, we have kNumWarps lists of NumWarpQ elements. + // This is a power of 2. + FinalBlockMerge::merge(sharedK, sharedV); + + // The block-wide merge has a trailing syncthreads + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // Queues for all warps + K* sharedK; + V* sharedV; + + // Our warp's queue (points into sharedK/sharedV) + // warpK[0] is highest (Dir) or lowest (!Dir) + K* warpK; + V* warpV; + + // This is a cached k-1 value + int kMinus1; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct BlockSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + + __device__ inline BlockSelect(K initK, V initV, K* smemK, V* smemV, int k) + : threadK(initK), threadV(initV), sharedK(smemK), sharedV(smemV) + { + } + + __device__ inline void addThreadQ(K k, V v) + { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() + { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { addThreadQ(k, v); } + + __device__ inline void reduce() + { + // Reduce within the warp + KeyValuePair pair(threadK, threadV); + + if (Dir) { + pair = warpReduce(pair, max_op{}); + } else { + pair = warpReduce(pair, min_op{}); + } + + // Each warp writes out a single value + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; + + if (laneId == 0) { + sharedK[warpId] = pair.key; + sharedV[warpId] = pair.value; + } + + __syncthreads(); + + // We typically use this for small blocks (<= 128), just having the + // first thread in the block perform the reduction across warps is + // faster + if (threadIdx.x == 0) { + threadK = sharedK[0]; + threadV = sharedV[0]; + +#pragma unroll + for (int i = 1; i < kNumWarps; ++i) { + K k = sharedK[i]; + V v = sharedV[i]; + + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + // Hopefully a thread's smem reads/writes are ordered wrt + // itself, so no barrier needed :) + sharedK[0] = threadK; + sharedV[0] = threadV; + } + + // In case other threads wish to read this value + __syncthreads(); + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; + + // Where we reduce in smem + K* sharedK; + V* sharedV; +}; + +// +// per-warp WarpSelect +// + +// `Dir` true, produce largest values. +// `Dir` false, produce smallest values. +template +struct WarpSelect { + static constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; + + __device__ inline WarpSelect(K initKVal, V initVVal, int k) + : initK(initKVal), initV(initVVal), numVals(0), warpKTop(initKVal), kLane((k - 1) % WarpSize) + { + static_assert(utils::isPowerOf2(ThreadsPerBlock), "threads must be a power-of-2"); + static_assert(utils::isPowerOf2(NumWarpQ), "warp queue must be power-of-2"); + + // Fill the per-thread queue keys with the default value +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // Fill the warp queue with the default value +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + warpK[i] = initK; + warpV[i] = initV; + } + } + + __device__ inline void addThreadQ(K k, V v) + { + if (Dir ? Comp::gt(k, warpKTop) : Comp::lt(k, warpKTop)) { + // Rotate right +#pragma unroll + for (int i = NumThreadQ - 1; i > 0; --i) { + threadK[i] = threadK[i - 1]; + threadV[i] = threadV[i - 1]; + } + + threadK[0] = k; + threadV[0] = v; + ++numVals; + } + } + + __device__ inline void checkThreadQ() + { + bool needSort = (numVals == NumThreadQ); + +#if CUDA_VERSION >= 9000 + needSort = __any_sync(0xffffffff, needSort); +#else + needSort = __any(needSort); +#endif + + if (!needSort) { + // no lanes have triggered a sort + return; + } + + mergeWarpQ(); + + // Any top-k elements have been merged into the warp queue; we're + // free to reset the thread queues + numVals = 0; + +#pragma unroll + for (int i = 0; i < NumThreadQ; ++i) { + threadK[i] = initK; + threadV[i] = initV; + } + + // We have to beat at least this element + warpKTop = shfl(warpK[kNumWarpQRegisters - 1], kLane); + } + + /// This function handles sorting and merging together the + /// per-thread queues with the warp-wide queue, creating a sorted + /// list across both + __device__ inline void mergeWarpQ() + { + // Sort all of the per-thread queues + warpSortAnyRegisters(threadK, threadV); + + // The warp queue is already sorted, and now that we've sorted the + // per-thread queue, merge both sorted lists together, producing + // one sorted list + warpMergeAnyRegisters( + warpK, warpV, threadK, threadV); + } + + /// WARNING: all threads in a warp must participate in this. + /// Otherwise, you must call the constituent parts separately. + __device__ inline void add(K k, V v) + { + addThreadQ(k, v); + checkThreadQ(); + } + + __device__ inline void reduce() + { + // Have all warps dump and merge their queues; this will produce + // the final per-warp results + mergeWarpQ(); + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) + { + int laneId = raft::laneId(); + +#pragma unroll + for (int i = 0; i < kNumWarpQRegisters; ++i) { + int idx = i * WarpSize + laneId; + + if (idx < k) { + outK[idx] = warpK[i]; + outV[idx] = warpV[i]; + } + } + } + + // Default element key + const K initK; + + // Default element value + const V initV; + + // Number of valid elements in our thread queue + int numVals; + + // The k-th highest (Dir) or lowest (!Dir) element + K warpKTop; + + // Thread queue values + K threadK[NumThreadQ]; + V threadV[NumThreadQ]; + + // warpK[0] is highest (Dir) or lowest (!Dir) + K warpK[kNumWarpQRegisters]; + V warpV[kNumWarpQRegisters]; + + // This is what lane we should load an approximation (>=k) to the + // kth element from the last register in the warp queue (i.e., + // warpK[kNumWarpQRegisters - 1]). + int kLane; +}; + +/// Specialization for k == 1 (NumWarpQ == 1) +template +struct WarpSelect { + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; + + __device__ inline WarpSelect(K initK, V initV, int k) : threadK(initK), threadV(initV) {} + + __device__ inline void addThreadQ(K k, V v) + { + bool swap = Dir ? Comp::gt(k, threadK) : Comp::lt(k, threadK); + threadK = swap ? k : threadK; + threadV = swap ? v : threadV; + } + + __device__ inline void checkThreadQ() + { + // We don't need to do anything here, since the warp doesn't + // cooperate until the end + } + + __device__ inline void add(K k, V v) { addThreadQ(k, v); } + + __device__ inline void reduce() + { + // Reduce within the warp + KeyValuePair pair(threadK, threadV); + + if (Dir) { + pair = warpReduce(pair, max_op{}); + } else { + pair = warpReduce(pair, min_op{}); + } + + threadK = pair.key; + threadV = pair.value; + } + + /// Dump final k selected values for this warp out + __device__ inline void writeOut(K* outK, V* outV, int k) + { + if (raft::laneId() == 0) { + *outK = threadK; + *outV = threadV; + } + } + + // threadK is lowest (Dir) or highest (!Dir) + K threadK; + V threadV; +}; + +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h b/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h new file mode 100644 index 0000000000..bac051b68c --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/StaticUtils.h @@ -0,0 +1,48 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file thirdparty/LICENSES/LICENSE.faiss + */ + +#pragma once + +#include + +// allow usage for non-CUDA files +#ifndef __host__ +#define __host__ +#define __device__ +#endif + +namespace raft::spatial::knn::detail::faiss_select::utils { + +template +constexpr __host__ __device__ bool isPowerOf2(T v) +{ + return (v && !(v & (v - 1))); +} + +static_assert(isPowerOf2(2048), "isPowerOf2"); +static_assert(!isPowerOf2(3333), "isPowerOf2"); + +template +constexpr __host__ __device__ T nextHighestPowerOf2(T v) +{ + return (isPowerOf2(v) ? (T)2 * v : ((T)1 << (log2(v) + 1))); +} + +static_assert(nextHighestPowerOf2(1) == 2, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(2) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(3) == 4, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(4) == 8, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(15) == 16, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(16) == 32, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2(17) == 32, "nextHighestPowerOf2"); + +static_assert(nextHighestPowerOf2(1536000000u) == 2147483648u, "nextHighestPowerOf2"); +static_assert(nextHighestPowerOf2((size_t)2147483648ULL) == (size_t)4294967296ULL, + "nextHighestPowerOf2"); + +} // namespace raft::spatial::knn::detail::faiss_select::utils diff --git a/cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh b/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh similarity index 80% rename from cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh rename to cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh index 34240fba64..617a26a243 100644 --- a/cpp/include/raft/spatial/knn/detail/block_select_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/faiss_select/key_value_block_select.cuh @@ -2,26 +2,19 @@ * Copyright (c) Facebook, Inc. and its affiliates. * * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. + * LICENSE file thirdparty/LICENSES/LICENSE.faiss */ #pragma once -#include -#include -#include -#include -#include -#include - -#include "warp_select_faiss.cuh" +#include +#include // TODO: Need to think further about the impact (and new boundaries created) on the registers // because this will change the max k that can be processed. One solution might be to break // up k into multiple batches for larger k. -namespace faiss { -namespace gpu { +namespace raft::spatial::knn::detail::faiss_select { // `Dir` true, produce largest values. // `Dir` false, produce smallest values. @@ -33,7 +26,7 @@ template struct KeyValueBlockSelect { - static constexpr int kNumWarps = ThreadsPerBlock / kWarpSize; + static constexpr int kNumWarps = ThreadsPerBlock / WarpSize; static constexpr int kTotalWarpSortSize = NumWarpQ; __device__ inline KeyValueBlockSelect( @@ -59,14 +52,14 @@ struct KeyValueBlockSelect { threadV[i].value = initVv; } - int laneId = getLaneId(); - int warpId = threadIdx.x / kWarpSize; + int laneId = raft::laneId(); + int warpId = threadIdx.x / WarpSize; warpK = sharedK + warpId * kTotalWarpSortSize; warpV = sharedV + warpId * kTotalWarpSortSize; // Fill warp queue (only the actual queue space is fine, not where // we write the per-thread queues for merging) - for (int i = laneId; i < NumWarpQ; i += kWarpSize) { + for (int i = laneId; i < NumWarpQ; i += WarpSize) { warpK[i] = initK; warpV[i].key = initVk; warpV[i].value = initVv; @@ -134,20 +127,20 @@ struct KeyValueBlockSelect { /// list across both __device__ inline void mergeWarpQ() { - int laneId = getLaneId(); + int laneId = raft::laneId(); // Sort all of the per-thread queues - warpSortAnyRegistersKVP(threadK, threadV); + warpSortAnyRegisters, NumThreadQ, !Dir, Comp>(threadK, threadV); - constexpr int kNumWarpQRegisters = NumWarpQ / kWarpSize; + constexpr int kNumWarpQRegisters = NumWarpQ / WarpSize; K warpKRegisters[kNumWarpQRegisters]; KeyValuePair warpVRegisters[kNumWarpQRegisters]; #pragma unroll for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpKRegisters[i] = warpK[i * kWarpSize + laneId]; - warpVRegisters[i].key = warpV[i * kWarpSize + laneId].key; - warpVRegisters[i].value = warpV[i * kWarpSize + laneId].value; + warpKRegisters[i] = warpK[i * WarpSize + laneId]; + warpVRegisters[i].key = warpV[i * WarpSize + laneId].key; + warpVRegisters[i].value = warpV[i * WarpSize + laneId].value; } warpFence(); @@ -155,15 +148,15 @@ struct KeyValueBlockSelect { // The warp queue is already sorted, and now that we've sorted the // per-thread queue, merge both sorted lists together, producing // one sorted list - warpMergeAnyRegistersKVP( + warpMergeAnyRegisters, kNumWarpQRegisters, NumThreadQ, !Dir, Comp, false>( warpKRegisters, warpVRegisters, threadK, threadV); // Write back out the warp queue #pragma unroll for (int i = 0; i < kNumWarpQRegisters; ++i) { - warpK[i * kWarpSize + laneId] = warpKRegisters[i]; - warpV[i * kWarpSize + laneId].key = warpVRegisters[i].key; - warpV[i * kWarpSize + laneId].value = warpVRegisters[i].value; + warpK[i * WarpSize + laneId] = warpKRegisters[i]; + warpV[i * WarpSize + laneId].key = warpVRegisters[i].key; + warpV[i * WarpSize + laneId].value = warpVRegisters[i].value; } warpFence(); @@ -228,5 +221,4 @@ struct KeyValueBlockSelect { int kMinus1; }; -} // namespace gpu -} // namespace faiss \ No newline at end of file +} // namespace raft::spatial::knn::detail::faiss_select diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 85a05877f1..f1f160a154 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,9 +15,9 @@ */ #pragma once #include -#include #include #include +#include // TODO: Need to hide the PairwiseDistance class impl and expose to public API #include "processing.cuh" #include @@ -219,8 +219,8 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x constexpr auto identity = std::numeric_limits::max(); constexpr auto keyMax = std::numeric_limits::max(); constexpr auto Dir = false; - typedef faiss::gpu:: - WarpSelect, NumWarpQ, NumThreadQ, 32> + typedef faiss_select:: + WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index 333fc1c573..e073841dd3 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,12 +18,11 @@ #include #include - -#include -#include +#include #include #include +#include namespace raft { namespace spatial { @@ -61,21 +60,21 @@ __global__ void haversine_knn_kernel(value_idx* out_inds, size_t n_index_rows, int k) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> - heap(faiss::gpu::Limits::getMax(), + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> + heap(std::numeric_limits::max(), std::numeric_limits::max(), smemK, smemV, k); // Grid is exactly sized to rows available - int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize); + int limit = Pow2::roundDown(n_index_rows); const value_t* query_ptr = query + (blockIdx.x * 2); value_t x1 = query_ptr[0]; diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 086cae1089..b246121958 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,13 +23,12 @@ #include #include -#include -#include #include #include #include #include +#include #include #include #include @@ -61,7 +60,7 @@ __global__ void knn_merge_parts_kernel(value_t* inK, int k, value_idx* translations) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + constexpr int kNumWarps = tpb / WarpSize; __shared__ value_t smemK[kNumWarps * warp_q]; __shared__ value_idx smemV[kNumWarps * warp_q]; @@ -69,8 +68,8 @@ __global__ void knn_merge_parts_kernel(value_t* inK, /** * Uses shared memory */ - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> + faiss_select:: + BlockSelect, warp_q, thread_q, tpb> heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available @@ -88,7 +87,7 @@ __global__ void knn_merge_parts_kernel(value_t* inK, value_t* inKStart = inK + (row_idx + col); value_idx* inVStart = inV + (row_idx + col); - int limit = faiss::gpu::utils::roundDown(total_k, faiss::gpu::kWarpSize); + int limit = Pow2::roundDown(total_k); value_idx translation = 0; for (; i < limit; i += tpb) { @@ -134,7 +133,7 @@ inline void knn_merge_parts_impl(value_t* inK, constexpr int n_threads = (warp_q <= 1024) ? 128 : 64; auto block = dim3(n_threads); - auto kInit = faiss::gpu::Limits::getMax(); + auto kInit = std::numeric_limits::max(); auto vInit = -1; knn_merge_parts_kernel <<>>( diff --git a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh index 27c7e006ca..2cdc0fae91 100644 --- a/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/selection_faiss.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include namespace raft { namespace spatial { @@ -50,9 +50,14 @@ __global__ void select_k_kernel(const key_t* inK, __shared__ key_t smemK[kNumWarps * warp_q]; __shared__ payload_t smemV[kNumWarps * warp_q]; - faiss::gpu:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); + faiss_select::BlockSelect, + warp_q, + thread_q, + tpb> + heap(initK, initV, smemK, smemV, k); // Grid is exactly sized to rows available int row = blockIdx.x; diff --git a/thirdparty/LICENSES/LICENSE.faiss b/thirdparty/LICENSES/LICENSE.faiss new file mode 100644 index 0000000000..87cbf536c6 --- /dev/null +++ b/thirdparty/LICENSES/LICENSE.faiss @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) Facebook, Inc. and its affiliates. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file