Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
MssBenchmark/ann_benchmarks/algorithms/bruteforce.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
148 lines (129 sloc)
6.33 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import absolute_import | |
import numpy | |
import sklearn.neighbors | |
from ann_benchmarks.distance import metrics as pd | |
from ann_benchmarks.algorithms.base import BaseANN | |
from scipy.sparse import issparse | |
import chemfp | |
from bitarray import bitarray | |
class BruteForce(BaseANN): | |
def __init__(self, metric): | |
if metric not in ('angular', 'euclidean', 'hamming', 'jaccard'): | |
raise NotImplementedError("BruteForce doesn't support metric %s" % metric) | |
self._metric = metric | |
self.name = 'BruteForce()' | |
def fit(self, X): | |
metric = {'angular': 'cosine', 'euclidean': 'l2', 'hamming': 'hamming', 'jaccard' : 'jaccard'}[self._metric] | |
self._nbrs = sklearn.neighbors.NearestNeighbors(algorithm='brute', metric=metric) | |
self._nbrs.fit(X) | |
def query(self, v, n): | |
return list(self._nbrs.kneighbors([v], | |
return_distance = False, n_neighbors = n)[0]) | |
def query_with_distances(self, v, n): | |
(distances, positions) = self._nbrs.kneighbors([v], | |
return_distance = True, n_neighbors = n) | |
return zip(list(positions[0]), list(distances[0])) | |
class BruteForceBLAS(BaseANN): | |
"""kNN search that uses a linear scan = brute force.""" | |
def __init__(self, metric, precision=numpy.float32): | |
if metric not in ('angular', 'euclidean', 'hamming', 'jaccard'): | |
raise NotImplementedError("BruteForceBLAS doesn't support metric %s" % metric) | |
elif metric == 'hamming' and precision != numpy.bool: | |
raise NotImplementedError("BruteForceBLAS doesn't support precision %s with Hamming distances" % precision) | |
self._metric = metric | |
self._precision = precision | |
self.name = 'BruteForceBLAS()' | |
def fit(self, X): | |
"""Initialize the search index.""" | |
if self._metric == 'angular': | |
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector | |
X /= numpy.sqrt(lens)[..., numpy.newaxis] # normalize index vectors to unit length | |
self.index = numpy.ascontiguousarray(X, dtype=self._precision) | |
elif self._metric == 'hamming': | |
# Regarding bitvectors as vectors in l_2 is faster for blas | |
X = X.astype(numpy.float32) | |
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector | |
self.index = numpy.ascontiguousarray(X, dtype=numpy.float32) | |
self.lengths = numpy.ascontiguousarray(lens, dtype=numpy.float32) | |
elif self._metric == 'euclidean': | |
lens = (X ** 2).sum(-1) # precompute (squared) length of each vector | |
self.index = numpy.ascontiguousarray(X, dtype=self._precision) | |
self.lengths = numpy.ascontiguousarray(lens, dtype=self._precision) | |
elif self._metric == 'jaccard': | |
self.index = X | |
else: | |
assert False, "invalid metric" # shouldn't get past the constructor! | |
def query(self, v, n): | |
return [index for index, _ in self.query_with_distances(v, n)] | |
def query_with_distances(self, v, n): | |
"""Find indices of `n` most similar vectors from the index to query vector `v`.""" | |
if self._metric != 'jaccard': | |
# use same precision for query as for index | |
v = numpy.ascontiguousarray(v, dtype = self.index.dtype) | |
# HACK we ignore query length as that's a constant not affecting the final ordering | |
if self._metric == 'angular': | |
# argmax_a cossim(a, b) = argmax_a dot(a, b) / |a||b| = argmin_a -dot(a, b) | |
dists = -numpy.dot(self.index, v) | |
elif self._metric == 'euclidean': | |
# argmin_a (a - b)^2 = argmin_a a^2 - 2ab + b^2 = argmin_a a^2 - 2ab | |
dists = self.lengths - 2 * numpy.dot(self.index, v) | |
elif self._metric == 'hamming': | |
# Just compute hamming distance using euclidean distance | |
dists = self.lengths - 2 * numpy.dot(self.index, v) | |
elif self._metric == 'jaccard': | |
if issparse(self.index): | |
dists = [pd[self._metric]['distance'](v, e.toarray()[0]) for e in self.index] | |
else: | |
dists = [pd[self._metric]['distance'](v, e) for e in self.index] | |
else: | |
assert False, "invalid metric" # shouldn't get past the constructor! | |
nearest_indices = numpy.argpartition(dists, n)[:n] # partition-sort by distance, get `n` closest | |
indices = [idx for idx in nearest_indices if pd[self._metric]["distance_valid"](dists[idx])] | |
def fix(index): | |
if issparse(self.index): | |
ep = self.index[index].toarray()[0] | |
else: | |
ep = self.index[index] | |
ev = v | |
return (index, pd[self._metric]['distance'](ep, ev)) | |
return map(fix, indices) | |
class BruteForceFPS(BaseANN): | |
def __init__(self, metric): | |
if metric != 'jaccard': | |
raise NotImplementedError("BruteForce doesn't support metric %s" % metric) | |
self._metric = metric | |
self.name = 'BruteForceFPS()' | |
@staticmethod | |
def matrToArena(X): | |
# convert X to Chemfp fingerprintArena in memory | |
fps = [] | |
for row in range(X.shape[0]): | |
fp = bitarray(endian='big') | |
fp.extend(X[row]) | |
fps.append((row,fp.tobytes())) | |
return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1]), reorder=False) | |
def pre_fit(self, X): | |
self._fps = [] | |
for row in range(X.shape[0]): | |
fp = bitarray(endian='big') | |
fp.extend(X[row]) | |
self._fps.append((row,fp.tobytes())) | |
def fit(self, X): | |
self._target = chemfp.load_fingerprints(self._fps,chemfp.Metadata(num_bits=X.shape[1]), reorder=False) | |
# To ensure that BitBound is not used | |
self._target.popcount_indices = "" | |
def pre_query(self, v, n): | |
queryMatr = numpy.array([v]) | |
self._queries = BruteForceFPS.matrToArena(queryMatr) | |
def query(self, v, n, rq=False): | |
if rq: | |
self._results = chemfp.threshold_tanimoto_search(self._queries, self._target, threshold=1.0-n) | |
else: | |
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0) | |
def post_query(self, rq=False): | |
# parse the results | |
for (query_id, hits) in self._results: | |
if hits: | |
return hits.get_ids() | |
else: | |
return [] |