From a1edf498e8b8c43ab3547230b4ebca780fb5f873 Mon Sep 17 00:00:00 2001 From: searchivarius Date: Sat, 25 May 2019 16:53:12 -0400 Subject: [PATCH] Fixing knnQuery #370 --- python_bindings/nmslib.cc | 2 +- python_bindings/tests/bindings_test.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/python_bindings/nmslib.cc b/python_bindings/nmslib.cc index 5496f5d..950a28c 100644 --- a/python_bindings/nmslib.cc +++ b/python_bindings/nmslib.cc @@ -174,7 +174,7 @@ struct IndexWrapper { const Object * readObject(py::object input, int id = 0) { switch (data_type) { case DATATYPE_DENSE_VECTOR: { - py::array_t temp(input); + py::array_t temp(input); std::vector tempVect(temp.data(0), temp.data(0) + temp.size()); auto vectSpacePtr = reinterpret_cast*>(space.get()); return vectSpacePtr->CreateObjFromVect(id, -1, tempVect); diff --git a/python_bindings/tests/bindings_test.py b/python_bindings/tests/bindings_test.py index 56add76..667e994 100644 --- a/python_bindings/tests/bindings_test.py +++ b/python_bindings/tests/bindings_test.py @@ -24,15 +24,23 @@ def _get_index(self, space='cosinesimil'): def testKnnQuery(self): np.random.seed(23) - data = np.random.randn(1000, 10).astype(np.float32) + data = np.asfortranarray(np.random.randn(1000, 10).astype(np.float32)) index = self._get_index() index.addDataPointBatch(data) index.createIndex() - row = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1.]) - ids, distances = index.knnQuery(row, k=10) - self.assertTrue(get_hitrate(get_exact_cosine(row, data), ids) >= 5) + query = data[0] + + ids, distances = index.knnQuery(query, k=10) + self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) + + # There is a bug when different ways to specify the input query data + # were causing the trouble: https://github.com/nmslib/nmslib/issues/370 + query = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1.]) + + ids, distances = index.knnQuery(query, k=10) + self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) def testKnnQueryBatch(self): np.random.seed(23)