Skip to content

Commit

Permalink
Fix addDataPointBatch where ids != None
Browse files Browse the repository at this point in the history
I introduced a bug in my previous commit, and passing ids to addDataPointBatch
wasn't being respected. Fix this, and add a unittest so that this will be caught
int the future.
  • Loading branch information
Ben Frederickson committed Feb 13, 2018
1 parent 09487c1 commit 22793f2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python_bindings/nmslib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ struct IndexWrapper {
size_t readObjectVector(py::object input, ObjectVector * output,
py::object ids_ = py::none()) {
std::vector<int> ids;
if (!ids_) {
if (!ids_.is_none()) {
ids = py::cast<std::vector<int>>(ids_);
}

Expand Down
12 changes: 12 additions & 0 deletions python_bindings/tests/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ def testKnnQueryBatch(self):
for query, (ids, distances) in zip(queries, results):
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)

# test custom ids (set id to square of each row)
index = self._get_index()
index.addDataPointBatch(data, ids=np.arange(data.shape[0]) ** 2)
index.createIndex()

queries = data[:10]
results = index.knnQueryBatch(queries, k=10)
for query, (ids, distances) in zip(queries, results):
# convert from square back to row id
ids = np.sqrt(ids).astype(int)
self.assertTrue(get_hitrate(get_exact_cosine(query, data), ids) >= 5)

def testReloadIndex(self):
np.random.seed(23)
data = np.random.randn(1000, 10).astype(np.float32)
Expand Down

0 comments on commit 22793f2

Please sign in to comment.