Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
nmslib/python_bindings/tests/bindings_test.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
484 lines (335 sloc)
15.5 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import itertools | |
import tempfile | |
import unittest | |
import shutil | |
import numpy as np | |
import numpy.testing as npt | |
import nmslib | |
import time | |
import os, gc, psutil | |
MEM_TEST_CRIT_FAIL_RATE=0.25 | |
MEM_TEST_REPEAT_QTY1=4 | |
MEM_TEST_ITER1=10 | |
MEM_TEST_REPEAT_QTY2=4 | |
MEM_TEST_ITER2=5 | |
# The key to stable memory testing is using a reasonably large number of points | |
MEM_TEST_DATA_QTY=25000 | |
MEM_TEST_QUERY_QTY=200 | |
MEM_GROW_COEFF=1.5 # This is a bit adhoc but seems to work in practice | |
MEM_TEST_DATA_DIM=4 | |
def get_exact_cosine(row, data, N=10): | |
scores = data.dot(row) / np.linalg.norm(data, axis=-1) | |
best = np.argpartition(scores, -N)[-N:] | |
return sorted(zip(best, scores[best] / np.linalg.norm(row)), key=lambda x: -x[1]) | |
def get_hitrate(ground_truth, ids): | |
return len(set(i for i, _ in ground_truth).intersection(ids)) | |
def bit_vector_to_str(bit_vect): | |
return " ".join(["1" if e else "0" for e in bit_vect]) | |
def bit_vector_sparse_str(bit_vect): | |
return " ".join([str(k) for k, b in enumerate(bit_vect) if b]) | |
class TestCaseBase(unittest.TestCase): | |
# Each result is a tuple (ids, dists) | |
# This version deals properly with ties by resorting the second result set | |
# to be in the same order as the first one | |
def assert_allclose(self, orig, comp): | |
qty = len(orig[0]) | |
self.assertEqual(qty, len(orig[1])) | |
self.assertEqual(qty, len(comp[0])) | |
ids2dist = { comp[0][k] : comp[1][k] for k in range(qty) } | |
comp_resort_ids = [] | |
comp_resort_dists = [] | |
for i in range(qty): | |
one_id = orig[0][i] | |
comp_resort_ids.append(one_id) | |
self.assertTrue(one_id in ids2dist) | |
comp_resort_dists.append(ids2dist[one_id]) | |
npt.assert_allclose(orig, | |
(comp_resort_ids, comp_resort_dists)) | |
class DenseIndexTestMixin(object): | |
def _get_index(self, space='cosinesimil'): | |
raise NotImplementedError() | |
def testKnnQuery(self): | |
np.random.seed(23) | |
data = np.asfortranarray(np.random.randn(1000, 10).astype(np.float32)) | |
index = self._get_index() | |
index.addDataPointBatch(data) | |
index.createIndex() | |
query = data[0] | |
ids, distances = index.knnQuery(query, k=10) | |
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) | |
# There is a bug when different ways to specify the input query data | |
# were causing the trouble: https://github.com/nmslib/nmslib/issues/370 | |
query = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1.]) | |
ids, distances = index.knnQuery(query, k=10) | |
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) | |
def testKnnQueryBatch(self): | |
np.random.seed(23) | |
data = np.random.randn(1000, 10).astype(np.float32) | |
index = self._get_index() | |
index.addDataPointBatch(data) | |
index.createIndex() | |
queries = data[:10] | |
results = index.knnQueryBatch(queries, k=10) | |
for query, (ids, distances) in zip(queries, results): | |
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) | |
# test col-major arrays | |
queries = np.asfortranarray(queries) | |
results = index.knnQueryBatch(queries, k=10) | |
for query, (ids, distances) in zip(queries, results): | |
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) | |
# test custom ids (set id to square of each row) | |
index = self._get_index() | |
index.addDataPointBatch(data, ids=np.arange(data.shape[0]) ** 2) | |
index.createIndex() | |
queries = data[:10] | |
results = index.knnQueryBatch(queries, k=10) | |
for query, (ids, distances) in zip(queries, results): | |
# convert from square back to row id | |
ids = np.sqrt(ids).astype(int) | |
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) | |
def testReloadIndex(self): | |
np.random.seed(23) | |
data = np.random.randn(1000, 10).astype(np.float32) | |
original = self._get_index() | |
original.addDataPointBatch(data) | |
original.createIndex() | |
# test out saving/reloading index | |
temp_dir = tempfile.mkdtemp() | |
temp_file_pref = os.path.join(temp_dir, 'index') | |
for save_data in [0, 1]: | |
original.saveIndex(temp_file_pref, save_data=save_data) | |
reloaded = self._get_index() | |
if save_data == 0: | |
reloaded.addDataPointBatch(data) | |
reloaded.loadIndex(temp_file_pref, load_data=save_data) | |
original_results = original.knnQuery(data[0]) | |
reloaded_results = reloaded.knnQuery(data[0]) | |
self.assert_allclose(original_results, reloaded_results) | |
shutil.rmtree(temp_dir) | |
class BitVectorIndexTestMixin(object): | |
def _get_index(self, space='bit_jaccard'): | |
raise NotImplementedError() | |
def _get_batches(self, index, nbits, num_elems, chunk_size): | |
if "bit_" in str(index): | |
self.bit_vector_str_func = bit_vector_to_str | |
else: | |
self.bit_vector_str_func = bit_vector_sparse_str | |
batches = [] | |
for i in range(0, num_elems, chunk_size): | |
strs = [] | |
for j in range(chunk_size): | |
a = np.random.rand(nbits) > 0.5 | |
strs.append(self.bit_vector_str_func(a)) | |
batches.append([np.arange(i, i + chunk_size), strs]) | |
return batches | |
def testKnnQuery(self): | |
np.random.seed(23) | |
index = self._get_index() | |
batches = self._get_batches(index, 512, 2000, 1000) | |
for ids, data in batches: | |
index.addDataPointBatch(ids=ids, data=data) | |
index.createIndex() | |
s = self.bit_vector_str_func(np.ones(512)) | |
index.knnQuery(s, k=10) | |
def testReloadIndex(self): | |
np.random.seed(23) | |
original = self._get_index() | |
batches = self._get_batches(original, 512, 2000, 1000) | |
for ids, data in batches: | |
original.addDataPointBatch(ids=ids, data=data) | |
original.createIndex() | |
temp_dir = tempfile.mkdtemp() | |
temp_file_pref = os.path.join(temp_dir, 'index') | |
for save_data in [0, 1]: | |
# test out saving/reloading index | |
original.saveIndex(temp_file_pref, save_data=save_data) | |
reloaded = self._get_index() | |
if save_data == 0: | |
for ids, data in batches: | |
reloaded.addDataPointBatch(ids=ids, data=data) | |
reloaded.loadIndex(temp_file_pref, load_data=save_data) | |
s = self.bit_vector_str_func(np.ones(512)) | |
original_results = original.knnQuery(s) | |
reloaded_results = reloaded.knnQuery(s) | |
self.assert_allclose(original_results, reloaded_results) | |
shutil.rmtree(temp_dir) | |
class HNSWTestCase(TestCaseBase, DenseIndexTestMixin): | |
def _get_index(self, space='cosinesimil'): | |
return nmslib.init(method='hnsw', space=space) | |
class BitJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin): | |
def _get_index(self, space='bit_jaccard'): | |
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, | |
dtype=nmslib.DistType.FLOAT) | |
class SparseJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin): | |
def _get_index(self, space='jaccard_sparse'): | |
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, | |
dtype=nmslib.DistType.FLOAT) | |
class BitHammingTestCase(TestCaseBase, BitVectorIndexTestMixin): | |
def _get_index(self, space='bit_hamming'): | |
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, | |
dtype=nmslib.DistType.INT) | |
class SWGraphTestCase(TestCaseBase, DenseIndexTestMixin): | |
def _get_index(self, space='cosinesimil'): | |
return nmslib.init(method='sw-graph', space=space) | |
def testReloadIndex(self): | |
np.random.seed(23) | |
data = np.random.randn(1000, 10).astype(np.float32) | |
original = self._get_index() | |
original.addDataPointBatch(data) | |
original.createIndex() | |
temp_dir = tempfile.mkdtemp() | |
temp_file_pref = os.path.join(temp_dir, 'index') | |
# test out saving/reloading index | |
for save_data in [0, 1]: | |
original.saveIndex(temp_file_pref, save_data=save_data) | |
reloaded = self._get_index() | |
if save_data == 0: | |
reloaded.addDataPointBatch(data) | |
reloaded.loadIndex(temp_file_pref, load_data=save_data) | |
original_results = original.knnQuery(data[0]) | |
reloaded_results = reloaded.knnQuery(data[0]) | |
self.assert_allclose(original_results, reloaded_results) | |
shutil.rmtree(temp_dir) | |
class BallTreeTestCase(TestCaseBase, DenseIndexTestMixin): | |
def _get_index(self, space='cosinesimil'): | |
return nmslib.init(method='vptree', space=space) | |
def testReloadIndex(self): | |
return NotImplemented | |
class StringTestCase(TestCaseBase): | |
def testStringLeven(self): | |
index = nmslib.init(space='leven', | |
dtype=nmslib.DistType.INT, | |
data_type=nmslib.DataType.OBJECT_AS_STRING, | |
method='small_world_rand') | |
strings = [''.join(x) for x in itertools.permutations(['a', 't', 'c', 'g'])] | |
index.addDataPointBatch(strings) | |
index.addDataPoint(len(index), "atat") | |
index.addDataPoint(len(index), "gaga") | |
index.createIndex() | |
for i, distance in zip(*index.knnQuery(strings[0])): | |
self.assertEqual(index.getDistance(0, i), distance) | |
self.assertEqual(len(index), len(strings) + 2) | |
self.assertEqual(index[0], strings[0]) | |
self.assertEqual(index[len(index)-2], 'atat') | |
class SparseTestCase(TestCaseBase): | |
def testSparse(self): | |
index = nmslib.init(method='small_world_rand', space='cosinesimil_sparse', | |
data_type=nmslib.DataType.SPARSE_VECTOR) | |
index.addDataPoint(0, [(1, 2.), (2, 3.)]) | |
index.addDataPoint(1, [(0, 1.), (1, 2.)]) | |
index.addDataPoint(2, [(2, 3.), (3, 3.)]) | |
index.addDataPoint(3, [(3, 1.)]) | |
index.createIndex() | |
ids, distances = index.knnQuery([(1, 2.), (2, 3.)]) | |
self.assertEqual(ids[0], 0) | |
self.assertEqual(distances[0], 0) | |
self.assertEqual(len(index), 4) | |
self.assertEqual(index[3], [(3, 1.0)]) | |
class MemoryLeak1TestCase(TestCaseBase): | |
def testMemoryLeak1(self): | |
process = psutil.Process(os.getpid()) | |
np.random.seed(23) | |
data = np.random.randn(MEM_TEST_DATA_QTY, MEM_TEST_DATA_DIM).astype(np.float32) | |
query = np.random.randn(MEM_TEST_QUERY_QTY, MEM_TEST_DATA_DIM).astype(np.float32) | |
space_name = 'l2' | |
num_threads=4 | |
index_time_params = {'M': 20, | |
'efConstruction': 100, | |
'indexThreadQty': num_threads, | |
'post' : 0, | |
'skip_optimized_index' : 1} # using non-optimized index! | |
query_time_params = {'efSearch': 100} | |
fail_qty = 0 | |
test_qty = 0 | |
delta_first = None | |
gc.collect() | |
time.sleep(0.25) | |
init_mem = process.memory_info().rss | |
temp_dir = tempfile.mkdtemp() | |
temp_file_pref = os.path.join(temp_dir, 'index') | |
for tid in range(MEM_TEST_REPEAT_QTY1): | |
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) | |
index.addDataPointBatch(data) | |
index.createIndex(index_time_params) | |
index.saveIndex(temp_file_pref, save_data=True) | |
index = None | |
gc.collect() | |
for iter_id in range(MEM_TEST_ITER1): | |
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) | |
index.loadIndex(temp_file_pref, load_data=True) | |
index.setQueryTimeParams(query_time_params) | |
if iter_id == 0 and tid == 0: | |
delta_first = process.memory_info().rss - init_mem | |
delta_curr = process.memory_info().rss - init_mem | |
#print('Step %d mem deltas current: %d first: %d ratio %f' % (iter_id, delta_curr, delta_first, float(delta_curr)/max(delta_first, 1))) | |
nbrs = index.knnQueryBatch(query, k = 10, num_threads = num_threads) | |
nbrs = None | |
index = None | |
gc.collect() | |
gc.collect() | |
time.sleep(0.25) | |
delta_last = process.memory_info().rss - init_mem | |
print('Delta first/last %d/%d' % (delta_first, delta_last)) | |
test_qty += 1 | |
if delta_last >= delta_first * MEM_GROW_COEFF: | |
fail_qty += 1 | |
shutil.rmtree(temp_dir) | |
print('') | |
print('Fail qty %d out of %d' % (fail_qty, test_qty)) | |
self.assertTrue(fail_qty < MEM_TEST_ITER1 * MEM_TEST_CRIT_FAIL_RATE) | |
class MemoryLeak2TestCase(TestCaseBase): | |
def testMemoryLeak2(self): | |
process = psutil.Process(os.getpid()) | |
np.random.seed(23) | |
data = np.random.randn(MEM_TEST_DATA_QTY, 10).astype(np.float32) | |
query = np.random.randn(MEM_TEST_QUERY_QTY, 10).astype(np.float32) | |
space_name = 'l2' | |
num_threads=4 | |
index_time_params = {'M': 20, | |
'efConstruction': 100, | |
'indexThreadQty': num_threads, | |
'post' : 0, | |
'skip_optimized_index' : 1} # using non-optimized index! | |
query_time_params = {'efSearch': 100} | |
gc.collect() | |
time.sleep(0.25) | |
init_mem = process.memory_info().rss | |
fail_qty = 0 | |
test_qty = 0 | |
delta_first = None | |
temp_dir = tempfile.mkdtemp() | |
temp_file_pref = os.path.join(temp_dir, 'index') | |
for tid in range(MEM_TEST_REPEAT_QTY2): | |
gc.collect() | |
init_mem = process.memory_info().rss | |
delta1 = None | |
for iter_id in range(MEM_TEST_ITER2): | |
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) | |
index.addDataPointBatch(data) | |
index.createIndex(index_time_params) | |
if iter_id == 0 and tid == 0: | |
delta_first = process.memory_info().rss - init_mem | |
delta_curr = process.memory_info().rss - init_mem | |
#print('Step %d mem deltas current: %d first: %d ratio %f' % (iter_id, delta_curr, delta_first, float(delta_curr)/max(delta_first, 1))) | |
index.setQueryTimeParams(query_time_params) | |
nbrs = index.knnQueryBatch(query, k = 10, num_threads = num_threads) | |
nbrs = None | |
index = None | |
gc.collect() | |
gc.collect() | |
time.sleep(0.25) | |
delta_last = process.memory_info().rss - init_mem | |
#print('Delta last %d' % delta_last) | |
test_qty += 1 | |
if delta_last >= delta_first * MEM_GROW_COEFF: | |
fail_qty += 1 | |
shutil.rmtree(temp_dir) | |
print('') | |
print('Fail qty %d out of %d' % (fail_qty, test_qty)) | |
self.assertTrue(fail_qty < MEM_TEST_ITER2 * MEM_TEST_CRIT_FAIL_RATE) | |
class GlobalTestCase(TestCaseBase): | |
def testGlobal(self): | |
# this is a one line reproduction of https://github.com/nmslib/nmslib/issues/327 | |
GlobalTestCase.index = nmslib.init() | |
if __name__ == "__main__": | |
unittest.main() |