Commit c5f21730 authored by matthijs's avatar matthijs

port to python 3

parent 12049dae
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#! /usr/bin/env python2
import libfb.py.mkl # noqa
import numpy as np
from libfb import testutil
import faiss
class TestClustering(testutil.BaseFacebookTestCase):
def test_clustering(self):
d = 64
n = 1000
np.random.seed(123)
x = np.random.random(size=(n, d)).astype('float32')
km = faiss.Kmeans(d, 32, niter=10)
err32 = km.train(x)
# check that objective is decreasing
prev = 1e50
for o in km.obj:
self.assertGreater(prev, o)
prev = o
km = faiss.Kmeans(d, 64, niter=10)
err64 = km.train(x)
# check that 64 centroids give a lower quantization error than 32
self.assertGreater(err32, err64)
class TestPCA(testutil.BaseFacebookTestCase):
def test_pca(self):
d = 64
n = 1000
np.random.seed(123)
x = np.random.random(size=(n, d)).astype('float32')
pca = faiss.PCAMatrix(d, 10)
pca.train(x)
y = pca.apply_py(x)
# check that energy per component is decreasing
column_norm2 = (y**2).sum(0)
prev = 1e50
for o in column_norm2:
self.assertGreater(prev, o)
prev = o
class TestProductQuantizer(testutil.BaseFacebookTestCase):
def test_pq(self):
d = 64
n = 1000
cs = 4
np.random.seed(123)
x = np.random.random(size=(n, d)).astype('float32')
pq = faiss.ProductQuantizer(d, cs, 8)
pq.train(x)
codes = pq.compute_codes(x)
x2 = pq.decode(codes)
diff = ((x - x2)**2).sum()
# print "diff=", diff
# diff= 1807.98
self.assertGreater(2500, diff)
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#! /usr/bin/env python2
"""this is a basic test script that works with fbmake to check if
some simple indices work"""
import libfb.py.mkl # noqa
import numpy as np
import pdb
from libfb import testutil
import faiss
class EvalIVFPQAccuracy(testutil.BaseFacebookTestCase):
def get_dataset(self):
d = 64
nb = 1000
nt = 1500
nq = 200
np.random.seed(123)
xb = np.random.random(size=(nb, d)).astype('float32')
xt = np.random.random(size=(nt, d)).astype('float32')
xq = np.random.random(size=(nq, d)).astype('float32')
return (xt, xb, xq)
def test_IndexIVFPQ(self):
(xt, xb, xq) = self.get_dataset()
d = xt.shape[1]
gt_index = faiss.IndexFlatL2(d)
gt_index.add(xb)
D, gt_nns = gt_index.search(xq, 1)
coarse_quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(coarse_quantizer, d, 25, 16, 8)
index.train(xt)
index.add(xb)
index.nprobe = 5
D, nns = index.search(xq, 10)
n_ok = (nns == gt_nns).sum()
nq = xq.shape[0]
self.assertGreater(n_ok, nq * 0.4)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment