Skip to content

Commit

Permalink
Merge pull request #287 from benfred/pybind_fix_adddatapointbatch
Browse files Browse the repository at this point in the history
Fix addDataPointBatch where ids != None. Many thanks to @benfred !
  • Loading branch information
Leonid Boytsov authored and GitHub committed Feb 13, 2018
2 parents 1f355a8 + 22793f2 commit b9cfcf4
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python_bindings/nmslib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ struct IndexWrapper {
size_t readObjectVector(py::object input, ObjectVector * output,
py::object ids_ = py::none()) {
std::vector<int> ids;
if (!ids_) {
if (!ids_.is_none()) {
ids = py::cast<std::vector<int>>(ids_);
}

Expand Down
12 changes: 12 additions & 0 deletions python_bindings/tests/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def testKnnQueryBatch(self):
for query, (ids, distances) in zip(queries, results):
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)

# test custom ids (set id to square of each row)
index = self._get_index()
index.addDataPointBatch(data, ids=np.arange(data.shape[0]) ** 2)
index.createIndex()

queries = data[:10]
results = index.knnQueryBatch(queries, k=10)
for query, (ids, distances) in zip(queries, results):
# convert from square back to row id
ids = np.sqrt(ids).astype(int)
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)

def testReloadIndex(self):
np.random.seed(23)
data = np.random.randn(1000, 10).astype(np.float32)
Expand Down

0 comments on commit b9cfcf4

Please sign in to comment.