Commit fa91c139 authored by Matthijs Douze's avatar Matthijs Douze

take into account torch offset when getting ptr

parent 5555ae7f
......@@ -11,16 +11,18 @@ import unittest
import faiss
import torch
def swig_ptr_from_FloatTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(x.storage().data_ptr())
return faiss.cast_integer_to_float_ptr(
x.storage().data_ptr() + x.storage_offset() * 4)
def swig_ptr_from_LongTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.int64, 'dtype=%s' % x.dtype
return faiss.cast_integer_to_long_ptr(x.storage().data_ptr())
return faiss.cast_integer_to_long_ptr(
x.storage().data_ptr() + x.storage_offset() * 8)
def search_index_pytorch(index, x, k, D=None, I=None):
......@@ -155,10 +157,15 @@ class PytorchFaissInterop(unittest.TestCase):
assert np.all(I == gt_I)
assert np.all(np.abs(D - gt_D).max() < 1e-4)
# test on subset
D, I = search_raw_array_pytorch(res, xb_t, xq_t[60:80], k)
# back to CPU for verification
D = D.cpu().numpy()
I = I.cpu().numpy()
assert np.all(I == gt_I[60:80])
assert np.all(np.abs(D - gt_D[60:80]).max() < 1e-4)
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment