diff --git a/ann_benchmarks/.DS_Store b/ann_benchmarks/.DS_Store index f4d2c77..e7f84d2 100644 Binary files a/ann_benchmarks/.DS_Store and b/ann_benchmarks/.DS_Store differ diff --git a/ann_benchmarks/algorithms/chemfp.py b/ann_benchmarks/algorithms/chemfp.py index 3a6f29e..3e8a7c2 100644 --- a/ann_benchmarks/algorithms/chemfp.py +++ b/ann_benchmarks/algorithms/chemfp.py @@ -15,14 +15,14 @@ def __init__(self, metric): self.name = "Chemfp()" @staticmethod - def matrToArena(X): + def matrToArena(X, reorder=True): # convert X to Chemfp fingerprintArena in memory fps = [] for row in range(X.shape[0]): fp = bitarray(endian='big') fp.extend(X[row]) fps.append((row,fp.tobytes())) - return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1])) + return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1]),reorder=reorder) def pre_fit(self, X): self._target = Chemfp.matrToArena(X) @@ -45,3 +45,19 @@ def post_query(self, rq=False): return hits.get_ids() else: return [] + def pre_batch_query(self, X, n): + self._queries = Chemfp.matrToArena(X, False) + + def batch_query(self, X, n): + self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0) + + def get_batch_results(self): + # parse the results + res = [] + for (query_id, hits) in self._results: + if hits: + res.append(hits.get_ids()) + else: + res.append([]) + print(res) + return res diff --git a/ann_benchmarks/runner.py b/ann_benchmarks/runner.py index 9acaa4d..fce0723 100644 --- a/ann_benchmarks/runner.py +++ b/ann_benchmarks/runner.py @@ -63,6 +63,9 @@ def single_query(v): return (total, candidates) def batch_query(X): + # special code for Chemfp + if algoname in ['Chemfp']: + algo.pre_batch_query(X, count) start = time.time() algo.batch_query(X, count) total = (time.time() - start)