Commit abe2b0fd authored by dengos's avatar dengos Committed by Matthijs Douze

read/write index with std::function wrapper (#427)

* add access function to IndexIVF;

* - access for IndexIVF;
- write_index/read_index with std::function<...>;

* - fix test compile on mac;
- adjust write/read with std::function;

* replace std::function with IOReader/IOWriter;

* remove IndexIVF::access // tmp

* PFN_WRITE/READ => WRITE;

* revert mac compile fix;

* rename;

* fix compile;

* reset CMakeList;

* format; remove unused function/header;
parent 433f5c0f
......@@ -590,7 +590,6 @@ void IndexIVF::copy_subset_to (IndexIVF & other, int subset_type,
}
IndexIVF::~IndexIVF()
{
if (own_invlists) {
......
......@@ -73,14 +73,14 @@ static uint32_t fourcc (const char sx[4]) {
**************************************************************/
#define WRITEANDCHECK(ptr, n) { \
size_t ret = fwrite (ptr, sizeof (* (ptr)), n, f); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "write error"); \
#define WRITEANDCHECK(ptr, n) { \
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG(ret == (n), "write error"); \
}
#define READANDCHECK(ptr, n) { \
size_t ret = fread (ptr, sizeof (* (ptr)), n, f); \
FAISS_THROW_IF_NOT_MSG (ret == (n), "read error"); \
#define READANDCHECK(ptr, n) { \
size_t ret = (*f)(ptr, sizeof(*(ptr)), n); \
FAISS_THROW_IF_NOT_MSG(ret == (n), "read error"); \
}
#define WRITE1(x) WRITEANDCHECK(&(x), 1)
......@@ -106,15 +106,41 @@ struct ScopeFileCloser {
~ScopeFileCloser () {fclose (f); }
};
namespace {
struct FileIOReader: IOReader {
FILE *f = nullptr;
FileIOReader(FILE *rf): f(rf) {}
~FileIOReader() = default;
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) override {
return fread(ptr, size, nitems, f);
}
};
struct FileIOWriter: IOWriter {
FILE *f = nullptr;
FileIOWriter(FILE *wf): f(wf) {}
~FileIOWriter() = default;
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) override {
return fwrite(ptr, size, nitems, f);
}
};
} // namespace
/*************************************************************
* Write
**************************************************************/
static void write_index_header (const Index *idx, FILE *f) {
static void write_index_header (const Index *idx, IOWriter *f) {
WRITE1 (idx->d);
WRITE1 (idx->ntotal);
Index::idx_t dummy = 1 << 20;
......@@ -124,7 +150,7 @@ static void write_index_header (const Index *idx, FILE *f) {
WRITE1 (idx->metric_type);
}
void write_VectorTransform (const VectorTransform *vt, FILE *f) {
void write_VectorTransform (const VectorTransform *vt, IOWriter *f) {
if (const LinearTransform * lt =
dynamic_cast < const LinearTransform *> (vt)) {
if (dynamic_cast<const RandomRotationMatrix *>(lt)) {
......@@ -167,14 +193,16 @@ void write_VectorTransform (const VectorTransform *vt, FILE *f) {
WRITE1 (vt->is_trained);
}
static void write_ProductQuantizer (const ProductQuantizer *pq, FILE *f) {
static void write_ProductQuantizer (
const ProductQuantizer *pq, IOWriter *f) {
WRITE1 (pq->d);
WRITE1 (pq->M);
WRITE1 (pq->nbits);
WRITEVECTOR (pq->centroids);
}
static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
static void write_ScalarQuantizer (
const ScalarQuantizer *ivsc, IOWriter *f) {
WRITE1 (ivsc->qtype);
WRITE1 (ivsc->rangestat);
WRITE1 (ivsc->rangestat_arg);
......@@ -183,7 +211,7 @@ static void write_ScalarQuantizer (const ScalarQuantizer *ivsc, FILE *f) {
WRITEVECTOR (ivsc->trained);
}
static void write_InvertedLists (const InvertedLists *ils, FILE *f) {
static void write_InvertedLists (const InvertedLists *ils, IOWriter *f) {
if (ils == nullptr) {
uint32_t h = fourcc ("il00");
WRITE1 (h);
......@@ -258,10 +286,12 @@ void write_ProductQuantizer (const ProductQuantizer*pq, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f);
write_ProductQuantizer (pq, f);
FileIOWriter writer(f);
write_ProductQuantizer (pq, &writer);
}
static void write_HNSW (const HNSW *hnsw, FILE *f) {
static void write_HNSW (const HNSW *hnsw, IOWriter *f) {
WRITEVECTOR (hnsw->assign_probas);
WRITEVECTOR (hnsw->cum_nneighbor_per_level);
......@@ -274,10 +304,9 @@ static void write_HNSW (const HNSW *hnsw, FILE *f) {
WRITE1 (hnsw->efConstruction);
WRITE1 (hnsw->efSearch);
WRITE1 (hnsw->upper_beam);
}
static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
static void write_ivf_header (const IndexIVF *ivf, IOWriter *f) {
write_index_header (ivf, f);
WRITE1 (ivf->nlist);
WRITE1 (ivf->nprobe);
......@@ -286,7 +315,7 @@ static void write_ivf_header (const IndexIVF * ivf, FILE *f) {
WRITEVECTOR (ivf->direct_map);
}
void write_index (const Index *idx, FILE *f) {
void write_index (const Index *idx, IOWriter *f) {
if (const IndexFlat * idxf = dynamic_cast<const IndexFlat *> (idx)) {
uint32_t h = fourcc (
idxf->metric_type == METRIC_INNER_PRODUCT ? "IxFI" :
......@@ -418,6 +447,11 @@ void write_index (const Index *idx, FILE *f) {
}
}
void write_index (const Index *idx, FILE *f) {
FileIOWriter writer(f);
write_index(idx, &writer);
}
void write_index (const Index *idx, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
......@@ -429,14 +463,16 @@ void write_VectorTransform (const VectorTransform *vt, const char *fname) {
FILE *f = fopen (fname, "w");
FAISS_THROW_IF_NOT_FMT (f, "cannot open %s for writing", fname);
ScopeFileCloser closer(f);
write_VectorTransform (vt, f);
FileIOWriter writer(f);
write_VectorTransform (vt, &writer);
}
/*************************************************************
* Read
**************************************************************/
static void read_index_header (Index *idx, FILE *f) {
static void read_index_header (Index *idx, IOReader *f) {
READ1 (idx->d);
READ1 (idx->ntotal);
Index::idx_t dummy;
......@@ -447,7 +483,7 @@ static void read_index_header (Index *idx, FILE *f) {
idx->verbose = false;
}
VectorTransform* read_VectorTransform (FILE *f) {
VectorTransform* read_VectorTransform (IOReader *f) {
uint32_t h;
READ1 (h);
VectorTransform *vt = nullptr;
......@@ -497,7 +533,7 @@ VectorTransform* read_VectorTransform (FILE *f) {
static void read_ArrayInvertedLists_sizes (
FILE *f, std::vector<size_t> & sizes)
IOReader *f, std::vector<size_t> & sizes)
{
size_t nlist = sizes.size();
uint32_t list_type;
......@@ -518,8 +554,7 @@ static void read_ArrayInvertedLists_sizes (
}
}
InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
InvertedLists *read_InvertedLists (IOReader *f, int io_flags) {
uint32_t h;
READ1 (h);
if (h == fourcc ("il00")) {
......@@ -545,6 +580,10 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
}
return ails;
} else if (h == fourcc ("ilar") && (io_flags & IO_FLAG_MMAP)) {
auto impl = dynamic_cast<FileIOReader*>(f);
FAISS_THROW_IF_NOT(NULL != impl);
FILE *raw_f = impl->f;
auto ails = new OnDiskInvertedLists ();
READ1 (ails->nlist);
READ1 (ails->code_size);
......@@ -552,16 +591,16 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
ails->lists.resize (ails->nlist);
std::vector<size_t> sizes (ails->nlist);
read_ArrayInvertedLists_sizes (f, sizes);
size_t o0 = ftell (f), o = o0;
size_t o0 = ftell (raw_f), o = o0;
{ // do the mmap
struct stat buf;
int ret = fstat (fileno(f), &buf);
int ret = fstat (fileno(raw_f), &buf);
FAISS_THROW_IF_NOT_FMT (ret == 0,
"fstat failed: %s", strerror(errno));
ails->totsize = buf.st_size;
ails->ptr = (uint8_t*)mmap (nullptr, ails->totsize,
PROT_READ, MAP_SHARED,
fileno (f), 0);
fileno (raw_f), 0);
FAISS_THROW_IF_NOT_FMT (ails->ptr != MAP_FAILED,
"could not mmap: %s",
strerror(errno));
......@@ -574,7 +613,7 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
ails->code_size);
}
// resume normal reading of file
fseek (f, o, SEEK_SET);
fseek (raw_f, o, SEEK_SET);
return ails;
} else if (h == fourcc ("ilod")) {
OnDiskInvertedLists *od = new OnDiskInvertedLists();
......@@ -601,24 +640,24 @@ InvertedLists *read_InvertedLists (FILE *f, int io_flags) {
}
}
static void read_InvertedLists (IndexIVF *ivf, FILE *f, int io_flags) {
static void read_InvertedLists (
IndexIVF *ivf, IOReader *f, int io_flags) {
InvertedLists *ils = read_InvertedLists (f, io_flags);
FAISS_THROW_IF_NOT (ils->nlist == ivf->nlist &&
ils->code_size == ivf->code_size);
ivf->invlists = ils;
ivf->own_invlists = true;
}
static void read_ProductQuantizer (ProductQuantizer *pq, FILE *f) {
static void read_ProductQuantizer (ProductQuantizer *pq, IOReader *f) {
READ1 (pq->d);
READ1 (pq->M);
READ1 (pq->nbits);
pq->set_derived_values ();
READVECTOR (pq->centroids);
}
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
static void read_ScalarQuantizer (ScalarQuantizer *ivsc, IOReader *f) {
READ1 (ivsc->qtype);
READ1 (ivsc->rangestat);
READ1 (ivsc->rangestat_arg);
......@@ -628,7 +667,7 @@ static void read_ScalarQuantizer (ScalarQuantizer *ivsc, FILE *f) {
}
static void read_HNSW (HNSW *hnsw, FILE *f) {
static void read_HNSW (HNSW *hnsw, IOReader *f) {
READVECTOR (hnsw->assign_probas);
READVECTOR (hnsw->cum_nneighbor_per_level);
READVECTOR (hnsw->levels);
......@@ -648,14 +687,16 @@ ProductQuantizer * read_ProductQuantizer (const char*fname) {
ScopeFileCloser closer(f);
ProductQuantizer *pq = new ProductQuantizer();
ScopeDeleter1<ProductQuantizer> del (pq);
read_ProductQuantizer(pq, f);
FileIOReader reader(f);
read_ProductQuantizer(pq, &reader);
del.release ();
return pq;
}
static void read_ivf_header (
IndexIVF * ivf, FILE *f,
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
IndexIVF *ivf, IOReader *f,
std::vector<std::vector<Index::idx_t> > *ids = nullptr)
{
read_index_header (ivf, f);
READ1 (ivf->nlist);
......@@ -683,7 +724,7 @@ static ArrayInvertedLists *set_array_invlist(
return ail;
}
static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
static IndexIVFPQ *read_ivfpq (IOReader *f, uint32_t h, int io_flags)
{
bool legacy = h == fourcc ("IvQR") || h == fourcc ("IvPQ");
......@@ -720,7 +761,7 @@ static IndexIVFPQ *read_ivfpq (FILE *f, uint32_t h, int io_flags)
int read_old_fmt_hack = 0;
Index *read_index (FILE * f, int io_flags) {
Index *read_index (IOReader *f, int io_flags) {
Index * idx = nullptr;
uint32_t h;
READ1 (h);
......@@ -913,6 +954,10 @@ Index *read_index (FILE * f, int io_flags) {
}
Index *read_index (FILE * f, int io_flags) {
FileIOReader reader(f);
return read_index(&reader, io_flags);
}
Index *read_index (const char *fname, int io_flags) {
FILE *f = fopen (fname, "r");
......@@ -929,7 +974,9 @@ VectorTransform *read_VectorTransform (const char *fname) {
perror ("");
abort ();
}
VectorTransform *vt = read_VectorTransform (f);
FileIOReader reader(f);
VectorTransform *vt = read_VectorTransform (&reader);
fclose (f);
return vt;
}
......
......@@ -21,17 +21,21 @@ struct Index;
struct VectorTransform;
struct IndexIVF;
struct ProductQuantizer;
struct IOReader;
struct IOWriter;
void write_index (const Index *idx, FILE *f);
void write_index (const Index *idx, const char *fname);
void write_index (const Index *idx, IOWriter *writer);
const int IO_FLAG_MMAP = 1;
const int IO_FLAG_READ_ONLY = 2;
Index *read_index (FILE * f, int io_flags = 0);
Index *read_index (const char *fname, int io_flags = 0);
Index *read_index (IOReader *reader, int io_flags = 0);
void write_VectorTransform (const VectorTransform *vt, const char *fname);
......@@ -55,6 +59,21 @@ struct Cloner {
virtual ~Cloner() {}
};
struct IOReader {
// fread
virtual size_t operator()(
void *ptr, size_t size, size_t nitems) = 0;
virtual ~IOReader() {}
};
struct IOWriter {
// fwrite
virtual size_t operator()(
const void *ptr, size_t size, size_t nitems) = 0;
virtual ~IOWriter() {}
};
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment