bench_scalar_quantizer.py 3.33 KB
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the BSD+Patents license found in the
# LICENSE file in the root directory of this source tree.

#!/usr/bin/env python2

import time
import numpy as np
import faiss

#################################################################
# I/O functions
#################################################################


def ivecs_read(fname):
    a = np.fromfile(fname, dtype='int32')
    d = a[0]
    return a.reshape(-1, d + 1)[:, 1:].copy()


def fvecs_read(fname):
    return ivecs_read(fname).view('float32')


#################################################################
#  Main program
#################################################################

print "load data"

xt = fvecs_read("sift1M/sift_learn.fvecs")
xb = fvecs_read("sift1M/sift_base.fvecs")
xq = fvecs_read("sift1M/sift_query.fvecs")

# xq = xq[:1000]
# xb = xb[:100000]

nq, d = xq.shape

print "load GT"
gt = ivecs_read("sift1M/sift_groundtruth.ivecs")

# gt = gt[:1000]


ncent = 256

variants = [(name, getattr(faiss.ScalarQuantizer, name))
            for name in dir(faiss.ScalarQuantizer)
            if name.startswith('QT_')]

quantizer = faiss.IndexFlatL2(d)
# quantizer.add(np.zeros((1, d), dtype='float32'))

if False:
    for name, qtype in [('flat', 0)] + variants:

        print "============== test", name
        t0 = time.time()

        if name == 'flat':
            index = faiss.IndexIVFFlat(quantizer, d, ncent,
                                       faiss.METRIC_L2)
        else:
            index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
                                                  qtype, faiss.METRIC_L2)

        index.nprobe = 16
        print "[%.3f s] train" % (time.time() - t0)
        index.train(xt)
        print "[%.3f s] add" % (time.time() - t0)
        index.add(xb)
        print "[%.3f s] search" % (time.time() - t0)
        D, I = index.search(xq, 100)
        print "[%.3f s] eval" % (time.time() - t0)

        for rank in 1, 10, 100:
            n_ok = (I[:, :rank] == gt[:, :1]).sum()
            print "%.4f" % (n_ok / float(nq)),
        print

if True:
    for name, qtype in variants:

        print "============== test", name

        for rsname, vals in [('RS_minmax',
                              [-0.4, -0.2, -0.1, -0.05, 0.0, 0.1, 0.5]),
                             ('RS_meanstd', [0.8, 1.0, 1.5, 2.0, 3.0, 5.0, 10.0]),
                             ('RS_quantiles', [0.02, 0.05, 0.1, 0.15]),
                             ('RS_optim', [0.0])]:
            for val in vals:
                print "%-15s %5g    " % (rsname, val),
                index = faiss.IndexIVFScalarQuantizer(quantizer, d, ncent,
                                                      qtype, faiss.METRIC_L2)
                index.nprobe = 16
                index.sq.rangestat = getattr(faiss.ScalarQuantizer,
                                          rsname)

                index.rangestat_arg = val

                index.train(xt)
                index.add(xb)
                t0 = time.time()
                D, I = index.search(xq, 100)
                t1 = time.time()

                for rank in 1, 10, 100:
                    n_ok = (I[:, :rank] == gt[:, :1]).sum()
                    print "%.4f" % (n_ok / float(nq)),
                print "   %.3f s" % (t1 - t0)