Faiss
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator Friends
GpuIndex.cu
1 
2 /**
3  * Copyright (c) 2015-present, Facebook, Inc.
4  * All rights reserved.
5  *
6  * This source code is licensed under the CC-by-NC license found in the
7  * LICENSE file in the root directory of this source tree.
8  */
9 
10 // Copyright 2004-present Facebook. All Rights Reserved.
11 
12 #include "GpuIndex.h"
13 #include "../FaissAssert.h"
14 #include "GpuResources.h"
15 #include "utils/DeviceUtils.h"
16 
17 namespace faiss { namespace gpu {
18 
19 /// Default size for which we page add or search
20 constexpr size_t kAddPageSize = (size_t) 256 * 1024 * 1024;
21 constexpr size_t kSearchPageSize = (size_t) 256 * 1024 * 1024;
22 
23 GpuIndex::GpuIndex(GpuResources* resources,
24  int device,
25  int dims,
26  faiss::MetricType metric) :
27  Index(dims, metric),
28  resources_(resources),
29  device_(device) {
30  FAISS_ASSERT(device_ < getNumDevices());
31 
32  FAISS_ASSERT(resources_);
33  resources_->initializeForDevice(device_);
34 }
35 
36 void
37 GpuIndex::add(Index::idx_t n, const float* x) {
38  addInternal_(n, x, nullptr);
39 }
40 
41 void
43  const float* x,
44  const Index::idx_t* ids) {
45  addInternal_(n, x, ids);
46 }
47 
48 void
50  const float* x,
51  const Index::idx_t* ids) {
52  DeviceScope scope(device_);
53  FAISS_ASSERT(this->is_trained);
54 
55  if (n > 0) {
56  size_t totalSize = n * (size_t) this->d * sizeof(float);
57 
58  if (totalSize > kAddPageSize) {
59  // How many vectors fit into kAddPageSize?
60  size_t numVecsPerPage =
61  kAddPageSize / ((size_t) this->d * sizeof(float));
62 
63  // Always add at least 1 vector, if we have huge vectors
64  numVecsPerPage = std::max(numVecsPerPage, (size_t) 1);
65 
66  for (size_t i = 0; i < n; i += numVecsPerPage) {
67  size_t curNum = std::min(numVecsPerPage, n - i);
68 
69  addImpl_(curNum,
70  x + i * (size_t) this->d,
71  ids ? ids + i : nullptr);
72  }
73  } else {
74  addImpl_(n, x, ids);
75  }
76  }
77 }
78 
79 void
81  const float* x,
82  Index::idx_t k,
83  float* distances,
84  Index::idx_t* labels) const {
85  DeviceScope scope(device_);
86  FAISS_ASSERT(this->is_trained);
87 
88  if (n > 0) {
89  size_t totalSize = n * (size_t) this->d * sizeof(float);
90 
91  if (totalSize > kSearchPageSize) {
92  // How many vectors fit into kSearchPageSize?
93  // Just consider `x`, not the size of `distances` or `labels`
94  // since they should be small, relatively speaking
95  size_t numVecsPerPage =
96  kSearchPageSize / ((size_t) this->d * sizeof(float));
97 
98  // Always search at least 1 vector, if we have huge vectors
99  numVecsPerPage = std::max(numVecsPerPage, (size_t) 1);
100 
101  for (size_t i = 0; i < n; i += numVecsPerPage) {
102  size_t curNum = std::min(numVecsPerPage, n - i);
103 
104  searchImpl_(curNum,
105  x + i * (size_t) this->d,
106  k,
107  distances + i * k,
108  labels + i * k);
109  }
110  } else {
111  searchImpl_(n, x, k, distances, labels);
112  }
113  }
114 }
115 
116 } } // namespace
virtual void searchImpl_(faiss::Index::idx_t n, const float *x, faiss::Index::idx_t k, float *distances, faiss::Index::idx_t *labels) const =0
Overridden to actually perform the search.
void addInternal_(Index::idx_t n, const float *x, const Index::idx_t *ids)
Definition: GpuIndex.cu:49
int device_
The GPU device we are resident on.
Definition: GpuIndex.h:80
int d
vector dimension
Definition: Index.h:66
long idx_t
all indices are this type
Definition: Index.h:64
virtual void addImpl_(Index::idx_t n, const float *x, const Index::idx_t *ids)=0
Overridden to actually perform the add.
virtual void add(faiss::Index::idx_t, const float *x)
Definition: GpuIndex.cu:37
virtual void add_with_ids(Index::idx_t n, const float *x, const Index::idx_t *ids)
Definition: GpuIndex.cu:42
virtual void search(faiss::Index::idx_t n, const float *x, faiss::Index::idx_t k, float *distances, faiss::Index::idx_t *labels) const
Definition: GpuIndex.cu:80
bool is_trained
set if the Index does not require training, or if training is done already
Definition: Index.h:71
MetricType
Some algorithms support both an inner product vetsion and a L2 search version.
Definition: Index.h:44