Skip to content

Commit

Permalink
add chemfp batch K-NN search
Browse files Browse the repository at this point in the history
  • Loading branch information
ChunjiangZhu committed Jul 26, 2020
1 parent c4c913e commit 2cdb3bf
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 2 deletions.
Binary file modified ann_benchmarks/.DS_Store
Binary file not shown.
20 changes: 18 additions & 2 deletions ann_benchmarks/algorithms/chemfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def __init__(self, metric):
self.name = "Chemfp()"

@staticmethod
def matrToArena(X):
def matrToArena(X, reorder=True):
# convert X to Chemfp fingerprintArena in memory
fps = []
for row in range(X.shape[0]):
fp = bitarray(endian='big')
fp.extend(X[row])
fps.append((row,fp.tobytes()))
return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1]))
return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1]),reorder=reorder)

def pre_fit(self, X):
self._target = Chemfp.matrToArena(X)
Expand All @@ -45,3 +45,19 @@ def post_query(self, rq=False):
return hits.get_ids()
else:
return []
def pre_batch_query(self, X, n):
self._queries = Chemfp.matrToArena(X, False)

def batch_query(self, X, n):
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0)

def get_batch_results(self):
# parse the results
res = []
for (query_id, hits) in self._results:
if hits:
res.append(hits.get_ids())
else:
res.append([])
print(res)
return res
3 changes: 3 additions & 0 deletions ann_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def single_query(v):
return (total, candidates)

def batch_query(X):
# special code for Chemfp
if algoname in ['Chemfp']:
algo.pre_batch_query(X, count)
start = time.time()
algo.batch_query(X, count)
total = (time.time() - start)
Expand Down

0 comments on commit 2cdb3bf

Please sign in to comment.