Skip to content

Commit

Permalink
Testing script for #337
Browse files Browse the repository at this point in the history
  • Loading branch information
Leonid Boytsov authored and Leonid Boytsov committed Jul 26, 2018
1 parent c524735 commit ac5febb
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions python_bindings/notebooks/test_hnsw_recall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#!/usr/bin/env python3
import logging
# Uncomment to print logs to the screen
#logging.basicConfig(level=logging.INFO)

import numpy as np
import sys
import nmslib
import time
import math
from sklearn.neighbors import NearestNeighbors


def testHnswRecallL2(dataMatrix, queryMatrix, k, M=30, efC=200, efS=1000, numThreads=4):
queryQty = queryMatrix.shape[0]
indexTimeParams = {'M': M, 'indexThreadQty': numThreads, 'efConstruction': efC, 'post' : 0}

#Indexing
print('Index-time parameters', indexTimeParams)
spaceName='l2'
index = nmslib.init(method='hnsw', space=spaceName, data_type=nmslib.DataType.DENSE_VECTOR)
index.addDataPointBatch(dataMatrix)

start = time.time()
index.createIndex(indexTimeParams)
end = time.time()
print('Indexing time = %f' % (end-start))


# Querying
start = time.time()
nmslibFound = index.knnQueryBatch(queryMatrix, k=k, num_threads=numThreads)
end = time.time()
print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' %
(end - start, float(end - start) / queryQty, numThreads * float(end - start) / queryQty))


# Computing gold-standard data
print('Computing gold-standard data')

start = time.time()
sindx = NearestNeighbors(n_neighbors=k, metric='l2', algorithm='brute').fit(dataMatrix)
end = time.time()

print('Brute-force preparation time %f' % (end - start))

start = time.time()
bruteForceFound = sindx.kneighbors(queryMatrix)
end = time.time()

print('brute-force kNN time total=%f (sec), per query=%f (sec)' %
(end-start, float(end-start)/queryQty) )

# Setting query-time parameters
queryTimeParams = {'efSearch': efS}
print('Setting query-time parameters', queryTimeParams)
index.setQueryTimeParams(queryTimeParams)

# Finally computing recall for every i-th neighbor
for n in range(k):
recall=0.0
for i in range(0, queryQty):
correctSet = set(bruteForceFound[1][i])
retArr = nmslibFound[i][0]
retElem = retArr[n] if len(retArr) > n else -1

recall = recall + int(retElem in correctSet)
recall = recall / queryQty
print('kNN recall for neighbor %d %f' % (n+1, recall))


def testRandom(dataQty, queryQty, efS, dim, k):
queryQty = min(dataQty, queryQty)
dataMatrix = np.random.randn(dataQty, dim).astype(np.float32)
indx = np.random.choice(np.arange(dataQty), size=queryQty, replace=False)
queryMatrix = dataMatrix[indx, ].astype(np.float32)
testHnswRecallL2(dataMatrix, queryMatrix, k, efS=efS)

testRandom(100_000, 10, dim=100, k=10, efS=1000)


0 comments on commit ac5febb

Please sign in to comment.