Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
MssBenchmark/ann_benchmarks/algorithms/nmslib.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
113 lines (97 sloc)
4.36 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import os | |
import nmslib | |
from ann_benchmarks.constants import INDEX_DIR | |
from ann_benchmarks.algorithms.base import BaseANN | |
from scipy.sparse import csr_matrix | |
import numpy | |
class NmslibReuseIndex(BaseANN): | |
@staticmethod | |
def encode(d): | |
return ["%s=%s" % (a, b) for (a, b) in d.items()] | |
# For each entry in the sparse matrix, extract a list of IDs and | |
# convert them to a string. Return a list of such strings. | |
@staticmethod | |
def matrToStrArray(sparseMatr): | |
res = [] | |
indptr = sparseMatr.indptr | |
indices = sparseMatr.indices | |
for row in range(sparseMatr.shape[0]): | |
arr = [k for k in indices[indptr[row]: indptr[row + 1]]] | |
arr.sort() | |
res.append(' '.join([str(k) for k in arr])) | |
return res | |
@staticmethod | |
def intMatrToStrArray(intMatr): | |
res = [] | |
for row in range(intMatr.shape[0]): | |
res.append(' '.join([str(k) for k in intMatr[row]])) | |
return res | |
def __init__(self, metric, object_type, method_name, index_param, query_param): | |
self._nmslib_metric = {'angular': 'cosinesimil', 'euclidean': 'l2', 'jaccard': 'jaccard_sparse'}[metric] | |
self._object_type = object_type | |
self._method_name = method_name | |
self._save_index = False | |
self._index_param = NmslibReuseIndex.encode(index_param) | |
if query_param!=False: | |
self._query_param = NmslibReuseIndex.encode(query_param) | |
self.name = 'Nmslib(method_name=%s, index_param=%s, query_param=%s)' % ( | |
self._method_name, self._index_param, self._query_param) | |
else: | |
self._query_param = None | |
self.name = 'Nmslib(method_name=%s, index_param=%s)' % ( | |
self._method_name, self._index_param) | |
self._index_name = os.path.join(INDEX_DIR, "nmslib_%s_%s_%s" % (self._method_name, metric, '_'.join(self._index_param))) | |
d = os.path.dirname(self._index_name) | |
if not os.path.exists(d): | |
os.makedirs(d) | |
def fit(self, X): | |
if self._method_name == 'vptree': | |
# To avoid this issue: | |
# terminate called after throwing an instance of 'std::runtime_error' | |
# what(): The data size is too small or the bucket size is too big. Select the parameters so that <total # of records> is NOT less than <bucket size> * 1000 | |
# Aborted (core dumped) | |
self._index_param.append('bucketSize=%d' % min(int(X.shape[0] * 0.0005), 1000)) | |
if self._nmslib_metric == 'jaccard_sparse': | |
if self._object_type == 'Byte': | |
X_trans = NmslibReuseIndex.matrToStrArray(csr_matrix(X)) | |
else: | |
X_trans = NmslibReuseIndex.intMatrToStrArray(X) | |
else: | |
self._index = nmslib.init(space=self._nmslib_metric, method=self._method_name) | |
self._index.addDataPointBatch(X) | |
if os.path.exists(self._index_name): | |
print('Loading index from file') | |
self._index.loadIndex(self._index_name) | |
else: | |
self._index.createIndex(self._index_param) | |
if self._save_index: | |
self._index.saveIndex(self._index_name) | |
if self._query_param is not None: | |
self._index.setQueryTimeParams(self._query_param) | |
def set_query_arguments(self, ef): | |
if self._method_name == 'hnsw' or self._method_name == 'sw-graph': | |
self._index.setQueryTimeParams(["efSearch=%s"%(ef)]) | |
def query(self, v, n, rq=False): | |
# Chunjiang modified | |
if self._nmslib_metric == 'jaccard_sparse': | |
if self._object_type == 'Byte': | |
nz = numpy.nonzero(v)[0] | |
v = ' '.join([str(k) for k in nz]) | |
else: | |
v = ' '.join([str(k) for k in v]) | |
if rq: | |
ids, distances = self._index.rangeQuery(v, n) | |
else: | |
ids, distances = self._index.knnQuery(v, n) | |
return ids | |
def batch_query(self, X, n): | |
# Chunjiang modified | |
if self._nmslib_metric == 'jaccard_sparse': | |
if self._object_type == 'Byte': | |
X = NmslibReuseIndex.matrToStrArray(csr_matrix(X)) | |
else: | |
X = NmslibReuseIndex.intMatrToStrArray(X) | |
self.res = self._index.knnQueryBatch(X, n) | |
def get_batch_results(self): | |
return [x for x, _ in self.res] |