12 #include "PQScanMultiPassNoPrecomputed.cuh" 
   13 #include "../GpuResources.h" 
   14 #include "PQCodeDistances.cuh" 
   15 #include "PQCodeLoad.cuh" 
   16 #include "IVFUtils.cuh" 
   17 #include "../utils/ConversionOperators.cuh" 
   18 #include "../utils/DeviceTensor.cuh" 
   19 #include "../utils/DeviceUtils.h" 
   20 #include "../utils/Float16.cuh" 
   21 #include "../utils/LoadStoreOperators.cuh" 
   22 #include "../utils/NoTypeTensor.cuh" 
   23 #include "../utils/StaticUtils.h" 
   25 #include "../utils/HostTensor.cuh" 
   27 namespace faiss { 
namespace gpu {
 
   29 bool isSupportedNoPrecomputedSubDimSize(
int dims) {
 
   51 template <
typename LookupT, 
typename LookupVecT>
 
   53   static inline __device__ 
void load(LookupT* smem,
 
   56     constexpr 
int kWordSize = 
sizeof(LookupVecT) / 
sizeof(LookupT);
 
   61     if (numCodes % kWordSize == 0) {
 
   65       constexpr 
int kUnroll = 2;
 
   66       int limitVec = numCodes / (kUnroll * kWordSize * blockDim.x);
 
   67       limitVec *= kUnroll * blockDim.x;
 
   69       LookupVecT* smemV = (LookupVecT*) smem;
 
   70       LookupVecT* codesV = (LookupVecT*) codes;
 
   72       for (
int i = threadIdx.x; i < limitVec; i += kUnroll * blockDim.x) {
 
   73         LookupVecT vals[kUnroll];
 
   76         for (
int j = 0; j < kUnroll; ++j) {
 
   82         for (
int j = 0; j < kUnroll; ++j) {
 
   89       int remainder = limitVec * kWordSize;
 
   91       for (
int i = remainder + threadIdx.x; i < numCodes; i += blockDim.x) {
 
   96       constexpr 
int kUnroll = 4;
 
   98       int limit = utils::roundDown(numCodes, kUnroll * blockDim.x);
 
  101       for (; i < limit; i += kUnroll * blockDim.x) {
 
  102         LookupT vals[kUnroll];
 
  105         for (
int j = 0; j < kUnroll; ++j) {
 
  106           vals[j] = codes[i + j * blockDim.x];
 
  110         for (
int j = 0; j < kUnroll; ++j) {
 
  111           smem[i + j * blockDim.x] = vals[j];
 
  115       for (; i < numCodes; i += blockDim.x) {
 
  122 template <
int NumSubQuantizers, 
typename LookupT, 
typename LookupVecT>
 
  132   const auto codesPerSubQuantizer = pqCentroids.
getSize(2);
 
  135   extern __shared__ 
char smemCodeDistances[];
 
  136   LookupT* codeDist = (LookupT*) smemCodeDistances;
 
  139   auto queryId = blockIdx.y;
 
  140   auto probeId = blockIdx.x;
 
  144   int outBase = *(prefixSumOffsets[queryId][probeId].
data() - 1);
 
  145   float* distanceOut = distance[outBase].
data();
 
  147   auto listId = topQueryToCentroid[queryId][probeId];
 
  153   unsigned char* codeList = (
unsigned char*) listCodes[listId];
 
  154   int limit = listLengths[listId];
 
  156   constexpr 
int kNumCode32 = NumSubQuantizers <= 4 ? 1 :
 
  157     (NumSubQuantizers / 4);
 
  158   unsigned int code32[kNumCode32];
 
  159   unsigned int nextCode32[kNumCode32];
 
  162   if (threadIdx.x < limit) {
 
  163     LoadCode32<NumSubQuantizers>::load(code32, codeList, threadIdx.x);
 
  166   LoadCodeDistances<LookupT, LookupVecT>::load(
 
  168     codeDistances[queryId][probeId].data(),
 
  176   for (
int codeIndex = threadIdx.x;
 
  178        codeIndex += blockDim.x) {
 
  180     if (codeIndex + blockDim.x < limit) {
 
  181       LoadCode32<NumSubQuantizers>::load(
 
  182         nextCode32, codeList, codeIndex + blockDim.x);
 
  188     for (
int word = 0; word < kNumCode32; ++word) {
 
  189       constexpr 
int kBytesPerCode32 =
 
  190         NumSubQuantizers < 4 ? NumSubQuantizers : 4;
 
  192       if (kBytesPerCode32 == 1) {
 
  193         auto code = code32[0];
 
  194         dist = ConvertTo<float>::to(codeDist[code]);
 
  198         for (
int byte = 0; byte < kBytesPerCode32; ++byte) {
 
  199           auto code = getByte(code32[word], byte * 8, 8);
 
  202             codesPerSubQuantizer * (word * kBytesPerCode32 + byte);
 
  204           dist += ConvertTo<float>::to(codeDist[offset + code]);
 
  212     distanceOut[codeIndex] = dist;
 
  216     for (
int word = 0; word < kNumCode32; ++word) {
 
  217       code32[word] = nextCode32[word];
 
  223 runMultiPassTile(Tensor<float, 2, true>& queries,
 
  224                  Tensor<float, 2, true>& centroids,
 
  225                  Tensor<float, 3, true>& pqCentroidsInnermostCode,
 
  226                  NoTypeTensor<4, true>& codeDistances,
 
  227                  Tensor<int, 2, true>& topQueryToCentroid,
 
  228                  bool useFloat16Lookup,
 
  230                  int numSubQuantizers,
 
  231                  int numSubQuantizerCodes,
 
  232                  thrust::device_vector<void*>& listCodes,
 
  233                  thrust::device_vector<void*>& listIndices,
 
  234                  IndicesOptions indicesOptions,
 
  235                  thrust::device_vector<int>& listLengths,
 
  236                  Tensor<char, 1, true>& thrustMem,
 
  237                  Tensor<int, 2, true>& prefixSumOffsets,
 
  238                  Tensor<float, 1, true>& allDistances,
 
  239                  Tensor<float, 3, true>& heapDistances,
 
  240                  Tensor<int, 3, true>& heapIndices,
 
  242                  Tensor<float, 2, true>& outDistances,
 
  243                  Tensor<long, 2, true>& outIndices,
 
  244                  cudaStream_t stream) {
 
  245 #ifndef FAISS_USE_FLOAT16 
  246   FAISS_ASSERT(!useFloat16Lookup);
 
  251   runCalcListOffsets(topQueryToCentroid, listLengths, prefixSumOffsets,
 
  256   runPQCodeDistances(pqCentroidsInnermostCode,
 
  267     auto kThreadsPerBlock = 256;
 
  269     auto grid = dim3(topQueryToCentroid.getSize(1),
 
  270                      topQueryToCentroid.getSize(0));
 
  271     auto block = dim3(kThreadsPerBlock);
 
  274     auto smem = 
sizeof(float);
 
  275 #ifdef FAISS_USE_FLOAT16 
  276     if (useFloat16Lookup) {
 
  280     smem *= numSubQuantizers * numSubQuantizerCodes;
 
  281     FAISS_ASSERT(smem <= getMaxSharedMemPerBlockCurrentDevice());
 
  283 #define RUN_PQ_OPT(NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T)                   \ 
  285       auto codeDistancesT = codeDistances.toTensor<LOOKUP_T>();         \ 
  287       pqScanNoPrecomputedMultiPass<NUM_SUB_Q, LOOKUP_T, LOOKUP_VEC_T>   \ 
  288         <<<grid, block, smem, stream>>>(                                \ 
  290           pqCentroidsInnermostCode,                                     \ 
  291           topQueryToCentroid,                                           \ 
  293           listCodes.data().get(),                                       \ 
  294           listLengths.data().get(),                                     \ 
  299 #ifdef FAISS_USE_FLOAT16 
  300 #define RUN_PQ(NUM_SUB_Q)                       \ 
  302       if (useFloat16Lookup) {                   \ 
  303         RUN_PQ_OPT(NUM_SUB_Q, half, Half8);     \ 
  305         RUN_PQ_OPT(NUM_SUB_Q, float, float4);   \ 
  309 #define RUN_PQ(NUM_SUB_Q)                       \ 
  311       RUN_PQ_OPT(NUM_SUB_Q, float, float4);     \ 
  313 #endif // FAISS_USE_FLOAT16 
  315     switch (bytesPerCode) {
 
  374   runPass1SelectLists(prefixSumOffsets,
 
  376                       topQueryToCentroid.getSize(1),
 
  384   auto flatHeapDistances = heapDistances.downcastInner<2>();
 
  385   auto flatHeapIndices = heapIndices.downcastInner<2>();
 
  387   runPass2SelectLists(flatHeapDistances,
 
  399   CUDA_VERIFY(cudaGetLastError());
 
  402 void runPQScanMultiPassNoPrecomputed(Tensor<float, 2, true>& queries,
 
  403                                      Tensor<float, 2, true>& centroids,
 
  404                                      Tensor<float, 3, true>& pqCentroidsInnermostCode,
 
  405                                      Tensor<int, 2, true>& topQueryToCentroid,
 
  406                                      bool useFloat16Lookup,
 
  408                                      int numSubQuantizers,
 
  409                                      int numSubQuantizerCodes,
 
  410                                      thrust::device_vector<void*>& listCodes,
 
  411                                      thrust::device_vector<void*>& listIndices,
 
  412                                      IndicesOptions indicesOptions,
 
  413                                      thrust::device_vector<int>& listLengths,
 
  417                                      Tensor<float, 2, true>& outDistances,
 
  419                                      Tensor<long, 2, true>& outIndices,
 
  421   constexpr 
int kMinQueryTileSize = 8;
 
  422   constexpr 
int kMaxQueryTileSize = 128;
 
  423   constexpr 
int kThrustMemSize = 16384;
 
  425   int nprobe = topQueryToCentroid.getSize(1);
 
  427   auto& mem = res->getMemoryManagerCurrentDevice();
 
  428   auto stream = res->getDefaultStreamCurrentDevice();
 
  432   DeviceTensor<char, 1, true> thrustMem1(
 
  433     mem, {kThrustMemSize}, stream);
 
  434   DeviceTensor<char, 1, true> thrustMem2(
 
  435     mem, {kThrustMemSize}, stream);
 
  436   DeviceTensor<char, 1, true>* thrustMem[2] =
 
  437     {&thrustMem1, &thrustMem2};
 
  441   size_t sizeAvailable = mem.getSizeAvailable();
 
  445   constexpr 
int kNProbeSplit = 8;
 
  446   int pass2Chunks = std::min(nprobe, kNProbeSplit);
 
  448   size_t sizeForFirstSelectPass =
 
  449     pass2Chunks * k * (
sizeof(float) + 
sizeof(
int));
 
  452   size_t sizePerQuery =
 
  454     ((nprobe * 
sizeof(int) + 
sizeof(
int)) + 
 
  455      nprobe * maxListLength * 
sizeof(
float) + 
 
  457      nprobe * numSubQuantizers * numSubQuantizerCodes * 
sizeof(float) +
 
  458      sizeForFirstSelectPass);
 
  460   int queryTileSize = (int) (sizeAvailable / sizePerQuery);
 
  462   if (queryTileSize < kMinQueryTileSize) {
 
  463     queryTileSize = kMinQueryTileSize;
 
  464   } 
else if (queryTileSize > kMaxQueryTileSize) {
 
  465     queryTileSize = kMaxQueryTileSize;
 
  470   FAISS_ASSERT(queryTileSize * nprobe * maxListLength <
 
  471          std::numeric_limits<int>::max());
 
  476   DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
 
  477     mem, {queryTileSize * nprobe + 1}, stream);
 
  478   DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
 
  479     mem, {queryTileSize * nprobe + 1}, stream);
 
  481   DeviceTensor<int, 2, true> prefixSumOffsets1(
 
  482     prefixSumOffsetSpace1[1].data(),
 
  483     {queryTileSize, nprobe});
 
  484   DeviceTensor<int, 2, true> prefixSumOffsets2(
 
  485     prefixSumOffsetSpace2[1].data(),
 
  486     {queryTileSize, nprobe});
 
  487   DeviceTensor<int, 2, true>* prefixSumOffsets[2] =
 
  488     {&prefixSumOffsets1, &prefixSumOffsets2};
 
  492   CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace1.data(),
 
  496   CUDA_VERIFY(cudaMemsetAsync(prefixSumOffsetSpace2.data(),
 
  501   int codeDistanceTypeSize = 
sizeof(float);
 
  502 #ifdef FAISS_USE_FLOAT16 
  503   if (useFloat16Lookup) {
 
  504     codeDistanceTypeSize = 
sizeof(half);
 
  507   FAISS_ASSERT(!useFloat16Lookup);
 
  508   int codeSize = 
sizeof(float);
 
  511   int totalCodeDistancesSize =
 
  512     queryTileSize * nprobe * numSubQuantizers * numSubQuantizerCodes *
 
  513     codeDistanceTypeSize;
 
  515   DeviceTensor<char, 1, true> codeDistances1Mem(
 
  516     mem, {totalCodeDistancesSize}, stream);
 
  517   NoTypeTensor<4, true> codeDistances1(
 
  518     codeDistances1Mem.data(),
 
  519     codeDistanceTypeSize,
 
  520     {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
 
  522   DeviceTensor<char, 1, true> codeDistances2Mem(
 
  523     mem, {totalCodeDistancesSize}, stream);
 
  524   NoTypeTensor<4, true> codeDistances2(
 
  525     codeDistances2Mem.data(),
 
  526     codeDistanceTypeSize,
 
  527     {queryTileSize, nprobe, numSubQuantizers, numSubQuantizerCodes});
 
  529   NoTypeTensor<4, true>* codeDistances[2] =
 
  530     {&codeDistances1, &codeDistances2};
 
  532   DeviceTensor<float, 1, true> allDistances1(
 
  533     mem, {queryTileSize * nprobe * maxListLength}, stream);
 
  534   DeviceTensor<float, 1, true> allDistances2(
 
  535     mem, {queryTileSize * nprobe * maxListLength}, stream);
 
  536   DeviceTensor<float, 1, true>* allDistances[2] =
 
  537     {&allDistances1, &allDistances2};
 
  539   DeviceTensor<float, 3, true> heapDistances1(
 
  540     mem, {queryTileSize, pass2Chunks, k}, stream);
 
  541   DeviceTensor<float, 3, true> heapDistances2(
 
  542     mem, {queryTileSize, pass2Chunks, k}, stream);
 
  543   DeviceTensor<float, 3, true>* heapDistances[2] =
 
  544     {&heapDistances1, &heapDistances2};
 
  546   DeviceTensor<int, 3, true> heapIndices1(
 
  547     mem, {queryTileSize, pass2Chunks, k}, stream);
 
  548   DeviceTensor<int, 3, true> heapIndices2(
 
  549     mem, {queryTileSize, pass2Chunks, k}, stream);
 
  550   DeviceTensor<int, 3, true>* heapIndices[2] =
 
  551     {&heapIndices1, &heapIndices2};
 
  553   auto streams = res->getAlternateStreamsCurrentDevice();
 
  554   streamWait(streams, {stream});
 
  558   for (
int query = 0; query < queries.getSize(0); query += queryTileSize) {
 
  559     int numQueriesInTile =
 
  560       std::min(queryTileSize, queries.getSize(0) - query);
 
  562     auto prefixSumOffsetsView =
 
  563       prefixSumOffsets[curStream]->narrowOutermost(0, numQueriesInTile);
 
  565     auto codeDistancesView =
 
  566       codeDistances[curStream]->narrowOutermost(0, numQueriesInTile);
 
  567     auto coarseIndicesView =
 
  568       topQueryToCentroid.narrowOutermost(query, numQueriesInTile);
 
  570       queries.narrowOutermost(query, numQueriesInTile);
 
  572     auto heapDistancesView =
 
  573       heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
 
  574     auto heapIndicesView =
 
  575       heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
 
  577     auto outDistanceView =
 
  578       outDistances.narrowOutermost(query, numQueriesInTile);
 
  579     auto outIndicesView =
 
  580       outIndices.narrowOutermost(query, numQueriesInTile);
 
  582     runMultiPassTile(queryView,
 
  584                      pqCentroidsInnermostCode,
 
  590                      numSubQuantizerCodes,
 
  595                      *thrustMem[curStream],
 
  596                      prefixSumOffsetsView,
 
  597                      *allDistances[curStream],
 
  605     curStream = (curStream + 1) % 2;
 
  608   streamWait({stream}, streams);
 
__host__ __device__ DataPtrType data()
Returns a raw pointer to the start of our data. 
__host__ __device__ IndexT getSize(int i) const