Commit cd884114 authored by Ailing's avatar Ailing Committed by Matthijs Douze

Make tests compatible with py3 (#348)

parent 0c482e54
...@@ -144,7 +144,7 @@ class TestException(unittest.TestCase): ...@@ -144,7 +144,7 @@ class TestException(unittest.TestCase):
try: try:
# an unsupported operation for IndexFlat # an unsupported operation for IndexFlat
index.add_with_ids(a, b) index.add_with_ids(a, b)
except RuntimeError, e: except RuntimeError as e:
assert 'add_with_ids not implemented' in str(e) assert 'add_with_ids not implemented' in str(e)
else: else:
assert False, 'exception did not fire???' assert False, 'exception did not fire???'
...@@ -153,7 +153,7 @@ class TestException(unittest.TestCase): ...@@ -153,7 +153,7 @@ class TestException(unittest.TestCase):
try: try:
faiss.index_factory(12, 'IVF256,Flat,PQ8') faiss.index_factory(12, 'IVF256,Flat,PQ8')
except RuntimeError, e: except RuntimeError as e:
assert 'could not parse' in str(e) assert 'could not parse' in str(e)
else: else:
assert False, 'exception did not fire???' assert False, 'exception did not fire???'
......
...@@ -12,6 +12,7 @@ import numpy as np ...@@ -12,6 +12,7 @@ import numpy as np
import unittest import unittest
import faiss import faiss
import os import os
import tempfile
def get_dataset_2(d, nb, nt, nq): def get_dataset_2(d, nb, nt, nq):
"""A dataset that is not completely random but still challenging to """A dataset that is not completely random but still challenging to
...@@ -48,26 +49,26 @@ class TestRemove(unittest.TestCase): ...@@ -48,26 +49,26 @@ class TestRemove(unittest.TestCase):
filename = None filename = None
if ondisk: if ondisk:
filename = os.tmpnam() filename = tempfile.mkstemp()[1]
invlists = faiss.OnDiskInvertedLists( invlists = faiss.OnDiskInvertedLists(
index1.nlist, index1.code_size, index1.nlist, index1.code_size,
filename) filename)
index1.replace_invlists(invlists) index1.replace_invlists(invlists)
index1.add(xb[:nb / 2]) index1.add(xb[:int(nb / 2)])
index2 = faiss.IndexIVFFlat(quantizer, d, 20) index2 = faiss.IndexIVFFlat(quantizer, d, 20)
assert index2.is_trained assert index2.is_trained
index2.add(xb[nb / 2:]) index2.add(xb[int(nb / 2):])
Dref, Iref = index1.search(xq, 10) Dref, Iref = index1.search(xq, 10)
index1.merge_from(index2, nb / 2) index1.merge_from(index2, int(nb / 2))
assert index1.ntotal == nb assert index1.ntotal == nb
index1.remove_ids(faiss.IDSelectorRange(nb / 2, nb)) index1.remove_ids(faiss.IDSelectorRange(int(nb / 2), nb))
assert index1.ntotal == nb / 2 assert index1.ntotal == int(nb / 2)
Dnew, Inew = index1.search(xq, 10) Dnew, Inew = index1.search(xq, 10)
assert np.all(Dnew == Dref) assert np.all(Dnew == Dref)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment