12 #include "IVFUtils.cuh" 
   13 #include "../utils/DeviceUtils.h" 
   14 #include "../utils/Select.cuh" 
   15 #include "../utils/StaticUtils.h" 
   16 #include "../utils/Tensor.cuh" 
   24 namespace faiss { 
namespace gpu {
 
   26 constexpr 
auto kMax = std::numeric_limits<float>::max();
 
   27 constexpr 
auto kMin = std::numeric_limits<float>::min();
 
   29 template <
int ThreadsPerBlock, 
int NumWarpQ, 
int NumThreadQ, 
bool Dir>
 
   31 pass1SelectLists(Tensor<int, 2, true> prefixSumOffsets,
 
   32                  Tensor<float, 1, true> distance,
 
   35                  Tensor<float, 3, true> heapDistances,
 
   36                  Tensor<int, 3, true> heapIndices) {
 
   37   constexpr 
int kNumWarps = ThreadsPerBlock / kWarpSize;
 
   39   __shared__ 
float smemK[kNumWarps * NumWarpQ];
 
   40   __shared__ 
int smemV[kNumWarps * NumWarpQ];
 
   42   constexpr 
auto kInit = Dir ? kMin : kMax;
 
   43   BlockSelect<float, int, Dir, Comparator<float>,
 
   44             NumWarpQ, NumThreadQ, ThreadsPerBlock>
 
   45     heap(kInit, -1, smemK, smemV, k);
 
   47   auto queryId = blockIdx.y;
 
   48   auto sliceId = blockIdx.x;
 
   49   auto numSlices = gridDim.x;
 
   51   int sliceSize = (nprobe / numSlices);
 
   52   int sliceStart = sliceSize * sliceId;
 
   53   int sliceEnd = sliceId == (numSlices - 1) ? nprobe :
 
   54     sliceStart + sliceSize;
 
   55   auto offsets = prefixSumOffsets[queryId].data();
 
   58   int start = *(&offsets[sliceStart] - 1);
 
   59   int end = offsets[sliceEnd - 1];
 
   61   int num = end - start;
 
   62   int limit = utils::roundDown(num, kWarpSize);
 
   65   auto distanceStart = distance[start].data();
 
   69   for (; i < limit; i += blockDim.x) {
 
   70     heap.add(distanceStart[i], start + i);
 
   75     heap.addThreadQ(distanceStart[i], start + i);
 
   83   for (
int i = threadIdx.x; i < k; i += blockDim.x) {
 
   84     heapDistances[queryId][sliceId][i] = smemK[i];
 
   85     heapIndices[queryId][sliceId][i] = smemV[i];
 
   90 runPass1SelectLists(Tensor<int, 2, true>& prefixSumOffsets,
 
   91                     Tensor<float, 1, true>& distance,
 
   95                     Tensor<float, 3, true>& heapDistances,
 
   96                     Tensor<int, 3, true>& heapIndices,
 
   97                     cudaStream_t stream) {
 
   98   constexpr 
auto kThreadsPerBlock = 128;
 
  100   auto grid = dim3(heapDistances.getSize(1), prefixSumOffsets.getSize(0));
 
  101   auto block = dim3(kThreadsPerBlock);
 
  103 #define RUN_PASS(NUM_WARP_Q, NUM_THREAD_Q, DIR)                         \ 
  105     pass1SelectLists<kThreadsPerBlock, NUM_WARP_Q, NUM_THREAD_Q, DIR>   \ 
  106       <<<grid, block, 0, stream>>>(prefixSumOffsets,                    \ 
  115 #define RUN_PASS_DIR(DIR)                            \ 
  118       RUN_PASS(1, 1, DIR);                           \ 
  119     } else if (k <= 32) {                            \ 
  120       RUN_PASS(32, 2, DIR);                          \ 
  121     } else if (k <= 64) {                            \ 
  122       RUN_PASS(64, 3, DIR);                          \ 
  123     } else if (k <= 128) {                           \ 
  124       RUN_PASS(128, 3, DIR);                         \ 
  125     } else if (k <= 256) {                           \ 
  126       RUN_PASS(256, 4, DIR);                         \ 
  127     } else if (k <= 512) {                           \ 
  128       RUN_PASS(512, 8, DIR);                         \ 
  129     } else if (k <= 1024) {                          \ 
  130       RUN_PASS(1024, 8, DIR);                        \