Skip to content
Permalink
0e81a7b148
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
113 lines (97 sloc) 4.36 KB
from __future__ import absolute_import
import os
import nmslib
from ann_benchmarks.constants import INDEX_DIR
from ann_benchmarks.algorithms.base import BaseANN
from scipy.sparse import csr_matrix
import numpy
class NmslibReuseIndex(BaseANN):
@staticmethod
def encode(d):
return ["%s=%s" % (a, b) for (a, b) in d.items()]
# For each entry in the sparse matrix, extract a list of IDs and
# convert them to a string. Return a list of such strings.
@staticmethod
def matrToStrArray(sparseMatr):
res = []
indptr = sparseMatr.indptr
indices = sparseMatr.indices
for row in range(sparseMatr.shape[0]):
arr = [k for k in indices[indptr[row]: indptr[row + 1]]]
arr.sort()
res.append(' '.join([str(k) for k in arr]))
return res
@staticmethod
def intMatrToStrArray(intMatr):
res = []
for row in range(intMatr.shape[0]):
res.append(' '.join([str(k) for k in intMatr[row]]))
return res
def __init__(self, metric, object_type, method_name, index_param, query_param):
self._nmslib_metric = {'angular': 'cosinesimil', 'euclidean': 'l2', 'jaccard': 'jaccard_sparse'}[metric]
self._object_type = object_type
self._method_name = method_name
self._save_index = False
self._index_param = NmslibReuseIndex.encode(index_param)
if query_param!=False:
self._query_param = NmslibReuseIndex.encode(query_param)
self.name = 'Nmslib(method_name=%s, index_param=%s, query_param=%s)' % (
self._method_name, self._index_param, self._query_param)
else:
self._query_param = None
self.name = 'Nmslib(method_name=%s, index_param=%s)' % (
self._method_name, self._index_param)
self._index_name = os.path.join(INDEX_DIR, "nmslib_%s_%s_%s" % (self._method_name, metric, '_'.join(self._index_param)))
d = os.path.dirname(self._index_name)
if not os.path.exists(d):
os.makedirs(d)
def fit(self, X):
if self._method_name == 'vptree':
# To avoid this issue:
# terminate called after throwing an instance of 'std::runtime_error'
# what(): The data size is too small or the bucket size is too big. Select the parameters so that <total # of records> is NOT less than <bucket size> * 1000
# Aborted (core dumped)
self._index_param.append('bucketSize=%d' % min(int(X.shape[0] * 0.0005), 1000))
if self._nmslib_metric == 'jaccard_sparse':
if self._object_type == 'Byte':
X_trans = NmslibReuseIndex.matrToStrArray(csr_matrix(X))
else:
X_trans = NmslibReuseIndex.intMatrToStrArray(X)
else:
self._index = nmslib.init(space=self._nmslib_metric, method=self._method_name)
self._index.addDataPointBatch(X)
if os.path.exists(self._index_name):
print('Loading index from file')
self._index.loadIndex(self._index_name)
else:
self._index.createIndex(self._index_param)
if self._save_index:
self._index.saveIndex(self._index_name)
if self._query_param is not None:
self._index.setQueryTimeParams(self._query_param)
def set_query_arguments(self, ef):
if self._method_name == 'hnsw' or self._method_name == 'sw-graph':
self._index.setQueryTimeParams(["efSearch=%s"%(ef)])
def query(self, v, n, rq=False):
# Chunjiang modified
if self._nmslib_metric == 'jaccard_sparse':
if self._object_type == 'Byte':
nz = numpy.nonzero(v)[0]
v = ' '.join([str(k) for k in nz])
else:
v = ' '.join([str(k) for k in v])
if rq:
ids, distances = self._index.rangeQuery(v, n)
else:
ids, distances = self._index.knnQuery(v, n)
return ids
def batch_query(self, X, n):
# Chunjiang modified
if self._nmslib_metric == 'jaccard_sparse':
if self._object_type == 'Byte':
X = NmslibReuseIndex.matrToStrArray(csr_matrix(X))
else:
X = NmslibReuseIndex.intMatrToStrArray(X)
self.res = self._index.knnQueryBatch(X, n)
def get_batch_results(self):
return [x for x, _ in self.res]