Commit abe2b0fd authored by dengos's avatar dengos Committed by Matthijs Douze

read/write index with std::function wrapper (#427)

* add access function to IndexIVF;

* - access for IndexIVF;
- write_index/read_index with std::function<...>;

* - fix test compile on mac;
- adjust write/read with std::function;

* replace std::function with IOReader/IOWriter;

* remove IndexIVF::access // tmp

* PFN_WRITE/READ => WRITE;

* revert mac compile fix;

* rename;

* fix compile;

* reset CMakeList;

* format; remove unused function/header;
parent 433f5c0f
...@@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type, ...@@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
} }
IndexIVF::~IndexIVF() IndexIVF::~IndexIVF()
{ {
if (own_invlists) { if (own_invlists) {
......
...@@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) { ...@@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) {
**************************************************************/ **************************************************************/
#define WRITEANDCHECK(ptr, n) { \ #define WRITEANDCHECK(ptr, n) { \
size_t ret = fwrite (ptr, sizeof (* (ptr)), n, f); \ size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "write error"); \ FAISS_THROW_IF_NOT_MSG(ret == (n), "write error"); \
} }
#define READANDCHECK(ptr, n) { \ #define READANDCHECK(ptr, n) { \
size_t ret = fread (ptr, sizeof (* (ptr)), n, f); \ size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "read error"); \ FAISS_THROW_IF_NOT_MSG(ret == (n), "read error"); \
} }
#define WRITE1(x) WRITEANDCHECK(&(x), 1) #define WRITE1(x) WRITEANDCHECK(&(x), 1)
...@@ -106,15 +106,41 @@ struct ScopeFileCloser { ...@@ -106,15 +106,41 @@ struct ScopeFileCloser {
~ScopeFileCloser () {fclose (f); } ~ScopeFileCloser () {fclose (f); }
}; };
namespace {
struct FileIOReader: IOReader {
FILE *f = nullptr;
FileIOReader(FILE *rf): f(rf) {}
~FileIOReader() = default;
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) override {
return fread(ptr, size, nitems, f);
}
};
struct FileIOWriter: IOWriter {
FILE *f = nullptr;
FileIOWriter(FILE *wf): f(wf) {}
~FileIOWriter() = default;
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) override {
return fwrite(ptr, size, nitems, f);
}
};
} // namespace
/************************************************************* /*************************************************************
* Write * Write
**************************************************************/ **************************************************************/
static void write_index_header (const Index *idx, IOWriter *f) {
static void write_index_header (const Index *idx, FILE *f) {
WRITE1 (idx->d); WRITE1 (idx->d);
WRITE1 (idx->ntotal); WRITE1 (idx->ntotal);
Index::idx_t dummy = 1 << 20; Index::idx_t dummy = 1 << 20;
...@@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) { ...@@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) {
WRITE1 (idx->metric_type); WRITE1 (idx->metric_type);
} }
void write_VectorTransform (const VectorTransform *vt, FILE *f) { void write_VectorTransform (const VectorTransform *vt, IOWriter *f) {
if (const LinearTransform * lt = if (const LinearTransform * lt =
dynamic_cast < const LinearTransform *> (vt)) { dynamic_cast < const LinearTransform *> (vt)) {
if (dynamic_cast<const RandomRotationMatrix *>(lt)) { if (dynamic_cast<const RandomRotationMatrix *>(lt)) {
...@@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) { ...@@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) {
WRITE1 (vt->is_trained); WRITE1 (vt->is_trained);
} }
static void write_ProductQuantizer (const ProductQuantizer *pq, FILE *f) { static void write_ProductQuantizer (
const ProductQuantizer *pq, IOWriter *f) {
WRITE1 (pq->d); WRITE1 (pq->d);
WRITE1 (pq->M); WRITE1 (pq->M);
WRITE1 (pq->nbits); WRITE1 (pq->nbits);
WRITEVECTOR (pq->centroids); WRITEVECTOR (pq->centroids);
} }
static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) { static void write_ScalarQuantizer (
const ScalarQuantizer *ivsc, IOWriter *f) {
WRITE1 (ivsc->qtype); WRITE1 (ivsc->qtype);
WRITE1 (ivsc->rangestat); WRITE1 (ivsc->rangestat);
WRITE1 (ivsc->rangestat_arg); WRITE1 (ivsc->rangestat_arg);
...@@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) { ...@@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
WRITEVECTOR (ivsc->trained); WRITEVECTOR (ivsc->trained);
} }
static void write_InvertedLists (const InvertedLists *ils, FILE *f) { static void write_InvertedLists (const InvertedLists *ils, IOWriter *f) {
if (ils == nullptr) { if (ils == nullptr) {
uint32_t h = fourcc ("il00"); uint32_t h = fourcc ("il00");
WRITE1 (h); WRITE1 (h);
...@@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) { ...@@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) {
FILE *f = fopen (fname, "w"); FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f); ScopeFileCloser closer(f);
write_ProductQuantizer (pq, f);
FileIOWriter writer(f);
write_ProductQuantizer (pq, &writer);
} }
static void write_HNSW (const HNSW *hnsw, FILE *f) { static void write_HNSW (const HNSW *hnsw, IOWriter *f) {
WRITEVECTOR (hnsw->assign_probas); WRITEVECTOR (hnsw->assign_probas);
WRITEVECTOR (hnsw->cum_nneighbor_per_level); WRITEVECTOR (hnsw->cum_nneighbor_per_level);
...@@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) { ...@@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) {
WRITE1 (hnsw->efConstruction); WRITE1 (hnsw->efConstruction);
WRITE1 (hnsw->efSearch); WRITE1 (hnsw->efSearch);
WRITE1 (hnsw->upper_beam); WRITE1 (hnsw->upper_beam);
} }
static void write_ivf_header (const IndexIVF * ivf, FILE *f) { static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) {
write_index_header (ivf, f); write_index_header (ivf, f);
WRITE1 (ivf->nlist); WRITE1 (ivf->nlist);
WRITE1 (ivf->nprobe); WRITE1 (ivf->nprobe);
...@@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) { ...@@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
WRITEVECTOR (ivf->direct_map); WRITEVECTOR (ivf->direct_map);
} }
void write_index (const Index *idx, FILE *f) { void write_index (const Index *idx, IOWriter *f) {
if (const IndexFlat * idxf = dynamic_cast<const IndexFlat *> (idx)) { if (const IndexFlat * idxf = dynamic_cast<const IndexFlat *> (idx)) {
uint32_t h = fourcc ( uint32_t h = fourcc (
idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" : idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" :
...@@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) { ...@@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) {
} }
} }
void write_index (const Index *idx, FILE *f) {
FileIOWriter writer(f);
write_index(idx, &writer);
}
void write_index (const Index *idx, const char *fname) { void write_index (const Index *idx, const char *fname) {
FILE *f = fopen (fname, "w"); FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
...@@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) { ...@@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) {
FILE *f = fopen (fname, "w"); FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname); FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f); ScopeFileCloser closer(f);
write_VectorTransform (vt, f);
FileIOWriter writer(f);
write_VectorTransform (vt, &writer);
} }
/************************************************************* /*************************************************************
* Read * Read
**************************************************************/ **************************************************************/
static void read_index_header (Index *idx, FILE *f) { static void read_index_header (Index *idx, IOReader *f) {
READ1 (idx->d); READ1 (idx->d);
READ1 (idx->ntotal); READ1 (idx->ntotal);
Index::idx_t dummy; Index::idx_t dummy;
...@@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) { ...@@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) {
idx->verbose = false; idx->verbose = false;
} }
VectorTransform* read_VectorTransform (FILE *f) { VectorTransform* read_VectorTransform (IOReader *f) {
uint32_t h; uint32_t h;
READ1 (h); READ1 (h);
VectorTransform *vt = nullptr; VectorTransform *vt = nullptr;
...@@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) { ...@@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) {
static void read_ArrayInvertedLists_sizes ( static void read_ArrayInvertedLists_sizes (
FILE *f, std::vector<size_t> & sizes) IOReader *f, std::vector<size_t> & sizes)
{ {
size_t nlist = sizes.size(); size_t nlist = sizes.size();
uint32_t list_type; uint32_t list_type;
...@@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes ( ...@@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes (
} }
} }
InvertedLists *read_InvertedLists (IOReader *f, int io_flags) {
InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
uint32_t h; uint32_t h;
READ1 (h); READ1 (h);
if (h == fourcc ("il00")) { if (h == fourcc ("il00")) {
...@@ -545,6 +580,10 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ...@@ -545,6 +580,10 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
} }
return ails; return ails;
} else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) { } else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) {
auto impl = dynamic_cast<FileIOReader*>(f);
FAISS_THROW_IF_NOT(NULL != impl);
FILE *raw_f = impl->f;
auto ails = new OnDiskInvertedLists (); auto ails = new OnDiskInvertedLists ();
READ1 (ails->nlist); READ1 (ails->nlist);
READ1 (ails->code_size); READ1 (ails->code_size);
...@@ -552,16 +591,16 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ...@@ -552,16 +591,16 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
ails->lists.resize (ails->nlist); ails->lists.resize (ails->nlist);
std::vector<size_t> sizes (ails->nlist); std::vector<size_t> sizes (ails->nlist);
read_ArrayInvertedLists_sizes (f, sizes); read_ArrayInvertedLists_sizes (f, sizes);
size_t o0 = ftell (f), o = o0; size_t o0 = ftell (raw_f), o = o0;
{ // do the mmap { // do the mmap
struct stat buf; struct stat buf;
int ret = fstat (fileno(f), &buf); int ret = fstat (fileno(raw_f), &buf);
FAISS_THROW_IF_NOT_FMT (ret == 0, FAISS_THROW_IF_NOT_FMT (ret == 0,
"fstat failed: %s", strerror(errno)); "fstat failed: %s", strerror(errno));
ails->totsize = buf.st_size; ails->totsize = buf.st_size;
ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize, ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize,
PROT_READ, MAP_SHARED, PROT_READ, MAP_SHARED,
fileno (f), 0); fileno (raw_f), 0);
FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED, FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED,
"could not mmap: %s", "could not mmap: %s",
strerror(errno)); strerror(errno));
...@@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ...@@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
ails->code_size); ails->code_size);
} }
// resume normal reading of file // resume normal reading of file
fseek (f, o, SEEK_SET); fseek (raw_f, o, SEEK_SET);
return ails; return ails;
} else if (h == fourcc ("ilod")) { } else if (h == fourcc ("ilod")) {
OnDiskInvertedLists *od = new OnDiskInvertedLists(); OnDiskInvertedLists *od = new OnDiskInvertedLists();
...@@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) { ...@@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
} }
} }
static void read_InvertedLists (IndexIVF *ivf, FILE *f, int io_flags) { static void read_InvertedLists (
IndexIVF *ivf, IOReader *f, int io_flags) {
InvertedLists *ils = read_InvertedLists (f, io_flags); InvertedLists *ils = read_InvertedLists (f, io_flags);
FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist && FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist &&
ils->code_size == ivf->code_size); ils->code_size == ivf->code_size);
ivf->invlists = ils; ivf->invlists = ils;
ivf->own_invlists = true; ivf->own_invlists = true;
} }
static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) {
static void read_ProductQuantizer (ProductQuantizer *pq, FILE *f) {
READ1 (pq->d); READ1 (pq->d);
READ1 (pq->M); READ1 (pq->M);
READ1 (pq->nbits); READ1 (pq->nbits);
pq->set_derived_values (); pq->set_derived_values ();
READVECTOR (pq->centroids); READVECTOR (pq->centroids);
} }
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) { static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) {
READ1 (ivsc->qtype); READ1 (ivsc->qtype);
READ1 (ivsc->rangestat); READ1 (ivsc->rangestat);
READ1 (ivsc->rangestat_arg); READ1 (ivsc->rangestat_arg);
...@@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) { ...@@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
} }
static void read_HNSW (HNSW *hnsw, FILE *f) { static void read_HNSW (HNSW *hnsw, IOReader *f) {
READVECTOR (hnsw->assign_probas); READVECTOR (hnsw->assign_probas);
READVECTOR (hnsw->cum_nneighbor_per_level); READVECTOR (hnsw->cum_nneighbor_per_level);
READVECTOR (hnsw->levels); READVECTOR (hnsw->levels);
...@@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) { ...@@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) {
ScopeFileCloser closer(f); ScopeFileCloser closer(f);
ProductQuantizer *pq = new ProductQuantizer(); ProductQuantizer *pq = new ProductQuantizer();
ScopeDeleter1<ProductQuantizer> del (pq); ScopeDeleter1<ProductQuantizer> del (pq);
read_ProductQuantizer(pq, f);
FileIOReader reader(f);
read_ProductQuantizer(pq, &reader);
del.release (); del.release ();
return pq; return pq;
} }
static void read_ivf_header ( static void read_ivf_header (
IndexIVF * ivf, FILE *f, IndexIVF *ivf, IOReader *f,
std::vector<std::vector<Index::idx_t> > *ids = nullptr) std::vector<std::vector<Index::idx_t> > *ids = nullptr)
{ {
read_index_header (ivf, f); read_index_header (ivf, f);
READ1 (ivf->nlist); READ1 (ivf->nlist);
...@@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist( ...@@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist(
return ail; return ail;
} }
static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags) static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags)
{ {
bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ"); bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ");
...@@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags) ...@@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
int read_old_fmt_hack = 0; int read_old_fmt_hack = 0;
Index *read_index (FILE * f, int io_flags) { Index *read_index (IOReader *f, int io_flags) {
Index * idx = nullptr; Index * idx = nullptr;
uint32_t h; uint32_t h;
READ1 (h); READ1 (h);
...@@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) { ...@@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) {
} }
Index *read_index (FILE * f, int io_flags) {
FileIOReader reader(f);
return read_index(&reader, io_flags);
}
Index *read_index (const char *fname, int io_flags) { Index *read_index (const char *fname, int io_flags) {
FILE *f = fopen (fname, "r"); FILE *f = fopen (fname, "r");
...@@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) { ...@@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) {
perror (""); perror ("");
abort (); abort ();
} }
VectorTransform *vt = read_VectorTransform (f);
FileIOReader reader(f);
VectorTransform *vt = read_VectorTransform (&reader);
fclose (f); fclose (f);
return vt; return vt;
} }
......
...@@ -21,17 +21,21 @@ struct Index; ...@@ -21,17 +21,21 @@ struct Index;
struct VectorTransform; struct VectorTransform;
struct IndexIVF; struct IndexIVF;
struct ProductQuantizer; struct ProductQuantizer;
struct IOReader;
struct IOWriter;
void write_index (const Index *idx, FILE *f); void write_index (const Index *idx, FILE *f);
void write_index (const Index *idx, const char *fname); void write_index (const Index *idx, const char *fname);
void write_index (const Index *idx, IOWriter *writer);
const int IO_FLAG_MMAP = 1; const int IO_FLAG_MMAP = 1;
const int IO_FLAG_READ_ONLY = 2; const int IO_FLAG_READ_ONLY = 2;
Index *read_index (FILE * f, int io_flags = 0); Index *read_index (FILE * f, int io_flags = 0);
Index *read_index (const char *fname, int io_flags = 0); Index *read_index (const char *fname, int io_flags = 0);
Index *read_index (IOReader *reader, int io_flags = 0);
void write_VectorTransform (const VectorTransform *vt, const char *fname); void write_VectorTransform (const VectorTransform *vt, const char *fname);
...@@ -55,6 +59,21 @@ struct Cloner { ...@@ -55,6 +59,21 @@ struct Cloner {
virtual ~Cloner() {} virtual ~Cloner() {}
}; };
struct IOReader {
// fread
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) = 0;
virtual ~IOReader() {}
};
struct IOWriter {
// fwrite
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) = 0;
virtual ~IOWriter() {}
};
} }
#endif #endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment