Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
update unit test
  • Loading branch information
Greg Friedland committed Feb 21, 2019
1 parent 03071c2 commit 56aa99c
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 306 deletions.
144 changes: 55 additions & 89 deletions python_bindings/tests/bindings_test.py
Expand Up @@ -218,94 +218,60 @@ class BitVectorIndexTestMixin(object):
def _get_index(self, space='bit_jaccard'):
raise NotImplementedError()

def testKnnQuery(self):
for num_elems in [30000, 100000, 300000, 1000000]:
for nbits in [512, 2048]:
self._testKnnQuery(nbits, num_elems)
def _get_batches(self, index, nbits, num_elems, chunk_size):
if "bit_" in str(index):
self.bit_vector_str_func = bit_vector_to_str
else:
self.bit_vector_str_func = bit_vector_sparse_str

batches = []
for i in range(0, num_elems, chunk_size):
strs = []
for j in range(chunk_size):
a = np.random.rand(nbits) > 0.5
strs.append(self.bit_vector_str_func(a))
batches.append([np.arange(i, i + chunk_size), strs])
return batches

def _testKnnQuery(self, nbits, num_elems):
chunk_size = 10000
def testKnnQuery(self):
np.random.seed(23)

ps_proc = psutil.Process()
# print(f"\n{ps_proc.memory_info()}")
index = self._get_index()
if "bit_jaccard" in str(index):
bit_vector_str_func = bit_vector_to_str
else:
bit_vector_str_func = bit_vector_sparse_str

# logging.basicConfig(level=logging.INFO)
# with PsUtil(interval=2, proc_attr=["memory_info"]):
with PeakMemoryUsage(f"AddData: vector={nbits}-bit elems={num_elems}"):
np.random.seed(23)
for i in range(0, num_elems, chunk_size):
strs = []
for j in range(chunk_size):
a = np.random.rand(nbits) > 0.5
strs.append(bit_vector_str_func(a))
index.addDataPointBatch(ids=np.arange(i, i + chunk_size), data=strs)

# print(f"\n{ps_proc.memory_info()}")
with PeakMemoryUsage(f"CreateIndex: vector={nbits}-bit of elems={num_elems}"):
index.createIndex()
# print(f"\n{ps_proc.memory_info()}")

a = np.ones(nbits)
ids, distances = index.knnQuery(bit_vector_str_func(a), k=10)
# print(ids)
print(distances)
# self.assertTrue(get_hitrate(get_exact_cosine(row, data), ids) >= 5)
# def testKnnQueryBatch(self):
# np.random.seed(23)
# data = np.random.randn(1000, 10).astype(np.float32)
#
# index = self._get_index()
# index.addDataPointBatch(data)
# index.createIndex()
#
# queries = data[:10]
# results = index.knnQueryBatch(queries, k=10)
# for query, (ids, distances) in zip(queries, results):
# self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)
#
# # test col-major arrays
# queries = np.asfortranarray(queries)
# results = index.knnQueryBatch(queries, k=10)
# for query, (ids, distances) in zip(queries, results):
# self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)
#
# # test custom ids (set id to square of each row)
# index = self._get_index()
# index.addDataPointBatch(data, ids=np.arange(data.shape[0]) ** 2)
# index.createIndex()
#
# queries = data[:10]
# results = index.knnQueryBatch(queries, k=10)
# for query, (ids, distances) in zip(queries, results):
# # convert from square back to row id
# ids = np.sqrt(ids).astype(int)
# self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)

# def testReloadIndex(self):
# np.random.seed(23)
# data = np.random.randn(1000, 10).astype(np.float32)
#
# original = self._get_index()
# original.addDataPointBatch(data)
# original.createIndex()
#
# # test out saving/reloading index
# with tempfile.NamedTemporaryFile() as tmp:
# original.saveIndex(tmp.name + ".index")
#
# reloaded = self._get_index()
# reloaded.addDataPointBatch(data)
# reloaded.loadIndex(tmp.name + ".index")
#
# original_results = original.knnQuery(data[0])
# reloaded_results = reloaded.knnQuery(data[0])
# npt.assert_allclose(original_results,
# reloaded_results)
batches = self._get_batches(index, 512, 2000, 1000)
for ids, data in batches:
index.addDataPointBatch(ids=ids, data=data)

index.createIndex()

s = self.bit_vector_str_func(np.ones(512))
index.knnQuery(s, k=10)

def testReloadIndex(self):
np.random.seed(23)

original = self._get_index()
batches = self._get_batches(original, 512, 2000, 1000)
for ids, data in batches:
original.addDataPointBatch(ids=ids, data=data)
original.createIndex()

# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index")

reloaded = self._get_index()
for ids, data in batches:
reloaded.addDataPointBatch(ids=ids, data=data)
reloaded.loadIndex(tmp.name + ".index")

s = self.bit_vector_str_func(np.ones(512))
original_results = original.knnQuery(s)
reloaded_results = reloaded.knnQuery(s)
original_results = list(zip(list(original_results[0]), list(original_results[1])))
original_results = sorted(original_results, key=lambda x: x[1])
reloaded_results = list(zip(list(reloaded_results[0]), list(reloaded_results[1])))
reloaded_results = sorted(reloaded_results, key=lambda x: x[1])
npt.assert_allclose(original_results, reloaded_results)


class HNSWTestCase(unittest.TestCase, DenseIndexTestMixin):
Expand All @@ -325,10 +291,10 @@ class SparseJaccardTestCase(unittest.TestCase, BitVectorIndexTestMixin):
dtype=nmslib.DistType.FLOAT)


# class BitHammingTestCase(unittest.TestCase, BitVectorIndexTestMixin):
# def _get_index(self, space='bit_hamming'):
# return nmslib.init(method='hnsw', space='bit_hamming', data_type=nmslib.DataType.OBJECT_AS_STRING,
# dtype=nmslib.DistType.INT)
class BitHammingTestCase(unittest.TestCase, BitVectorIndexTestMixin):
def _get_index(self, space='bit_hamming'):
return nmslib.init(method='hnsw', space='bit_hamming', data_type=nmslib.DataType.OBJECT_AS_STRING,
dtype=nmslib.DistType.INT)


class SWGraphTestCase(unittest.TestCase, DenseIndexTestMixin):
Expand Down
167 changes: 0 additions & 167 deletions python_bindings/tests/jaccard_comparison.py

This file was deleted.

11 changes: 0 additions & 11 deletions python_bindings/tests/jaccard_comparison.sh

This file was deleted.

39 changes: 0 additions & 39 deletions python_bindings/tests/jaccard_comparison_plot.py

This file was deleted.

0 comments on commit 56aa99c

Please sign in to comment.