diff --git a/ann_benchmarks/algorithms/chemfp.py b/ann_benchmarks/algorithms/chemfp.py new file mode 100644 index 0000000..e700a6e --- /dev/null +++ b/ann_benchmarks/algorithms/chemfp.py @@ -0,0 +1,42 @@ +from __future__ import absolute_import +import chemfp +from ann_benchmarks.algorithms.base import BaseANN +from scipy.sparse import csr_matrix +import numpy +import os +from bitarray import bitarray + +class Chemfp(BaseANN): + + def __init__(self, metric): + if metric != "jaccard": + raise NotImplementedError("Chemfp doesn't support metric %s, only jaccard metric is supported." % metric) + self._metric = metric + self.name = "Chemfp()" + + @staticmethod + def matrToArena(X): + # convert X to Chemfp fingerprintArena in memory + fps = [] + for row in range(X.shape[0]): + fp = bitarray(endian='big') + fp.extend(X[row]) + fps.append((row,fp.tobytes())) + return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1])) + + def pre_fit(self, X): + self._target = Chemfp.matrToArena(X) + + + def pre_query(self, v, n): + queryMatr = numpy.array([v]) + self._queries = Chemfp.matrToArena(queryMatr) + + def query(self, v, n): + self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0) + + def post_query(self): + # parse the results + for (query_id, hits) in self._results: + if hits: + return hits.get_ids()