Skip to content

Commit

Permalink
add chemfp range query
Browse files Browse the repository at this point in the history
  • Loading branch information
ChunjiangZhu committed Jul 4, 2020
1 parent eff89a1 commit d8dbb78
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
Binary file modified ann_benchmarks/.DS_Store
Binary file not shown.
11 changes: 8 additions & 3 deletions ann_benchmarks/algorithms/chemfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,16 @@ def pre_query(self, v, n):
queryMatr = numpy.array([v])
self._queries = Chemfp.matrToArena(queryMatr)

def query(self, v, n):
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0)
def query(self, v, n, rq=False):
if rq:
self._results = chemfp.threshold_tanimoto_search(self._queries, self._target, threshold=1.0-n)
else:
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0)

def post_query(self):
def post_query(self, rq=False):
# parse the results
for (query_id, hits) in self._results:
if hits:
return hits.get_ids()
else:
return []
2 changes: 1 addition & 1 deletion ann_benchmarks/plotting/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def rangequery(dataset_distances, run_distances, radius, epsilon=1e-10):
for true_distances, found_distances in zip(dataset_distances, run_distances):
true = [d for d in true_distances if d <= radius + epsilon]
found = [d for d in found_distances if d <= radius + epsilon]
print('found: ' + str(len(found)) + '/true: ' + str(len(true)))
#print('found: ' + str(len(found)) + '/true: ' + str(len(true)))
if len(true) == 0:
if len(found) == 0:
total += 1.0
Expand Down
2 changes: 1 addition & 1 deletion ann_benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def single_query(v):

# special code for the Risc, DivideSkip, and Chemfp
if algoname in ['Risc', 'DivideSkip', 'Chemfp']:
candidates = algo.post_query()
candidates = algo.post_query(rq)

if issparse(X_train):
candidates = [(int(idx), float(metrics[distance]['distance'](v, X_train[idx].toarray()[0])))
Expand Down

0 comments on commit d8dbb78

Please sign in to comment.