Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Add the folding algorithm
  • Loading branch information
cjz18001 committed Sep 21, 2020
1 parent cf566a6 commit 528f4b5
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions ann_benchmarks/algorithms/folding.py
@@ -0,0 +1,45 @@
from __future__ import absolute_import
import chemfp
from ann_benchmarks.algorithms.base import BaseANN
from scipy.sparse import csr_matrix
import numpy
import os

class Folding(BaseANN):

def __init__(self, metric, num_bits):
if metric != "jaccard":
raise NotImplementedError("Folding doesn't support metric %s, only jaccard metric is supported." % metric)
self._metric = metric
self.num_bits = num_bits
self.name = 'Folding(num_bits==%s)' % (num_bits)

@staticmethod
def matrToArena(X, num_bits):
from chemfp import bitops
# convert X to Chemfp fingerprintArena in memory
fps = []
for row in range(X.shape[0]):
bit_list = list(numpy.nonzero(X[row])[0])
# folded to the required number of bits
fps.append((row,bitops.byte_from_bitlist(bit_list, num_bits)))
return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=num_bits))

def fit(self, X):
self._target = Folding.matrToArena(X, self.num_bits)

def query(self, v, n, rq=False):
queryMatr = numpy.array([v])
self._queries = Folding.matrToArena(queryMatr, self.num_bits)
if rq:
self._results = chemfp.threshold_tanimoto_search(self._queries, self._target, threshold=1.0-n)
else:
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0)

def post_query(self, rq=False):
# parse the results
for (query_id, hits) in self._results:
if hits:
return hits.get_ids()
else:
return []

0 comments on commit 528f4b5

Please sign in to comment.