From 22793f266c1bf33ad0ada52a8242e27868ffeee6 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 13 Feb 2018 10:02:22 -0800 Subject: [PATCH] Fix addDataPointBatch where ids != None I introduced a bug in my previous commit, and passing ids to addDataPointBatch wasn't being respected. Fix this, and add a unittest so that this will be caught int the future. --- python_bindings/nmslib.cc | 2 +- python_bindings/tests/bindings_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python_bindings/nmslib.cc b/python_bindings/nmslib.cc index 1873c38..7bfb0ca 100644 --- a/python_bindings/nmslib.cc +++ b/python_bindings/nmslib.cc @@ -193,7 +193,7 @@ struct IndexWrapper { size_t readObjectVector(py::object input, ObjectVector * output, py::object ids_ = py::none()) { std::vector ids; - if (!ids_) { + if (!ids_.is_none()) { ids = py::cast>(ids_); } diff --git a/python_bindings/tests/bindings_test.py b/python_bindings/tests/bindings_test.py index 8a31dfb..7499275 100644 --- a/python_bindings/tests/bindings_test.py +++ b/python_bindings/tests/bindings_test.py @@ -53,6 +53,18 @@ def testKnnQueryBatch(self): for query, (ids, distances) in zip(queries, results): self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) + # test custom ids (set id to square of each row) + index = self._get_index() + index.addDataPointBatch(data, ids=np.arange(data.shape[0]) ** 2) + index.createIndex() + + queries = data[:10] + results = index.knnQueryBatch(queries, k=10) + for query, (ids, distances) in zip(queries, results): + # convert from square back to row id + ids = np.sqrt(ids).astype(int) + self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5) + def testReloadIndex(self): np.random.seed(23) data = np.random.randn(1000, 10).astype(np.float32)