11 #include "IVFUtils.cuh" 
   12 #include "../utils/DeviceUtils.h" 
   13 #include "../utils/StaticUtils.h" 
   14 #include "../utils/Tensor.cuh" 
   15 #include "../utils/ThrustAllocator.cuh" 
   16 #include <thrust/scan.h> 
   17 #include <thrust/execution_policy.h> 
   19 namespace faiss { 
namespace gpu {
 
   24 getResultLengths(Tensor<int, 2, true> topQueryToCentroid,
 
   27                  Tensor<int, 2, true> length) {
 
   28   int linearThreadId = blockIdx.x * blockDim.x + threadIdx.x;
 
   29   if (linearThreadId >= totalSize) {
 
   33   int nprobe = topQueryToCentroid.getSize(1);
 
   34   int queryId = linearThreadId / nprobe;
 
   35   int listId = linearThreadId % nprobe;
 
   37   int centroidId = topQueryToCentroid[queryId][listId];
 
   40   length[queryId][listId] = (centroidId != -1) ? listLengths[centroidId] : 0;
 
   43 void runCalcListOffsets(Tensor<int, 2, true>& topQueryToCentroid,
 
   44                         thrust::device_vector<int>& listLengths,
 
   45                         Tensor<int, 2, true>& prefixSumOffsets,
 
   46                         Tensor<char, 1, true>& thrustMem,
 
   47                         cudaStream_t stream) {
 
   48   FAISS_ASSERT(topQueryToCentroid.getSize(0) == prefixSumOffsets.getSize(0));
 
   49   FAISS_ASSERT(topQueryToCentroid.getSize(1) == prefixSumOffsets.getSize(1));
 
   51   int totalSize = topQueryToCentroid.numElements();
 
   53   int numThreads = std::min(totalSize, getMaxThreadsCurrentDevice());
 
   54   int numBlocks = utils::divUp(totalSize, numThreads);
 
   56   auto grid = dim3(numBlocks);
 
   57   auto block = dim3(numThreads);
 
   59   getResultLengths<<<grid, block, 0, stream>>>(
 
   61     listLengths.data().get(),
 
   70   GpuResourcesThrustAllocator alloc(thrustMem.data(),
 
   71                                     thrustMem.getSizeInBytes());
 
   73   thrust::inclusive_scan(thrust::cuda::par(alloc).on(stream),
 
   74                          prefixSumOffsets.data(),
 
   75                          prefixSumOffsets.data() + totalSize,
 
   76                          prefixSumOffsets.data());