Commit f7aedbdf authored by matthijs's avatar matthijs

sync with FB version 2017-07-18

- implemented ScalarQuantizer (without IVF)
- implemented update for IndexIVFFlat
- implemented L2 normalization preproc
parent 602debae
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "IndexIVF.h" #include "IndexIVF.h"
#include "IndexIVFPQ.h" #include "IndexIVFPQ.h"
#include "MetaIndexes.h" #include "MetaIndexes.h"
#include "IndexIVFScalarQuantizer.h" #include "IndexScalarQuantizer.h"
namespace faiss { namespace faiss {
...@@ -623,18 +623,28 @@ void ParameterSpace::explore (Index *index, ...@@ -623,18 +623,28 @@ void ParameterSpace::explore (Index *index,
* index_factory * index_factory
***************************************************************/ ***************************************************************/
namespace {
struct VTChain {
std::vector<VectorTransform *> chain;
~VTChain () {
for (int i = 0; i < chain.size(); i++) {
delete chain[i];
}
}
};
}
Index *index_factory (int d, const char *description_in, MetricType metric) Index *index_factory (int d, const char *description_in, MetricType metric)
{ {
VectorTransform *vt = nullptr; VTChain vts;
Index *coarse_quantizer = nullptr; Index *coarse_quantizer = nullptr;
Index *index = nullptr; Index *index = nullptr;
bool add_idmap = false; bool add_idmap = false;
bool make_IndexRefineFlat = false; bool make_IndexRefineFlat = false;
ScopeDeleter1<Index> del_coarse_quantizer, del_index; ScopeDeleter1<Index> del_coarse_quantizer, del_index;
ScopeDeleter1<VectorTransform> del_vt;
char description[strlen(description_in) + 1]; char description[strlen(description_in) + 1];
char *ptr; char *ptr;
...@@ -656,18 +666,27 @@ Index *index_factory (int d, const char *description_in, MetricType metric) ...@@ -656,18 +666,27 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
Index *index_1 = nullptr; Index *index_1 = nullptr;
// VectorTransforms // VectorTransforms
if (!vt && sscanf (tok, "PCA%d", &d_out) == 1) { if (sscanf (tok, "PCA%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out); vt_1 = new PCAMatrix (d, d_out);
d = d_out; d = d_out;
} else if (!vt && sscanf (tok, "PCAR%d", &d_out) == 1) { } else if (sscanf (tok, "PCAR%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, 0, true); vt_1 = new PCAMatrix (d, d_out, 0, true);
d = d_out; d = d_out;
} else if (!vt && sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) { } else if (sscanf (tok, "PCAW%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, -0.5, false);
d = d_out;
} else if (sscanf (tok, "PCAWR%d", &d_out) == 1) {
vt_1 = new PCAMatrix (d, d_out, -0.5, true);
d = d_out;
} else if (sscanf (tok, "OPQ%d_%d", &opq_M, &d_out) == 2) {
vt_1 = new OPQMatrix (d, opq_M, d_out); vt_1 = new OPQMatrix (d, opq_M, d_out);
d = d_out; d = d_out;
} else if (!vt && sscanf (tok, "OPQ%d", &opq_M) == 1) { } else if (sscanf (tok, "OPQ%d", &opq_M) == 1) {
vt_1 = new OPQMatrix (d, opq_M); vt_1 = new OPQMatrix (d, opq_M);
// coarse quantizers } else if (stok == "L2norm") {
vt_1 = new NormalizationTransform (d, 2.0);
// coarse quantizers
} else if (!coarse_quantizer && } else if (!coarse_quantizer &&
sscanf (tok, "IVF%d", &ncentroids) == 1) { sscanf (tok, "IVF%d", &ncentroids) == 1) {
if (metric == METRIC_L2) { if (metric == METRIC_L2) {
...@@ -698,28 +717,25 @@ Index *index_factory (int d, const char *description_in, MetricType metric) ...@@ -698,28 +717,25 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
index_1 = index_ivf; index_1 = index_ivf;
} else { } else {
index_1 = new IndexFlat (d, metric); index_1 = new IndexFlat (d, metric);
if (add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}
} }
} else if (!index && (stok == "SQ8" || stok == "SQ4")) { } else if (!index && (stok == "SQ8" || stok == "SQ4")) {
FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
"ScalarQuantizer works only with an IVF");
ScalarQuantizer::QuantizerType qt = ScalarQuantizer::QuantizerType qt =
stok == "SQ8" ? ScalarQuantizer::QT_8bit : stok == "SQ8" ? ScalarQuantizer::QT_8bit :
stok == "SQ4" ? ScalarQuantizer::QT_4bit : stok == "SQ4" ? ScalarQuantizer::QT_4bit :
ScalarQuantizer::QT_4bit; ScalarQuantizer::QT_4bit;
IndexIVFScalarQuantizer *index_ivf = new IndexIVFScalarQuantizer ( if (coarse_quantizer) {
coarse_quantizer, d, ncentroids, qt, metric); IndexIVFScalarQuantizer *index_ivf =
index_ivf->quantizer_trains_alone = new IndexIVFScalarQuantizer (
dynamic_cast<MultiIndexQuantizer*>(coarse_quantizer) coarse_quantizer, d, ncentroids, qt, metric);
!= nullptr; index_ivf->quantizer_trains_alone =
del_coarse_quantizer.release (); dynamic_cast<MultiIndexQuantizer*>(coarse_quantizer)
index_ivf->own_fields = true; != nullptr;
index_1 = index_ivf; del_coarse_quantizer.release ();
index_ivf->own_fields = true;
index_1 = index_ivf;
} else {
index_1 = new IndexScalarQuantizer (d, qt, metric);
}
} else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) { } else if (!index && sscanf (tok, "PQ%d+%d", &M, &M2) == 2) {
FAISS_THROW_IF_NOT_MSG(coarse_quantizer, FAISS_THROW_IF_NOT_MSG(coarse_quantizer,
"PQ with + works only with an IVF"); "PQ with + works only with an IVF");
...@@ -750,13 +766,6 @@ Index *index_factory (int d, const char *description_in, MetricType metric) ...@@ -750,13 +766,6 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
IndexPQ *index_pq = new IndexPQ (d, M, 8, metric); IndexPQ *index_pq = new IndexPQ (d, M, 8, metric);
index_pq->do_polysemous_training = true; index_pq->do_polysemous_training = true;
index_1 = index_pq; index_1 = index_pq;
if (add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
del_index.set (idmap);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}
} }
} else if (stok == "RFlat") { } else if (stok == "RFlat") {
make_IndexRefineFlat = true; make_IndexRefineFlat = true;
...@@ -765,9 +774,16 @@ Index *index_factory (int d, const char *description_in, MetricType metric) ...@@ -765,9 +774,16 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
tok, description_in); tok, description_in);
} }
if (index_1 && add_idmap) {
IndexIDMap *idmap = new IndexIDMap(index_1);
del_index.set (idmap);
idmap->own_fields = true;
index_1 = idmap;
add_idmap = false;
}
if (vt_1) { if (vt_1) {
vt = vt_1; vts.chain.push_back (vt_1);
del_vt.set (vt);
} }
if (coarse_quantizer_1) { if (coarse_quantizer_1) {
...@@ -793,10 +809,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric) ...@@ -793,10 +809,14 @@ Index *index_factory (int d, const char *description_in, MetricType metric)
"IDMap option not used\n"); "IDMap option not used\n");
} }
if (vt) { if (vts.chain.size() > 0) {
IndexPreTransform *index_pt = new IndexPreTransform (vt, index); IndexPreTransform *index_pt = new IndexPreTransform (index);
del_vt.release ();
index_pt->own_fields = true; index_pt->own_fields = true;
// add from back
while (vts.chain.size() > 0) {
index_pt->prepend_transform (vts.chain.back());
vts.chain.pop_back ();
}
index = index_pt; index = index_pt;
} }
......
...@@ -158,6 +158,10 @@ void RangeSearchPartialResult::set_result (bool incremental) ...@@ -158,6 +158,10 @@ void RangeSearchPartialResult::set_result (bool incremental)
} }
/***********************************************************************
* IDSelectorRange
***********************************************************************/
IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax): IDSelectorRange::IDSelectorRange (idx_t imin, idx_t imax):
imin (imin), imax (imax) imin (imin), imax (imax)
{ {
...@@ -169,6 +173,9 @@ bool IDSelectorRange::is_member (idx_t id) const ...@@ -169,6 +173,9 @@ bool IDSelectorRange::is_member (idx_t id) const
} }
/***********************************************************************
* IDSelectorBatch
***********************************************************************/
IDSelectorBatch::IDSelectorBatch (long n, const idx_t *indices) IDSelectorBatch::IDSelectorBatch (long n, const idx_t *indices)
{ {
......
...@@ -15,12 +15,7 @@ ...@@ -15,12 +15,7 @@
#define FAISS_AUX_INDEX_STRUCTURES_H #define FAISS_AUX_INDEX_STRUCTURES_H
#include <vector> #include <vector>
#if __cplusplus >= 201103L
#include <unordered_set> #include <unordered_set>
#endif
#include <set>
#include "Index.h" #include "Index.h"
...@@ -80,11 +75,7 @@ struct IDSelectorRange: IDSelector { ...@@ -80,11 +75,7 @@ struct IDSelectorRange: IDSelector {
* hash collisions if lsb's are always the same */ * hash collisions if lsb's are always the same */
struct IDSelectorBatch: IDSelector { struct IDSelectorBatch: IDSelector {
#if __cplusplus >= 201103L
std::unordered_set<idx_t> set; std::unordered_set<idx_t> set;
#else
std::set<idx_t> set;
#endif
typedef unsigned char uint8_t; typedef unsigned char uint8_t;
std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
......
...@@ -9,7 +9,6 @@ ...@@ -9,7 +9,6 @@
// Copyright 2004-present Facebook. All Rights Reserved. // Copyright 2004-present Facebook. All Rights Reserved.
#include "FaissException.h" #include "FaissException.h"
#include <cstdio>
namespace faiss { namespace faiss {
...@@ -28,4 +27,9 @@ FaissException::FaissException(const std::string& m, ...@@ -28,4 +27,9 @@ FaissException::FaissException(const std::string& m,
funcName, file, line, m.c_str()); funcName, file, line, m.c_str());
} }
const char*
FaissException::what() const noexcept {
return msg.c_str();
}
} }
...@@ -27,9 +27,7 @@ class FaissException : public std::exception { ...@@ -27,9 +27,7 @@ class FaissException : public std::exception {
int line); int line);
/// from std::exception /// from std::exception
const char* what() const noexcept override const char* what() const noexcept override;
{ return msg.c_str(); }
~FaissException () noexcept override {}
std::string msg; std::string msg;
}; };
......
...@@ -65,21 +65,28 @@ void IndexIVF::add (idx_t n, const float * x) ...@@ -65,21 +65,28 @@ void IndexIVF::add (idx_t n, const float * x)
add_with_ids (n, x, nullptr); add_with_ids (n, x, nullptr);
} }
void IndexIVF::make_direct_map () void IndexIVF::make_direct_map (bool new_maintain_direct_map)
{ {
if (maintain_direct_map) return; // nothing to do
if (new_maintain_direct_map == maintain_direct_map)
direct_map.resize (ntotal, -1); return;
for (size_t key = 0; key < nlist; key++) {
const std::vector<long> & idlist = ids[key]; if (new_maintain_direct_map) {
direct_map.resize (ntotal, -1);
for (long ofs = 0; ofs < idlist.size(); ofs++) { for (size_t key = 0; key < nlist; key++) {
direct_map [idlist [ofs]] = const std::vector<long> & idlist = ids[key];
key << 32 | ofs;
for (long ofs = 0; ofs < idlist.size(); ofs++) {
FAISS_THROW_IF_NOT_MSG (
0 <= idlist [ofs] && idlist[ofs] < ntotal,
"direct map supported only for seuquential ids");
direct_map [idlist [ofs]] = key << 32 | ofs;
}
} }
} else {
direct_map.clear ();
} }
maintain_direct_map = new_maintain_direct_map;
maintain_direct_map = true;
} }
...@@ -183,7 +190,6 @@ void IndexIVF::merge_from (IndexIVF &other, idx_t add_id) ...@@ -183,7 +190,6 @@ void IndexIVF::merge_from (IndexIVF &other, idx_t add_id)
IndexIVF::~IndexIVF() IndexIVF::~IndexIVF()
{ {
if (own_fields) delete quantizer; if (own_fields) delete quantizer;
...@@ -217,6 +223,8 @@ void IndexIVFFlat::add_core (idx_t n, const float * x, const long *xids, ...@@ -217,6 +223,8 @@ void IndexIVFFlat::add_core (idx_t n, const float * x, const long *xids,
{ {
FAISS_THROW_IF_NOT (is_trained); FAISS_THROW_IF_NOT (is_trained);
FAISS_THROW_IF_NOT_MSG (!(maintain_direct_map && xids),
"cannot have direct map and add with ids");
const long * idx; const long * idx;
ScopeDeleter<long> del; ScopeDeleter<long> del;
...@@ -477,6 +485,49 @@ void IndexIVFFlat::copy_subset_to (IndexIVFFlat & other, int subset_type, ...@@ -477,6 +485,49 @@ void IndexIVFFlat::copy_subset_to (IndexIVFFlat & other, int subset_type,
} }
} }
void IndexIVFFlat::update_vectors (int n, idx_t *new_ids, const float *x)
{
FAISS_THROW_IF_NOT (maintain_direct_map);
FAISS_THROW_IF_NOT (is_trained);
std::vector<idx_t> assign (n);
quantizer->assign (n, x, assign.data());
for (int i = 0; i < n; i++) {
idx_t id = new_ids[i];
FAISS_THROW_IF_NOT_MSG (0 <= id && id < ntotal,
"id to update out of range");
{ // remove old one
long dm = direct_map[id];
long ofs = dm & 0xffffffff;
long il = dm >> 32;
size_t l = ids[il].size();
if (ofs != l - 1) {
long id2 = ids[il].back();
ids[il][ofs] = id2;
direct_map[id2] = (il << 32) | ofs;
memcpy (vecs[il].data() + ofs * d,
vecs[il].data() + (l - 1) * d,
d * sizeof(vecs[il][0]));
}
ids[il].pop_back();
vecs[il].resize((l - 1) * d);
}
{ // insert new one
long il = assign[i];
size_t l = ids[il].size();
long dm = (il << 32) | l;
direct_map[id] = dm;
ids[il].push_back (id);
vecs[il].resize((l + 1) * d);
memcpy (vecs[il].data() + l * d,
x + i * d,
d * sizeof(vecs[il][0]));
}
}
}
void IndexIVFFlat::reset() void IndexIVFFlat::reset()
......
...@@ -91,9 +91,12 @@ struct IndexIVF: Index { ...@@ -91,9 +91,12 @@ struct IndexIVF: Index {
size_t get_list_size (size_t list_no) const size_t get_list_size (size_t list_no) const
{ return ids[list_no].size(); } { return ids[list_no].size(); }
/** intialize a direct map
/// intialize a direct map *
void make_direct_map (); * @param new_maintain_direct_map if true, create a direct map,
* else clear it
*/
void make_direct_map (bool new_maintain_direct_map=true);
/// 1= perfectly balanced, >1: imbalanced /// 1= perfectly balanced, >1: imbalanced
double imbalance_factor () const; double imbalance_factor () const;
...@@ -184,6 +187,16 @@ struct IndexIVFFlat: IndexIVF { ...@@ -184,6 +187,16 @@ struct IndexIVFFlat: IndexIVF {
const long * keys, const long * keys,
float_maxheap_array_t * res) const; float_maxheap_array_t * res) const;
/** Update a subset of vectors.
*
* The index must have a direct_map
*
* @param nv nb of vectors to update
* @param idx vector indices to update, size nv
* @param v vectors of new values, size nv*d
*/
void update_vectors (int nv, idx_t *idx, const float *v);
void reconstruct(idx_t key, float* recons) const override; void reconstruct(idx_t key, float* recons) const override;
void merge_from_residuals(IndexIVF& other) override; void merge_from_residuals(IndexIVF& other) override;
......
This diff is collapsed.
/**
* Copyright (c) 2015-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the CC-by-NC license found in the
* LICENSE file in the root directory of this source tree.
*/
#ifndef FAISS_INDEX_IVF_SCALAR_QUANTIZER_H
#define FAISS_INDEX_IVF_SCALAR_QUANTIZER_H
#include <stdint.h>
#include <vector>
#include "IndexIVF.h"
namespace faiss {
/** An IVF implementation where the components of the residuals are
* encoded with a scalar uniform quantizer. All distance computations
* are asymmetric, so the encoded vectors are decoded and approximate
* distances are computed.
*
* The uniform quantizer has a range [vmin, vmax]. The range can be
* the same for all dimensions (uniform) or specific per dimension
* (default).
*/
struct ScalarQuantizer {
enum QuantizerType {
QT_8bit, ///< 8 bits per component
QT_4bit, ///< 4 bits per component
QT_8bit_uniform, ///< same, shared range for all dimensions
QT_4bit_uniform,
};
QuantizerType qtype;
/** The uniform encoder can estimate the range of representable
* values of the unform encoder using different statistics. Here
* rs = rangestat_arg */
// rangestat_arg.
enum RangeStat {
RS_minmax, ///< [min - rs*(max-min), max + rs*(max-min)]
RS_meanstd, ///< [mean - std * rs, mean + std * rs]
RS_quantiles, ///< [Q(rs), Q(1-rs)]
RS_optim, ///< alternate optimization of reconstruction error
};
RangeStat rangestat;
float rangestat_arg;
/// dimension of input vectors
size_t d;
/// bytes per vector
size_t code_size;
/// trained values (including the range)
std::vector<float> trained;
ScalarQuantizer (size_t d, QuantizerType qtype);
ScalarQuantizer ();
void train (size_t n, const float *x);
/// same as compute_code for several vectors
void compute_codes (const float * x,
uint8_t * codes,
size_t n) const ;
/// decode a vector from a given code (or n vectors if third argument)
void decode (const uint8_t *code, float *x, size_t n) const;
};
struct IndexIVFScalarQuantizer:IndexIVF {
ScalarQuantizer sq;
size_t code_size;
/// inverted list codes.
std::vector<std::vector<uint8_t> > codes;
IndexIVFScalarQuantizer(Index *quantizer, size_t d, size_t nlist,
ScalarQuantizer::QuantizerType qtype,
MetricType metric = METRIC_L2);
IndexIVFScalarQuantizer();
void train_residual(idx_t n, const float* x) override;
void add_with_ids(idx_t n, const float* x, const long* xids) override;
void search(
idx_t n,
const float* x,
idx_t k,
float* distances,
idx_t* labels) const override;
void merge_from_residuals(IndexIVF& other) override;
};
}
#endif
...@@ -29,7 +29,7 @@ LIBOBJ=hamming.o utils.o \ ...@@ -29,7 +29,7 @@ LIBOBJ=hamming.o utils.o \
Clustering.o Heap.o VectorTransform.o index_io.o \ Clustering.o Heap.o VectorTransform.o index_io.o \
PolysemousTraining.o MetaIndexes.o Index.o \ PolysemousTraining.o MetaIndexes.o Index.o \
ProductQuantizer.o AutoTune.o AuxIndexStructures.o \ ProductQuantizer.o AutoTune.o AuxIndexStructures.o \
IndexIVFScalarQuantizer.o FaissException.o IndexScalarQuantizer.o FaissException.o
$(LIBNAME).a: $(LIBOBJ) $(LIBNAME).a: $(LIBOBJ)
...@@ -71,7 +71,7 @@ tests/demo_sift1M: tests/demo_sift1M.cpp $(LIBNAME).a ...@@ -71,7 +71,7 @@ tests/demo_sift1M: tests/demo_sift1M.cpp $(LIBNAME).a
HFILES = IndexFlat.h Index.h IndexLSH.h IndexPQ.h IndexIVF.h \ HFILES = IndexFlat.h Index.h IndexLSH.h IndexPQ.h IndexIVF.h \
IndexIVFPQ.h VectorTransform.h index_io.h utils.h \ IndexIVFPQ.h VectorTransform.h index_io.h utils.h \
PolysemousTraining.h Heap.h MetaIndexes.h AuxIndexStructures.h \ PolysemousTraining.h Heap.h MetaIndexes.h AuxIndexStructures.h \
Clustering.h hamming.h AutoTune.h IndexIVFScalarQuantizer.h FaissException.h Clustering.h hamming.h AutoTune.h IndexScalarQuantizer.h FaissException.h
# also silently generates python/swigfaiss.py # also silently generates python/swigfaiss.py
python/swigfaiss_wrap.cxx: swigfaiss.swig $(HFILES) python/swigfaiss_wrap.cxx: swigfaiss.swig $(HFILES)
...@@ -89,11 +89,12 @@ _swigfaiss.so: python/_swigfaiss.so ...@@ -89,11 +89,12 @@ _swigfaiss.so: python/_swigfaiss.so
############################# #############################
# Dependencies # Dependencies
# for i in *.cpp ; do gcc -I.. -MM $i -msse4; done # for i in *.cpp ; do g++ -std=c++11 -I.. -MM $i -msse4; done
AutoTune.o: AutoTune.cpp AutoTune.h Index.h FaissAssert.h \ AutoTune.o: AutoTune.cpp AutoTune.h Index.h FaissAssert.h \
FaissException.h utils.h Heap.h IndexFlat.h VectorTransform.h IndexLSH.h \ FaissException.h utils.h Heap.h IndexFlat.h VectorTransform.h IndexLSH.h \
IndexPQ.h ProductQuantizer.h Clustering.h PolysemousTraining.h \ IndexPQ.h ProductQuantizer.h Clustering.h PolysemousTraining.h \
IndexIVF.h IndexIVFPQ.h MetaIndexes.h IndexIVFScalarQuantizer.h IndexIVF.h IndexIVFPQ.h MetaIndexes.h IndexScalarQuantizer.h
AuxIndexStructures.o: AuxIndexStructures.cpp AuxIndexStructures.h Index.h AuxIndexStructures.o: AuxIndexStructures.cpp AuxIndexStructures.h Index.h
Clustering.o: Clustering.cpp Clustering.h Index.h utils.h Heap.h \ Clustering.o: Clustering.cpp Clustering.h Index.h utils.h Heap.h \
FaissAssert.h FaissException.h IndexFlat.h FaissAssert.h FaissException.h IndexFlat.h
...@@ -106,7 +107,7 @@ IndexFlat.o: IndexFlat.cpp IndexFlat.h Index.h utils.h Heap.h \ ...@@ -106,7 +107,7 @@ IndexFlat.o: IndexFlat.cpp IndexFlat.h Index.h utils.h Heap.h \
index_io.o: index_io.cpp index_io.h FaissAssert.h FaissException.h \ index_io.o: index_io.cpp index_io.h FaissAssert.h FaissException.h \
IndexFlat.h Index.h VectorTransform.h IndexLSH.h IndexPQ.h \ IndexFlat.h Index.h VectorTransform.h IndexLSH.h IndexPQ.h \
ProductQuantizer.h Clustering.h Heap.h PolysemousTraining.h IndexIVF.h \ ProductQuantizer.h Clustering.h Heap.h PolysemousTraining.h IndexIVF.h \
IndexIVFPQ.h MetaIndexes.h IndexIVFScalarQuantizer.h IndexIVFPQ.h MetaIndexes.h IndexScalarQuantizer.h
IndexIVF.o: IndexIVF.cpp IndexIVF.h Index.h Clustering.h Heap.h utils.h \ IndexIVF.o: IndexIVF.cpp IndexIVF.h Index.h Clustering.h Heap.h utils.h \
hamming.h FaissAssert.h FaissException.h IndexFlat.h \ hamming.h FaissAssert.h FaissException.h IndexFlat.h \
AuxIndexStructures.h AuxIndexStructures.h
...@@ -114,13 +115,13 @@ IndexIVFPQ.o: IndexIVFPQ.cpp IndexIVFPQ.h IndexIVF.h Index.h Clustering.h \ ...@@ -114,13 +115,13 @@ IndexIVFPQ.o: IndexIVFPQ.cpp IndexIVFPQ.h IndexIVF.h Index.h Clustering.h \
Heap.h IndexPQ.h ProductQuantizer.h PolysemousTraining.h utils.h \ Heap.h IndexPQ.h ProductQuantizer.h PolysemousTraining.h utils.h \
IndexFlat.h hamming.h FaissAssert.h FaissException.h \ IndexFlat.h hamming.h FaissAssert.h FaissException.h \
AuxIndexStructures.h AuxIndexStructures.h
IndexIVFScalarQuantizer.o: IndexIVFScalarQuantizer.cpp \
IndexIVFScalarQuantizer.h IndexIVF.h Index.h Clustering.h Heap.h utils.h \
FaissAssert.h FaissException.h
IndexLSH.o: IndexLSH.cpp IndexLSH.h Index.h VectorTransform.h utils.h \ IndexLSH.o: IndexLSH.cpp IndexLSH.h Index.h VectorTransform.h utils.h \
Heap.h hamming.h FaissAssert.h FaissException.h Heap.h hamming.h FaissAssert.h FaissException.h
IndexPQ.o: IndexPQ.cpp IndexPQ.h Index.h ProductQuantizer.h Clustering.h \ IndexPQ.o: IndexPQ.cpp IndexPQ.h Index.h ProductQuantizer.h Clustering.h \
Heap.h PolysemousTraining.h FaissAssert.h FaissException.h hamming.h Heap.h PolysemousTraining.h FaissAssert.h FaissException.h hamming.h
IndexScalarQuantizer.o: IndexScalarQuantizer.cpp IndexScalarQuantizer.h \
IndexIVF.h Index.h Clustering.h Heap.h utils.h FaissAssert.h \
FaissException.h
MetaIndexes.o: MetaIndexes.cpp MetaIndexes.h Index.h FaissAssert.h \ MetaIndexes.o: MetaIndexes.cpp MetaIndexes.h Index.h FaissAssert.h \
FaissException.h Heap.h AuxIndexStructures.h FaissException.h Heap.h AuxIndexStructures.h
PolysemousTraining.o: PolysemousTraining.cpp PolysemousTraining.h \ PolysemousTraining.o: PolysemousTraining.cpp PolysemousTraining.h \
......
...@@ -120,6 +120,48 @@ IndexIDMap::~IndexIDMap () ...@@ -120,6 +120,48 @@ IndexIDMap::~IndexIDMap ()
if (own_fields) delete index; if (own_fields) delete index;
} }
/*****************************************************
* IndexIDMap2 implementation
*******************************************************/
IndexIDMap2::IndexIDMap2 (Index *index): IndexIDMap (index)
{}
void IndexIDMap2::add_with_ids(idx_t n, const float* x, const long* xids)
{
size_t prev_ntotal = ntotal;
IndexIDMap::add_with_ids (n, x, xids);
for (size_t i = prev_ntotal; i < ntotal; i++) {
rev_map [id_map [i]] = i;
}
}
void IndexIDMap2::construct_rev_map ()
{
rev_map.clear ();
for (size_t i = 0; i < ntotal; i++) {
rev_map [id_map [i]] = i;
}
}
long IndexIDMap2::remove_ids(const IDSelector& sel)
{
// This is quite inefficient
long nremove = IndexIDMap::remove_ids (sel);
construct_rev_map ();
return nremove;
}
void IndexIDMap2::reconstruct (idx_t key, float * recons) const
{
try {
index->reconstruct (rev_map.at (key), recons);
} catch (const std::out_of_range& e) {
FAISS_THROW_FMT ("key %ld not found", key);
}
}
/***************************************************** /*****************************************************
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <vector> #include <vector>
#include <unordered_map>
#include "Index.h" #include "Index.h"
...@@ -54,6 +55,28 @@ struct IndexIDMap : Index { ...@@ -54,6 +55,28 @@ struct IndexIDMap : Index {
IndexIDMap () {own_fields=false; index=nullptr; } IndexIDMap () {own_fields=false; index=nullptr; }
}; };
/** same as IndexIDMap but also provides an efficient reconstruction
implementation via a 2-way index */
struct IndexIDMap2 : IndexIDMap {
std::unordered_map<idx_t, idx_t> rev_map;
explicit IndexIDMap2 (Index *index);
/// make the rev_map from scratch
void construct_rev_map ();
void add_with_ids(idx_t n, const float* x, const long* xids) override;
long remove_ids(const IDSelector& sel) override;
void reconstruct (idx_t key, float * recons) const override;
~IndexIDMap2() override {}
IndexIDMap2 () {}
};
/** Index that concatenates the results from several sub-indexes /** Index that concatenates the results from several sub-indexes
* *
*/ */
......
...@@ -711,6 +711,32 @@ void OPQMatrix::reverse_transform (idx_t n, const float * xt, ...@@ -711,6 +711,32 @@ void OPQMatrix::reverse_transform (idx_t n, const float * xt,
transform_transpose (n, xt, x); transform_transpose (n, xt, x);
} }
/*********************************************
* NormalizationTransform
*********************************************/
NormalizationTransform::NormalizationTransform (int d, float norm):
VectorTransform (d, d), norm (norm)
{
}
NormalizationTransform::NormalizationTransform ():
VectorTransform (-1, -1), norm (-1)
{
}
void NormalizationTransform::apply_noalloc
(idx_t n, const float* x, float* xt) const
{
if (norm == 2.0) {
memcpy (xt, x, sizeof (x[0]) * n * d_in);
fvec_renorm_L2 (d_in, n, xt);
} else {
FAISS_THROW_MSG ("not implemented");
}
}
/********************************************* /*********************************************
* IndexPreTransform * IndexPreTransform
*********************************************/ *********************************************/
...@@ -730,8 +756,6 @@ IndexPreTransform::IndexPreTransform ( ...@@ -730,8 +756,6 @@ IndexPreTransform::IndexPreTransform (
} }
IndexPreTransform::IndexPreTransform ( IndexPreTransform::IndexPreTransform (
VectorTransform * ltrans, VectorTransform * ltrans,
Index * index): Index * index):
...@@ -766,9 +790,16 @@ IndexPreTransform::~IndexPreTransform () ...@@ -766,9 +790,16 @@ IndexPreTransform::~IndexPreTransform ()
void IndexPreTransform::train (idx_t n, const float *x) void IndexPreTransform::train (idx_t n, const float *x)
{ {
int last_untrained = 0; int last_untrained = 0;
for (int i = 0; i < chain.size(); i++) if (index->is_trained) {
if (!chain[i]->is_trained) last_untrained = i; last_untrained = chain.size();
if (!index->is_trained) last_untrained = chain.size(); } else {
for (int i = chain.size() - 1; i >= 0; i--) {
if (!chain[i]->is_trained) {
last_untrained = i;
break;
}
}
}
const float *prev_x = x; const float *prev_x = x;
ScopeDeleter<float> del; ScopeDeleter<float> del;
......
...@@ -76,7 +76,6 @@ struct VectorTransform { ...@@ -76,7 +76,6 @@ struct VectorTransform {
*/ */
struct LinearTransform: VectorTransform { struct LinearTransform: VectorTransform {
bool have_bias; ///! whether to use the bias term bool have_bias; ///! whether to use the bias term
/// Transformation matrix, size d_out * d_in /// Transformation matrix, size d_out * d_in
...@@ -85,7 +84,6 @@ struct LinearTransform: VectorTransform { ...@@ -85,7 +84,6 @@ struct LinearTransform: VectorTransform {
/// bias vector, size d_out /// bias vector, size d_out
std::vector<float> b; std::vector<float> b;
/// both d_in > d_out and d_out < d_in are supported /// both d_in > d_out and d_out < d_in are supported
explicit LinearTransform (int d_in = 0, int d_out = 0, explicit LinearTransform (int d_in = 0, int d_out = 0,
bool have_bias = false); bool have_bias = false);
...@@ -204,7 +202,6 @@ struct OPQMatrix: LinearTransform { ...@@ -204,7 +202,6 @@ struct OPQMatrix: LinearTransform {
* to compute it with matrix multiplies */ * to compute it with matrix multiplies */
struct RemapDimensionsTransform: VectorTransform { struct RemapDimensionsTransform: VectorTransform {
/// map from output dimension to input, size d_out /// map from output dimension to input, size d_out
/// -1 -> set output to 0 /// -1 -> set output to 0
std::vector<int> map; std::vector<int> map;
...@@ -225,6 +222,18 @@ struct RemapDimensionsTransform: VectorTransform { ...@@ -225,6 +222,18 @@ struct RemapDimensionsTransform: VectorTransform {
}; };
/** per-vector normalization */
struct NormalizationTransform: VectorTransform {
float norm;
explicit NormalizationTransform (int d, float norm = 2.0);
NormalizationTransform ();
void apply_noalloc(idx_t n, const float* x, float* xt) const override;
};
/** Index that applies a LinearTransform transform on vectors before /** Index that applies a LinearTransform transform on vectors before
* handing them over to a sub-index */ * handing them over to a sub-index */
struct IndexPreTransform: Index { struct IndexPreTransform: Index {
......
...@@ -34,8 +34,13 @@ except ImportError as e: ...@@ -34,8 +34,13 @@ except ImportError as e:
################################################################## ##################################################################
def replace_method(the_class, name, replacement): def replace_method(the_class, name, replacement, ignore_missing=False):
orig_method = getattr(the_class, name) try:
orig_method = getattr(the_class, name)
except AttributeError:
if ignore_missing:
return
raise
if orig_method.__name__ == 'replacement_' + name: if orig_method.__name__ == 'replacement_' + name:
# replacement was done in parent class # replacement was done in parent class
return return
...@@ -123,12 +128,31 @@ def handle_Index(the_class): ...@@ -123,12 +128,31 @@ def handle_Index(the_class):
sel = IDSelectorBatch(x.size, swig_ptr(x)) sel = IDSelectorBatch(x.size, swig_ptr(x))
return self.remove_ids_c(sel) return self.remove_ids_c(sel)
def replacement_reconstruct(self, key):
x = np.empty(self.d, dtype=np.float32)
self.reconstruct_c(key, swig_ptr(x))
return x
def replacement_reconstruct_n(self, n0, ni):
x = np.empty((ni, self.d), dtype=np.float32)
self.reconstruct_n_c(n0, ni, swig_ptr(x))
return x
def replacement_update_vectors(self, keys, x):
n = keys.size
assert keys.shape == (n, )
assert x.shape == (n, self.d)
self.update_vectors_c(n, swig_ptr(keys), swig_ptr(x))
replace_method(the_class, 'add', replacement_add) replace_method(the_class, 'add', replacement_add)
replace_method(the_class, 'add_with_ids', replacement_add_with_ids) replace_method(the_class, 'add_with_ids', replacement_add_with_ids)
replace_method(the_class, 'train', replacement_train) replace_method(the_class, 'train', replacement_train)
replace_method(the_class, 'search', replacement_search) replace_method(the_class, 'search', replacement_search)
replace_method(the_class, 'remove_ids', replacement_remove_ids) replace_method(the_class, 'remove_ids', replacement_remove_ids)
replace_method(the_class, 'reconstruct', replacement_reconstruct)
replace_method(the_class, 'reconstruct_n', replacement_reconstruct_n)
replace_method(the_class, 'update_vectors', replacement_update_vectors,
ignore_missing=True)
def handle_VectorTransform(the_class): def handle_VectorTransform(the_class):
...@@ -228,12 +252,13 @@ def vector_float_to_array(v): ...@@ -228,12 +252,13 @@ def vector_float_to_array(v):
class Kmeans: class Kmeans:
def __init__(self, d, k, niter=25, verbose=False): def __init__(self, d, k, niter=25, verbose=False, spherical = False):
self.d = d self.d = d
self.k = k self.k = k
self.cp = ClusteringParameters() self.cp = ClusteringParameters()
self.cp.niter = niter self.cp.niter = niter
self.cp.verbose = verbose self.cp.verbose = verbose
self.cp.spherical = spherical
self.centroids = None self.centroids = None
def train(self, x): def train(self, x):
...@@ -241,7 +266,10 @@ class Kmeans: ...@@ -241,7 +266,10 @@ class Kmeans:
n, d = x.shape n, d = x.shape
assert d == self.d assert d == self.d
clus = Clustering(d, self.k, self.cp) clus = Clustering(d, self.k, self.cp)
self.index = IndexFlatL2(d) if self.cp.spherical:
self.index = IndexFlatIP(d)
else:
self.index = IndexFlatL2(d)
clus.train(x, self.index) clus.train(x, self.index)
centroids = vector_float_to_array(clus.centroids) centroids = vector_float_to_array(clus.centroids)
self.centroids = centroids.reshape(self.k, d) self.centroids = centroids.reshape(self.k, d)
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "IndexIVF.h" #include "IndexIVF.h"
#include "IndexIVFPQ.h" #include "IndexIVFPQ.h"
#include "MetaIndexes.h" #include "MetaIndexes.h"
#include "IndexIVFScalarQuantizer.h" #include "IndexScalarQuantizer.h"
/************************************************************* /*************************************************************
* The I/O format is the content of the class. For objects that are * The I/O format is the content of the class. For objects that are
...@@ -184,6 +184,11 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) { ...@@ -184,6 +184,11 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) {
uint32_t h = fourcc ("RmDT"); uint32_t h = fourcc ("RmDT");
WRITE1 (h); WRITE1 (h);
WRITEVECTOR (rdt->map); WRITEVECTOR (rdt->map);
} else if (const NormalizationTransform *nt =
dynamic_cast<const NormalizationTransform *>(vt)) {
uint32_t h = fourcc ("VNrm");
WRITE1 (h);
WRITE1 (nt->norm);
} else { } else {
FAISS_THROW_MSG ("cannot serialize this"); FAISS_THROW_MSG ("cannot serialize this");
} }
...@@ -261,6 +266,13 @@ void write_index (const Index *idx, FILE *f) { ...@@ -261,6 +266,13 @@ void write_index (const Index *idx, FILE *f) {
WRITE1 (idxp->search_type); WRITE1 (idxp->search_type);
WRITE1 (idxp->encode_signs); WRITE1 (idxp->encode_signs);
WRITE1 (idxp->polysemous_ht); WRITE1 (idxp->polysemous_ht);
} else if(const IndexScalarQuantizer * idxs =
dynamic_cast<const IndexScalarQuantizer *> (idx)) {
uint32_t h = fourcc ("IxSQ");
WRITE1 (h);
write_index_header (idx, f);
write_ScalarQuantizer (&idxs->sq, f);
WRITEVECTOR (idxs->codes);
} else if(const IndexIVFFlat * ivfl = } else if(const IndexIVFFlat * ivfl =
dynamic_cast<const IndexIVFFlat *> (idx)) { dynamic_cast<const IndexIVFFlat *> (idx)) {
uint32_t h = fourcc ("IvFl"); uint32_t h = fourcc ("IvFl");
...@@ -329,7 +341,10 @@ void write_index (const Index *idx, FILE *f) { ...@@ -329,7 +341,10 @@ void write_index (const Index *idx, FILE *f) {
WRITE1 (idxrf->k_factor); WRITE1 (idxrf->k_factor);
} else if(const IndexIDMap * idxmap = } else if(const IndexIDMap * idxmap =
dynamic_cast<const IndexIDMap *> (idx)) { dynamic_cast<const IndexIDMap *> (idx)) {
uint32_t h = fourcc ("IxMp"); uint32_t h =
dynamic_cast<const IndexIDMap2 *> (idx) ? fourcc ("IxM2") :
fourcc ("IxMp");
// no need to store additional info for IndexIDMap2
WRITE1 (h); WRITE1 (h);
write_index_header (idxmap, f); write_index_header (idxmap, f);
write_index (idxmap->index, f); write_index (idxmap->index, f);
...@@ -400,6 +415,10 @@ VectorTransform* read_VectorTransform (FILE *f) { ...@@ -400,6 +415,10 @@ VectorTransform* read_VectorTransform (FILE *f) {
RemapDimensionsTransform *rdt = new RemapDimensionsTransform (); RemapDimensionsTransform *rdt = new RemapDimensionsTransform ();
READVECTOR (rdt->map); READVECTOR (rdt->map);
vt = rdt; vt = rdt;
} else if (h == fourcc ("VNrm")) {
NormalizationTransform *nt = new NormalizationTransform ();
READ1 (nt->norm);
vt = nt;
} else { } else {
FAISS_THROW_MSG("fourcc not recognized"); FAISS_THROW_MSG("fourcc not recognized");
} }
...@@ -582,6 +601,13 @@ Index *read_index (FILE * f, bool try_mmap) { ...@@ -582,6 +601,13 @@ Index *read_index (FILE * f, bool try_mmap) {
for (size_t i = 0; i < ivfl->nlist; i++) for (size_t i = 0; i < ivfl->nlist; i++)
READVECTOR (ivfl->vecs[i]); READVECTOR (ivfl->vecs[i]);
idx = ivfl; idx = ivfl;
} else if (h == fourcc ("IxSQ")) {
IndexScalarQuantizer * idxs = new IndexScalarQuantizer ();
read_index_header (idxs, f);
read_ScalarQuantizer (&idxs->sq, f);
READVECTOR (idxs->codes);
idxs->code_size = idxs->sq.code_size;
idx = idxs;
} else if(h == fourcc ("IvSQ")) { } else if(h == fourcc ("IvSQ")) {
IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer(); IndexIVFScalarQuantizer * ivsc = new IndexIVFScalarQuantizer();
read_ivf_header (ivsc, f); read_ivf_header (ivsc, f);
...@@ -606,8 +632,9 @@ Index *read_index (FILE * f, bool try_mmap) { ...@@ -606,8 +632,9 @@ Index *read_index (FILE * f, bool try_mmap) {
} else { } else {
READ1 (nt); READ1 (nt);
} }
for (int i = 0; i < nt; i++) for (int i = 0; i < nt; i++) {
ixpt->chain.push_back (read_VectorTransform (f)); ixpt->chain.push_back (read_VectorTransform (f));
}
ixpt->index = read_index (f); ixpt->index = read_index (f);
idx = ixpt; idx = ixpt;
} else if(h == fourcc ("Imiq")) { } else if(h == fourcc ("Imiq")) {
...@@ -625,12 +652,16 @@ Index *read_index (FILE * f, bool try_mmap) { ...@@ -625,12 +652,16 @@ Index *read_index (FILE * f, bool try_mmap) {
delete rf; delete rf;
READ1 (idxrf->k_factor); READ1 (idxrf->k_factor);
idx = idxrf; idx = idxrf;
} else if(h == fourcc ("IxMp")) { } else if(h == fourcc ("IxMp") || h == fourcc ("IxM2")) {
IndexIDMap * idxmap = new IndexIDMap (); bool is_map2 = h == fourcc ("IxM2");
IndexIDMap * idxmap = is_map2 ? new IndexIDMap2 () : new IndexIDMap ();
read_index_header (idxmap, f); read_index_header (idxmap, f);
idxmap->index = read_index (f); idxmap->index = read_index (f);
idxmap->own_fields = true; idxmap->own_fields = true;
READVECTOR (idxmap->id_map); READVECTOR (idxmap->id_map);
if (is_map2) {
static_cast<IndexIDMap2*>(idxmap)->construct_rev_map ();
}
idx = idxmap; idx = idxmap;
} else { } else {
fprintf (stderr, "Index type 0x%08x not supported\n", h); fprintf (stderr, "Index type 0x%08x not supported\n", h);
...@@ -698,6 +729,7 @@ IndexIVF * Cloner::clone_IndexIVF (const IndexIVF *ivf) ...@@ -698,6 +729,7 @@ IndexIVF * Cloner::clone_IndexIVF (const IndexIVF *ivf)
TRYCLONE (IndexIVFPQR, ivf) TRYCLONE (IndexIVFPQR, ivf)
TRYCLONE (IndexIVFPQ, ivf) TRYCLONE (IndexIVFPQ, ivf)
TRYCLONE (IndexIVFFlat, ivf) TRYCLONE (IndexIVFFlat, ivf)
TRYCLONE (IndexIVFScalarQuantizer, ivf)
{ {
FAISS_THROW_MSG("clone not supported for this type of IndexIVF"); FAISS_THROW_MSG("clone not supported for this type of IndexIVF");
} }
...@@ -711,6 +743,7 @@ Index *Cloner::clone_Index (const Index *index) ...@@ -711,6 +743,7 @@ Index *Cloner::clone_Index (const Index *index)
TRYCLONE (IndexFlatL2, index) TRYCLONE (IndexFlatL2, index)
TRYCLONE (IndexFlatIP, index) TRYCLONE (IndexFlatIP, index)
TRYCLONE (IndexFlat, index) TRYCLONE (IndexFlat, index)
TRYCLONE (IndexScalarQuantizer, index)
TRYCLONE (MultiIndexQuantizer, index) TRYCLONE (MultiIndexQuantizer, index)
if (const IndexIVF * ivf = dynamic_cast<const IndexIVF*>(index)) { if (const IndexIVF * ivf = dynamic_cast<const IndexIVF*>(index)) {
IndexIVF *res = clone_IndexIVF (ivf); IndexIVF *res = clone_IndexIVF (ivf);
......
...@@ -1094,6 +1094,27 @@ class RemapDimensionsTransform(VectorTransform): ...@@ -1094,6 +1094,27 @@ class RemapDimensionsTransform(VectorTransform):
RemapDimensionsTransform_swigregister = _swigfaiss.RemapDimensionsTransform_swigregister RemapDimensionsTransform_swigregister = _swigfaiss.RemapDimensionsTransform_swigregister
RemapDimensionsTransform_swigregister(RemapDimensionsTransform) RemapDimensionsTransform_swigregister(RemapDimensionsTransform)
class NormalizationTransform(VectorTransform):
__swig_setmethods__ = {}
for _s in [VectorTransform]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, NormalizationTransform, name, value)
__swig_getmethods__ = {}
for _s in [VectorTransform]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, NormalizationTransform, name)
__repr__ = _swig_repr
__swig_setmethods__["norm"] = _swigfaiss.NormalizationTransform_norm_set
__swig_getmethods__["norm"] = _swigfaiss.NormalizationTransform_norm_get
if _newclass:norm = _swig_property(_swigfaiss.NormalizationTransform_norm_get, _swigfaiss.NormalizationTransform_norm_set)
def __init__(self, *args):
this = _swigfaiss.new_NormalizationTransform(*args)
try: self.this.append(this)
except: self.this = this
def apply_noalloc(self, *args): return _swigfaiss.NormalizationTransform_apply_noalloc(self, *args)
__swig_destroy__ = _swigfaiss.delete_NormalizationTransform
__del__ = lambda self : None;
NormalizationTransform_swigregister = _swigfaiss.NormalizationTransform_swigregister
NormalizationTransform_swigregister(NormalizationTransform)
class IndexPreTransform(Index): class IndexPreTransform(Index):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
...@@ -1635,7 +1656,7 @@ class IndexIVF(Index): ...@@ -1635,7 +1656,7 @@ class IndexIVF(Index):
__swig_destroy__ = _swigfaiss.delete_IndexIVF __swig_destroy__ = _swigfaiss.delete_IndexIVF
__del__ = lambda self : None; __del__ = lambda self : None;
def get_list_size(self, *args): return _swigfaiss.IndexIVF_get_list_size(self, *args) def get_list_size(self, *args): return _swigfaiss.IndexIVF_get_list_size(self, *args)
def make_direct_map(self): return _swigfaiss.IndexIVF_make_direct_map(self) def make_direct_map(self, new_maintain_direct_map=True): return _swigfaiss.IndexIVF_make_direct_map(self, new_maintain_direct_map)
def imbalance_factor(self): return _swigfaiss.IndexIVF_imbalance_factor(self) def imbalance_factor(self): return _swigfaiss.IndexIVF_imbalance_factor(self)
def print_stats(self): return _swigfaiss.IndexIVF_print_stats(self) def print_stats(self): return _swigfaiss.IndexIVF_print_stats(self)
IndexIVF_swigregister = _swigfaiss.IndexIVF_swigregister IndexIVF_swigregister = _swigfaiss.IndexIVF_swigregister
...@@ -1690,6 +1711,7 @@ class IndexIVFFlat(IndexIVF): ...@@ -1690,6 +1711,7 @@ class IndexIVFFlat(IndexIVF):
def remove_ids(self, *args): return _swigfaiss.IndexIVFFlat_remove_ids(self, *args) def remove_ids(self, *args): return _swigfaiss.IndexIVFFlat_remove_ids(self, *args)
def search_knn_inner_product(self, *args): return _swigfaiss.IndexIVFFlat_search_knn_inner_product(self, *args) def search_knn_inner_product(self, *args): return _swigfaiss.IndexIVFFlat_search_knn_inner_product(self, *args)
def search_knn_L2sqr(self, *args): return _swigfaiss.IndexIVFFlat_search_knn_L2sqr(self, *args) def search_knn_L2sqr(self, *args): return _swigfaiss.IndexIVFFlat_search_knn_L2sqr(self, *args)
def update_vectors(self, *args): return _swigfaiss.IndexIVFFlat_update_vectors(self, *args)
def reconstruct(self, *args): return _swigfaiss.IndexIVFFlat_reconstruct(self, *args) def reconstruct(self, *args): return _swigfaiss.IndexIVFFlat_reconstruct(self, *args)
def merge_from_residuals(self, *args): return _swigfaiss.IndexIVFFlat_merge_from_residuals(self, *args) def merge_from_residuals(self, *args): return _swigfaiss.IndexIVFFlat_merge_from_residuals(self, *args)
def __init__(self, *args): def __init__(self, *args):
...@@ -1770,6 +1792,38 @@ class ScalarQuantizer(_object): ...@@ -1770,6 +1792,38 @@ class ScalarQuantizer(_object):
ScalarQuantizer_swigregister = _swigfaiss.ScalarQuantizer_swigregister ScalarQuantizer_swigregister = _swigfaiss.ScalarQuantizer_swigregister
ScalarQuantizer_swigregister(ScalarQuantizer) ScalarQuantizer_swigregister(ScalarQuantizer)
class IndexScalarQuantizer(Index):
__swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, IndexScalarQuantizer, name, value)
__swig_getmethods__ = {}
for _s in [Index]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, IndexScalarQuantizer, name)
__repr__ = _swig_repr
__swig_setmethods__["sq"] = _swigfaiss.IndexScalarQuantizer_sq_set
__swig_getmethods__["sq"] = _swigfaiss.IndexScalarQuantizer_sq_get
if _newclass:sq = _swig_property(_swigfaiss.IndexScalarQuantizer_sq_get, _swigfaiss.IndexScalarQuantizer_sq_set)
__swig_setmethods__["codes"] = _swigfaiss.IndexScalarQuantizer_codes_set
__swig_getmethods__["codes"] = _swigfaiss.IndexScalarQuantizer_codes_get
if _newclass:codes = _swig_property(_swigfaiss.IndexScalarQuantizer_codes_get, _swigfaiss.IndexScalarQuantizer_codes_set)
__swig_setmethods__["code_size"] = _swigfaiss.IndexScalarQuantizer_code_size_set
__swig_getmethods__["code_size"] = _swigfaiss.IndexScalarQuantizer_code_size_get
if _newclass:code_size = _swig_property(_swigfaiss.IndexScalarQuantizer_code_size_get, _swigfaiss.IndexScalarQuantizer_code_size_set)
def __init__(self, *args):
this = _swigfaiss.new_IndexScalarQuantizer(*args)
try: self.this.append(this)
except: self.this = this
def train(self, *args): return _swigfaiss.IndexScalarQuantizer_train(self, *args)
def add(self, *args): return _swigfaiss.IndexScalarQuantizer_add(self, *args)
def search(self, *args): return _swigfaiss.IndexScalarQuantizer_search(self, *args)
def reset(self): return _swigfaiss.IndexScalarQuantizer_reset(self)
def reconstruct_n(self, *args): return _swigfaiss.IndexScalarQuantizer_reconstruct_n(self, *args)
def reconstruct(self, *args): return _swigfaiss.IndexScalarQuantizer_reconstruct(self, *args)
__swig_destroy__ = _swigfaiss.delete_IndexScalarQuantizer
__del__ = lambda self : None;
IndexScalarQuantizer_swigregister = _swigfaiss.IndexScalarQuantizer_swigregister
IndexScalarQuantizer_swigregister(IndexScalarQuantizer)
class IndexIVFScalarQuantizer(IndexIVF): class IndexIVFScalarQuantizer(IndexIVF):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [IndexIVF]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [IndexIVF]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
...@@ -2024,6 +2078,30 @@ class IndexIDMap(Index): ...@@ -2024,6 +2078,30 @@ class IndexIDMap(Index):
IndexIDMap_swigregister = _swigfaiss.IndexIDMap_swigregister IndexIDMap_swigregister = _swigfaiss.IndexIDMap_swigregister
IndexIDMap_swigregister(IndexIDMap) IndexIDMap_swigregister(IndexIDMap)
class IndexIDMap2(IndexIDMap):
__swig_setmethods__ = {}
for _s in [IndexIDMap]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, IndexIDMap2, name, value)
__swig_getmethods__ = {}
for _s in [IndexIDMap]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, IndexIDMap2, name)
__repr__ = _swig_repr
__swig_setmethods__["rev_map"] = _swigfaiss.IndexIDMap2_rev_map_set
__swig_getmethods__["rev_map"] = _swigfaiss.IndexIDMap2_rev_map_get
if _newclass:rev_map = _swig_property(_swigfaiss.IndexIDMap2_rev_map_get, _swigfaiss.IndexIDMap2_rev_map_set)
def construct_rev_map(self): return _swigfaiss.IndexIDMap2_construct_rev_map(self)
def add_with_ids(self, *args): return _swigfaiss.IndexIDMap2_add_with_ids(self, *args)
def remove_ids(self, *args): return _swigfaiss.IndexIDMap2_remove_ids(self, *args)
def reconstruct(self, *args): return _swigfaiss.IndexIDMap2_reconstruct(self, *args)
__swig_destroy__ = _swigfaiss.delete_IndexIDMap2
__del__ = lambda self : None;
def __init__(self, *args):
this = _swigfaiss.new_IndexIDMap2(*args)
try: self.this.append(this)
except: self.this = this
IndexIDMap2_swigregister = _swigfaiss.IndexIDMap2_swigregister
IndexIDMap2_swigregister(IndexIDMap2)
class IndexShards(Index): class IndexShards(Index):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
......
...@@ -1163,6 +1163,27 @@ class RemapDimensionsTransform(VectorTransform): ...@@ -1163,6 +1163,27 @@ class RemapDimensionsTransform(VectorTransform):
RemapDimensionsTransform_swigregister = _swigfaiss_gpu.RemapDimensionsTransform_swigregister RemapDimensionsTransform_swigregister = _swigfaiss_gpu.RemapDimensionsTransform_swigregister
RemapDimensionsTransform_swigregister(RemapDimensionsTransform) RemapDimensionsTransform_swigregister(RemapDimensionsTransform)
class NormalizationTransform(VectorTransform):
__swig_setmethods__ = {}
for _s in [VectorTransform]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, NormalizationTransform, name, value)
__swig_getmethods__ = {}
for _s in [VectorTransform]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, NormalizationTransform, name)
__repr__ = _swig_repr
__swig_setmethods__["norm"] = _swigfaiss_gpu.NormalizationTransform_norm_set
__swig_getmethods__["norm"] = _swigfaiss_gpu.NormalizationTransform_norm_get
if _newclass:norm = _swig_property(_swigfaiss_gpu.NormalizationTransform_norm_get, _swigfaiss_gpu.NormalizationTransform_norm_set)
def __init__(self, *args):
this = _swigfaiss_gpu.new_NormalizationTransform(*args)
try: self.this.append(this)
except: self.this = this
def apply_noalloc(self, *args): return _swigfaiss_gpu.NormalizationTransform_apply_noalloc(self, *args)
__swig_destroy__ = _swigfaiss_gpu.delete_NormalizationTransform
__del__ = lambda self : None;
NormalizationTransform_swigregister = _swigfaiss_gpu.NormalizationTransform_swigregister
NormalizationTransform_swigregister(NormalizationTransform)
class IndexPreTransform(Index): class IndexPreTransform(Index):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
...@@ -1704,7 +1725,7 @@ class IndexIVF(Index): ...@@ -1704,7 +1725,7 @@ class IndexIVF(Index):
__swig_destroy__ = _swigfaiss_gpu.delete_IndexIVF __swig_destroy__ = _swigfaiss_gpu.delete_IndexIVF
__del__ = lambda self : None; __del__ = lambda self : None;
def get_list_size(self, *args): return _swigfaiss_gpu.IndexIVF_get_list_size(self, *args) def get_list_size(self, *args): return _swigfaiss_gpu.IndexIVF_get_list_size(self, *args)
def make_direct_map(self): return _swigfaiss_gpu.IndexIVF_make_direct_map(self) def make_direct_map(self, new_maintain_direct_map=True): return _swigfaiss_gpu.IndexIVF_make_direct_map(self, new_maintain_direct_map)
def imbalance_factor(self): return _swigfaiss_gpu.IndexIVF_imbalance_factor(self) def imbalance_factor(self): return _swigfaiss_gpu.IndexIVF_imbalance_factor(self)
def print_stats(self): return _swigfaiss_gpu.IndexIVF_print_stats(self) def print_stats(self): return _swigfaiss_gpu.IndexIVF_print_stats(self)
IndexIVF_swigregister = _swigfaiss_gpu.IndexIVF_swigregister IndexIVF_swigregister = _swigfaiss_gpu.IndexIVF_swigregister
...@@ -1759,6 +1780,7 @@ class IndexIVFFlat(IndexIVF): ...@@ -1759,6 +1780,7 @@ class IndexIVFFlat(IndexIVF):
def remove_ids(self, *args): return _swigfaiss_gpu.IndexIVFFlat_remove_ids(self, *args) def remove_ids(self, *args): return _swigfaiss_gpu.IndexIVFFlat_remove_ids(self, *args)
def search_knn_inner_product(self, *args): return _swigfaiss_gpu.IndexIVFFlat_search_knn_inner_product(self, *args) def search_knn_inner_product(self, *args): return _swigfaiss_gpu.IndexIVFFlat_search_knn_inner_product(self, *args)
def search_knn_L2sqr(self, *args): return _swigfaiss_gpu.IndexIVFFlat_search_knn_L2sqr(self, *args) def search_knn_L2sqr(self, *args): return _swigfaiss_gpu.IndexIVFFlat_search_knn_L2sqr(self, *args)
def update_vectors(self, *args): return _swigfaiss_gpu.IndexIVFFlat_update_vectors(self, *args)
def reconstruct(self, *args): return _swigfaiss_gpu.IndexIVFFlat_reconstruct(self, *args) def reconstruct(self, *args): return _swigfaiss_gpu.IndexIVFFlat_reconstruct(self, *args)
def merge_from_residuals(self, *args): return _swigfaiss_gpu.IndexIVFFlat_merge_from_residuals(self, *args) def merge_from_residuals(self, *args): return _swigfaiss_gpu.IndexIVFFlat_merge_from_residuals(self, *args)
def __init__(self, *args): def __init__(self, *args):
...@@ -1839,6 +1861,38 @@ class ScalarQuantizer(_object): ...@@ -1839,6 +1861,38 @@ class ScalarQuantizer(_object):
ScalarQuantizer_swigregister = _swigfaiss_gpu.ScalarQuantizer_swigregister ScalarQuantizer_swigregister = _swigfaiss_gpu.ScalarQuantizer_swigregister
ScalarQuantizer_swigregister(ScalarQuantizer) ScalarQuantizer_swigregister(ScalarQuantizer)
class IndexScalarQuantizer(Index):
__swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, IndexScalarQuantizer, name, value)
__swig_getmethods__ = {}
for _s in [Index]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, IndexScalarQuantizer, name)
__repr__ = _swig_repr
__swig_setmethods__["sq"] = _swigfaiss_gpu.IndexScalarQuantizer_sq_set
__swig_getmethods__["sq"] = _swigfaiss_gpu.IndexScalarQuantizer_sq_get
if _newclass:sq = _swig_property(_swigfaiss_gpu.IndexScalarQuantizer_sq_get, _swigfaiss_gpu.IndexScalarQuantizer_sq_set)
__swig_setmethods__["codes"] = _swigfaiss_gpu.IndexScalarQuantizer_codes_set
__swig_getmethods__["codes"] = _swigfaiss_gpu.IndexScalarQuantizer_codes_get
if _newclass:codes = _swig_property(_swigfaiss_gpu.IndexScalarQuantizer_codes_get, _swigfaiss_gpu.IndexScalarQuantizer_codes_set)
__swig_setmethods__["code_size"] = _swigfaiss_gpu.IndexScalarQuantizer_code_size_set
__swig_getmethods__["code_size"] = _swigfaiss_gpu.IndexScalarQuantizer_code_size_get
if _newclass:code_size = _swig_property(_swigfaiss_gpu.IndexScalarQuantizer_code_size_get, _swigfaiss_gpu.IndexScalarQuantizer_code_size_set)
def __init__(self, *args):
this = _swigfaiss_gpu.new_IndexScalarQuantizer(*args)
try: self.this.append(this)
except: self.this = this
def train(self, *args): return _swigfaiss_gpu.IndexScalarQuantizer_train(self, *args)
def add(self, *args): return _swigfaiss_gpu.IndexScalarQuantizer_add(self, *args)
def search(self, *args): return _swigfaiss_gpu.IndexScalarQuantizer_search(self, *args)
def reset(self): return _swigfaiss_gpu.IndexScalarQuantizer_reset(self)
def reconstruct_n(self, *args): return _swigfaiss_gpu.IndexScalarQuantizer_reconstruct_n(self, *args)
def reconstruct(self, *args): return _swigfaiss_gpu.IndexScalarQuantizer_reconstruct(self, *args)
__swig_destroy__ = _swigfaiss_gpu.delete_IndexScalarQuantizer
__del__ = lambda self : None;
IndexScalarQuantizer_swigregister = _swigfaiss_gpu.IndexScalarQuantizer_swigregister
IndexScalarQuantizer_swigregister(IndexScalarQuantizer)
class IndexIVFScalarQuantizer(IndexIVF): class IndexIVFScalarQuantizer(IndexIVF):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [IndexIVF]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [IndexIVF]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
...@@ -2093,6 +2147,30 @@ class IndexIDMap(Index): ...@@ -2093,6 +2147,30 @@ class IndexIDMap(Index):
IndexIDMap_swigregister = _swigfaiss_gpu.IndexIDMap_swigregister IndexIDMap_swigregister = _swigfaiss_gpu.IndexIDMap_swigregister
IndexIDMap_swigregister(IndexIDMap) IndexIDMap_swigregister(IndexIDMap)
class IndexIDMap2(IndexIDMap):
__swig_setmethods__ = {}
for _s in [IndexIDMap]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
__setattr__ = lambda self, name, value: _swig_setattr(self, IndexIDMap2, name, value)
__swig_getmethods__ = {}
for _s in [IndexIDMap]: __swig_getmethods__.update(getattr(_s,'__swig_getmethods__',{}))
__getattr__ = lambda self, name: _swig_getattr(self, IndexIDMap2, name)
__repr__ = _swig_repr
__swig_setmethods__["rev_map"] = _swigfaiss_gpu.IndexIDMap2_rev_map_set
__swig_getmethods__["rev_map"] = _swigfaiss_gpu.IndexIDMap2_rev_map_get
if _newclass:rev_map = _swig_property(_swigfaiss_gpu.IndexIDMap2_rev_map_get, _swigfaiss_gpu.IndexIDMap2_rev_map_set)
def construct_rev_map(self): return _swigfaiss_gpu.IndexIDMap2_construct_rev_map(self)
def add_with_ids(self, *args): return _swigfaiss_gpu.IndexIDMap2_add_with_ids(self, *args)
def remove_ids(self, *args): return _swigfaiss_gpu.IndexIDMap2_remove_ids(self, *args)
def reconstruct(self, *args): return _swigfaiss_gpu.IndexIDMap2_reconstruct(self, *args)
__swig_destroy__ = _swigfaiss_gpu.delete_IndexIDMap2
__del__ = lambda self : None;
def __init__(self, *args):
this = _swigfaiss_gpu.new_IndexIDMap2(*args)
try: self.this.append(this)
except: self.this = this
IndexIDMap2_swigregister = _swigfaiss_gpu.IndexIDMap2_swigregister
IndexIDMap2_swigregister(IndexIDMap2)
class IndexShards(Index): class IndexShards(Index):
__swig_setmethods__ = {} __swig_setmethods__ = {}
for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{})) for _s in [Index]: __swig_setmethods__.update(getattr(_s,'__swig_setmethods__',{}))
......
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -74,7 +74,7 @@ extern "C" { ...@@ -74,7 +74,7 @@ extern "C" {
#include "IndexPQ.h" #include "IndexPQ.h"
#include "IndexIVF.h" #include "IndexIVF.h"
#include "IndexIVFPQ.h" #include "IndexIVFPQ.h"
#include "IndexIVFScalarQuantizer.h" #include "IndexScalarQuantizer.h"
#include "MetaIndexes.h" #include "MetaIndexes.h"
#include "FaissAssert.h" #include "FaissAssert.h"
...@@ -240,7 +240,7 @@ int get_num_gpus() ...@@ -240,7 +240,7 @@ int get_num_gpus()
%include "PolysemousTraining.h" %include "PolysemousTraining.h"
%include "IndexPQ.h" %include "IndexPQ.h"
%include "IndexIVF.h" %include "IndexIVF.h"
%include "IndexIVFScalarQuantizer.h" %include "IndexScalarQuantizer.h"
%ignore faiss::IndexIVFPQ::alloc_type; %ignore faiss::IndexIVFPQ::alloc_type;
%include "IndexIVFPQ.h" %include "IndexIVFPQ.h"
...@@ -426,6 +426,7 @@ struct AsyncIndexSearchC { ...@@ -426,6 +426,7 @@ struct AsyncIndexSearchC {
DOWNCAST ( IndexIVF ) DOWNCAST ( IndexIVF )
DOWNCAST ( IndexFlat ) DOWNCAST ( IndexFlat )
DOWNCAST ( IndexPQ ) DOWNCAST ( IndexPQ )
DOWNCAST ( IndexScalarQuantizer )
DOWNCAST ( IndexLSH ) DOWNCAST ( IndexLSH )
DOWNCAST ( IndexPreTransform ) DOWNCAST ( IndexPreTransform )
DOWNCAST ( MultiIndexQuantizer ) DOWNCAST ( MultiIndexQuantizer )
...@@ -457,6 +458,7 @@ struct AsyncIndexSearchC { ...@@ -457,6 +458,7 @@ struct AsyncIndexSearchC {
DOWNCAST (PCAMatrix) DOWNCAST (PCAMatrix)
DOWNCAST (RandomRotationMatrix) DOWNCAST (RandomRotationMatrix)
DOWNCAST (LinearTransform) DOWNCAST (LinearTransform)
DOWNCAST (NormalizationTransform)
DOWNCAST (VectorTransform) DOWNCAST (VectorTransform)
{ {
assert(false); assert(false);
......
...@@ -11,6 +11,7 @@ import numpy as np ...@@ -11,6 +11,7 @@ import numpy as np
import faiss import faiss
import unittest import unittest
class TestClustering(unittest.TestCase): class TestClustering(unittest.TestCase):
def test_clustering(self): def test_clustering(self):
...@@ -34,6 +35,17 @@ class TestClustering(unittest.TestCase): ...@@ -34,6 +35,17 @@ class TestClustering(unittest.TestCase):
# check that 64 centroids give a lower quantization error than 32 # check that 64 centroids give a lower quantization error than 32
self.assertGreater(err32, err64) self.assertGreater(err32, err64)
def test_nasty_clustering(self):
d = 2
np.random.seed(123)
x = np.zeros((100, d), dtype='float32')
for i in range(5):
x[i * 20:i * 20 + 20] = np.random.random(size=d)
# we have 5 distinct points but ask for 10 centroids...
km = faiss.Kmeans(d, 10, niter=10, verbose=True)
km.train(x)
class TestPCA(unittest.TestCase): class TestPCA(unittest.TestCase):
......
...@@ -6,10 +6,8 @@ ...@@ -6,10 +6,8 @@
#! /usr/bin/env python2 #! /usr/bin/env python2
"""this is a basic test script that works with fbmake to check if """this is a basic test script for simple indices work"""
some simple indices work"""
import sys
import numpy as np import numpy as np
import unittest import unittest
import faiss import faiss
...@@ -75,9 +73,9 @@ class TestMultiIndexQuantizer(unittest.TestCase): ...@@ -75,9 +73,9 @@ class TestMultiIndexQuantizer(unittest.TestCase):
self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0) self.assertEqual(np.abs(D1[:, :1] - D5[:, :1]).max(), 0)
class TestIVFScalarQuantizer(unittest.TestCase): class TestScalarQuantizer(unittest.TestCase):
def test_4variants(self): def test_4variants_ivf(self):
d = 32 d = 32
nt = 1500 nt = 1500
nq = 200 nq = 200
...@@ -127,19 +125,39 @@ class TestIVFScalarQuantizer(unittest.TestCase): ...@@ -127,19 +125,39 @@ class TestIVFScalarQuantizer(unittest.TestCase):
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform']) self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform']) self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
def test_4variants(self):
d = 32
nt = 1500
nq = 200
nb = 10000
np.random.seed(123)
class TestRemove(unittest.TestCase): xt = np.random.random(size=(nt, d)).astype('float32')
xq = np.random.random(size=(nq, d)).astype('float32')
xb = np.random.random(size=(nb, d)).astype('float32')
def test_remove(self): index_gt = faiss.IndexFlatL2(d)
# only tests the python interface index_gt.add(xb)
D, I_ref = index_gt.search(xq, 10)
nok = {}
for qname in "QT_4bit QT_4bit_uniform QT_8bit QT_8bit_uniform".split():
qtype = getattr(faiss.ScalarQuantizer, qname)
index = faiss.IndexScalarQuantizer(d, qtype, faiss.METRIC_L2)
index.train(xt)
index.add(xb)
D, I = index.search(xq, 10)
nok[qname] = (I[:, 0] == I_ref[:, 0]).sum()
print(nok)
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_4bit'])
self.assertGreaterEqual(nok['QT_8bit'], nok['QT_8bit_uniform'])
self.assertGreaterEqual(nok['QT_4bit'], nok['QT_4bit_uniform'])
index = faiss.IndexFlat(5)
xb = np.zeros((10, 5), dtype='float32')
xb[:, 0] = np.arange(10) + 1000
index.add(xb)
index.remove_ids(np.arange(5) * 2)
xb2 = faiss.vector_float_to_array(index.xb).reshape(5, 5)
assert np.all(xb2[:, 0] == xb[np.arange(5) * 2 + 1, 0])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -1418,7 +1418,7 @@ int km_update_centroids (const float * x, ...@@ -1418,7 +1418,7 @@ int km_update_centroids (const float * x,
for (size_t ci = 0; ci < k; ci++) { for (size_t ci = 0; ci < k; ci++) {
if (hassign[ci] == 0) { /* need to redefine a centroid */ if (hassign[ci] == 0) { /* need to redefine a centroid */
size_t cj; size_t cj;
for (cj = 0; 1; cj = (cj+1) % k) { for (cj = 0; 1; cj = (cj + 1) % k) {
/* probability to pick this cluster for split */ /* probability to pick this cluster for split */
float p = (hassign[cj] - 1.0) / (float) (n - k); float p = (hassign[cj] - 1.0) / (float) (n - k);
float r = rng.rand_float (); float r = rng.rand_float ();
...@@ -1429,15 +1429,15 @@ int km_update_centroids (const float * x, ...@@ -1429,15 +1429,15 @@ int km_update_centroids (const float * x,
memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d); memcpy (centroids+ci*d, centroids+cj*d, sizeof(*centroids) * d);
/* small symmetric pertubation. Much better than */ /* small symmetric pertubation. Much better than */
for (size_t j = 0; j < d; j++) for (size_t j = 0; j < d; j++) {
if (j % 2 == 0) { if (j % 2 == 0) {
centroids[ci * d + j] *= 1 + EPS; centroids[ci * d + j] *= 1 + EPS;
centroids[cj * d + j] *= 1 - EPS; centroids[cj * d + j] *= 1 - EPS;
} else {
centroids[ci * d + j] *= 1 - EPS;
centroids[cj * d + j] *= 1 + EPS;
} }
else { }
centroids[ci * d + j] *= 1 + EPS;
centroids[cj * d + j] *= 1 - EPS;
}
/* assume even split of the cluster */ /* assume even split of the cluster */
hassign[ci] = hassign[cj] / 2; hassign[ci] = hassign[cj] / 2;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment