Skip to content

Commit

Permalink
binding tests to deal better with ties #371
Browse files Browse the repository at this point in the history
  • Loading branch information
searchivarius committed May 26, 2019
1 parent 59f49ec commit e203c94
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions python_bindings/tests/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,33 @@ def bit_vector_to_str(bit_vect):
def bit_vector_sparse_str(bit_vect):
return " ".join([str(k) for k, b in enumerate(bit_vect) if b])


class TestCaseBase(unittest.TestCase):
# Each result is a tuple (ids, dists)
# This version deals properly with ties by resorting the second result set
# to be in the same order as the first one
def assert_allclose(self, orig, comp):
qty = len(orig[0])
self.assertEqual(qty, len(orig[1]))
self.assertEqual(qty, len(comp[0]))
ids2dist = { comp[0][k] : comp[1][k] for k in range(qty) }

comp_resort_ids = []
comp_resort_dists = []

for i in range(qty):
one_id = orig[0][i]
comp_resort_ids.append(one_id)
self.assertTrue(one_id in ids2dist)
comp_resort_dists.append(ids2dist[one_id])

npt.assert_allclose(orig,
(comp_resort_ids, comp_resort_dists))

class DenseIndexTestMixin(object):
def _get_index(self, space='cosinesimil'):
raise NotImplementedError()


def testKnnQuery(self):
np.random.seed(23)
Expand Down Expand Up @@ -99,7 +122,7 @@ def testReloadIndex(self):

original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
npt.assert_allclose(original_results,
self.assert_allclose(original_results,
reloaded_results)


Expand Down Expand Up @@ -156,34 +179,34 @@ def testReloadIndex(self):
s = self.bit_vector_str_func(np.ones(512))
original_results = original.knnQuery(s)
reloaded_results = reloaded.knnQuery(s)
npt.assert_allclose(original_results,
self.assert_allclose(original_results,
reloaded_results)


class HNSWTestCase(unittest.TestCase, DenseIndexTestMixin):
class HNSWTestCase(TestCaseBase, DenseIndexTestMixin):
def _get_index(self, space='cosinesimil'):
return nmslib.init(method='hnsw', space=space)


class BitJaccardTestCase(unittest.TestCase, BitVectorIndexTestMixin):
class BitJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin):
def _get_index(self, space='bit_jaccard'):
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING,
dtype=nmslib.DistType.FLOAT)


class SparseJaccardTestCase(unittest.TestCase, BitVectorIndexTestMixin):
class SparseJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin):
def _get_index(self, space='jaccard_sparse'):
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING,
dtype=nmslib.DistType.FLOAT)


class BitHammingTestCase(unittest.TestCase, BitVectorIndexTestMixin):
class BitHammingTestCase(TestCaseBase, BitVectorIndexTestMixin):
def _get_index(self, space='bit_hamming'):
return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING,
dtype=nmslib.DistType.INT)


class SWGraphTestCase(unittest.TestCase, DenseIndexTestMixin):
class SWGraphTestCase(TestCaseBase, DenseIndexTestMixin):
def _get_index(self, space='cosinesimil'):
return nmslib.init(method='sw-graph', space=space)

Expand All @@ -205,19 +228,19 @@ def testReloadIndex(self):

original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
npt.assert_allclose(original_results,
self.assert_allclose(original_results,
reloaded_results)


class BallTreeTestCase(unittest.TestCase, DenseIndexTestMixin):
class BallTreeTestCase(TestCaseBase, DenseIndexTestMixin):
def _get_index(self, space='cosinesimil'):
return nmslib.init(method='vptree', space=space)

def testReloadIndex(self):
return NotImplemented


class StringTestCase(unittest.TestCase):
class StringTestCase(TestCaseBase):
def testStringLeven(self):
index = nmslib.init(space='leven',
dtype=nmslib.DistType.INT,
Expand All @@ -240,7 +263,7 @@ def testStringLeven(self):
self.assertEqual(index[len(index)-2], 'atat')


class SparseTestCase(unittest.TestCase):
class SparseTestCase(TestCaseBase):
def testSparse(self):
index = nmslib.init(method='small_world_rand', space='cosinesimil_sparse',
data_type=nmslib.DataType.SPARSE_VECTOR)
Expand All @@ -260,7 +283,7 @@ def testSparse(self):
self.assertEqual(index[3], [(3, 1.0)])


class GlobalTestCase(unittest.TestCase):
class GlobalTestCase(TestCaseBase):
def testGlobal(self):
# this is a one line reproduction of https://github.com/nmslib/nmslib/issues/327
GlobalTestCase.index = nmslib.init()
Expand Down

0 comments on commit e203c94

Please sign in to comment.