From e203c944b1c9d0cddffb0376cc338ce0179f0bf4 Mon Sep 17 00:00:00 2001 From: searchivarius Date: Sat, 25 May 2019 21:19:15 -0400 Subject: [PATCH] binding tests to deal better with ties #371 --- python_bindings/tests/bindings_test.py | 47 +++++++++++++++++++------- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/python_bindings/tests/bindings_test.py b/python_bindings/tests/bindings_test.py index 61daf70..61920fb 100644 --- a/python_bindings/tests/bindings_test.py +++ b/python_bindings/tests/bindings_test.py @@ -25,10 +25,33 @@ def bit_vector_to_str(bit_vect): def bit_vector_sparse_str(bit_vect): return " ".join([str(k) for k, b in enumerate(bit_vect) if b]) + +class TestCaseBase(unittest.TestCase): + # Each result is a tuple (ids, dists) + # This version deals properly with ties by resorting the second result set + # to be in the same order as the first one + def assert_allclose(self, orig, comp): + qty = len(orig[0]) + self.assertEqual(qty, len(orig[1])) + self.assertEqual(qty, len(comp[0])) + ids2dist = { comp[0][k] : comp[1][k] for k in range(qty) } + + comp_resort_ids = [] + comp_resort_dists = [] + + for i in range(qty): + one_id = orig[0][i] + comp_resort_ids.append(one_id) + self.assertTrue(one_id in ids2dist) + comp_resort_dists.append(ids2dist[one_id]) + + npt.assert_allclose(orig, + (comp_resort_ids, comp_resort_dists)) class DenseIndexTestMixin(object): def _get_index(self, space='cosinesimil'): raise NotImplementedError() + def testKnnQuery(self): np.random.seed(23) @@ -99,7 +122,7 @@ def testReloadIndex(self): original_results = original.knnQuery(data[0]) reloaded_results = reloaded.knnQuery(data[0]) - npt.assert_allclose(original_results, + self.assert_allclose(original_results, reloaded_results) @@ -156,34 +179,34 @@ def testReloadIndex(self): s = self.bit_vector_str_func(np.ones(512)) original_results = original.knnQuery(s) reloaded_results = reloaded.knnQuery(s) - npt.assert_allclose(original_results, + self.assert_allclose(original_results, reloaded_results) -class HNSWTestCase(unittest.TestCase, DenseIndexTestMixin): +class HNSWTestCase(TestCaseBase, DenseIndexTestMixin): def _get_index(self, space='cosinesimil'): return nmslib.init(method='hnsw', space=space) -class BitJaccardTestCase(unittest.TestCase, BitVectorIndexTestMixin): +class BitJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin): def _get_index(self, space='bit_jaccard'): return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, dtype=nmslib.DistType.FLOAT) -class SparseJaccardTestCase(unittest.TestCase, BitVectorIndexTestMixin): +class SparseJaccardTestCase(TestCaseBase, BitVectorIndexTestMixin): def _get_index(self, space='jaccard_sparse'): return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, dtype=nmslib.DistType.FLOAT) -class BitHammingTestCase(unittest.TestCase, BitVectorIndexTestMixin): +class BitHammingTestCase(TestCaseBase, BitVectorIndexTestMixin): def _get_index(self, space='bit_hamming'): return nmslib.init(method='hnsw', space=space, data_type=nmslib.DataType.OBJECT_AS_STRING, dtype=nmslib.DistType.INT) -class SWGraphTestCase(unittest.TestCase, DenseIndexTestMixin): +class SWGraphTestCase(TestCaseBase, DenseIndexTestMixin): def _get_index(self, space='cosinesimil'): return nmslib.init(method='sw-graph', space=space) @@ -205,11 +228,11 @@ def testReloadIndex(self): original_results = original.knnQuery(data[0]) reloaded_results = reloaded.knnQuery(data[0]) - npt.assert_allclose(original_results, + self.assert_allclose(original_results, reloaded_results) -class BallTreeTestCase(unittest.TestCase, DenseIndexTestMixin): +class BallTreeTestCase(TestCaseBase, DenseIndexTestMixin): def _get_index(self, space='cosinesimil'): return nmslib.init(method='vptree', space=space) @@ -217,7 +240,7 @@ def testReloadIndex(self): return NotImplemented -class StringTestCase(unittest.TestCase): +class StringTestCase(TestCaseBase): def testStringLeven(self): index = nmslib.init(space='leven', dtype=nmslib.DistType.INT, @@ -240,7 +263,7 @@ def testStringLeven(self): self.assertEqual(index[len(index)-2], 'atat') -class SparseTestCase(unittest.TestCase): +class SparseTestCase(TestCaseBase): def testSparse(self): index = nmslib.init(method='small_world_rand', space='cosinesimil_sparse', data_type=nmslib.DataType.SPARSE_VECTOR) @@ -260,7 +283,7 @@ def testSparse(self): self.assertEqual(index[3], [(3, 1.0)]) -class GlobalTestCase(unittest.TestCase): +class GlobalTestCase(TestCaseBase): def testGlobal(self): # this is a one line reproduction of https://github.com/nmslib/nmslib/issues/327 GlobalTestCase.index = nmslib.init()