diff --git a/ann_benchmarks/algorithms/folding.py b/ann_benchmarks/algorithms/folding.py new file mode 100644 index 0000000..7573f28 --- /dev/null +++ b/ann_benchmarks/algorithms/folding.py @@ -0,0 +1,45 @@ +from __future__ import absolute_import +import chemfp +from ann_benchmarks.algorithms.base import BaseANN +from scipy.sparse import csr_matrix +import numpy +import os + +class Folding(BaseANN): + + def __init__(self, metric, num_bits): + if metric != "jaccard": + raise NotImplementedError("Folding doesn't support metric %s, only jaccard metric is supported." % metric) + self._metric = metric + self.num_bits = num_bits + self.name = 'Folding(num_bits==%s)' % (num_bits) + + @staticmethod + def matrToArena(X, num_bits): + from chemfp import bitops + # convert X to Chemfp fingerprintArena in memory + fps = [] + for row in range(X.shape[0]): + bit_list = list(numpy.nonzero(X[row])[0]) + # folded to the required number of bits + fps.append((row,bitops.byte_from_bitlist(bit_list, num_bits))) + return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=num_bits)) + + def fit(self, X): + self._target = Folding.matrToArena(X, self.num_bits) + + def query(self, v, n, rq=False): + queryMatr = numpy.array([v]) + self._queries = Folding.matrToArena(queryMatr, self.num_bits) + if rq: + self._results = chemfp.threshold_tanimoto_search(self._queries, self._target, threshold=1.0-n) + else: + self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0) + + def post_query(self, rq=False): + # parse the results + for (query_id, hits) in self._results: + if hits: + return hits.get_ids() + else: + return []