Skip to content

Commit

Permalink
add chemfp.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cjz18001 authored Jun 21, 2020
1 parent 5221721 commit ac4e3b4
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions ann_benchmarks/algorithms/chemfp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import absolute_import
import chemfp
from ann_benchmarks.algorithms.base import BaseANN
from scipy.sparse import csr_matrix
import numpy
import os
from bitarray import bitarray

class Chemfp(BaseANN):

def __init__(self, metric):
if metric != "jaccard":
raise NotImplementedError("Chemfp doesn't support metric %s, only jaccard metric is supported." % metric)
self._metric = metric
self.name = "Chemfp()"

@staticmethod
def matrToArena(X):
# convert X to Chemfp fingerprintArena in memory
fps = []
for row in range(X.shape[0]):
fp = bitarray(endian='big')
fp.extend(X[row])
fps.append((row,fp.tobytes()))
return chemfp.load_fingerprints(fps,chemfp.Metadata(num_bits=X.shape[1]))

def pre_fit(self, X):
self._target = Chemfp.matrToArena(X)


def pre_query(self, v, n):
queryMatr = numpy.array([v])
self._queries = Chemfp.matrToArena(queryMatr)

def query(self, v, n):
self._results = chemfp.knearest_tanimoto_search(self._queries, self._target, k=n, threshold=0.0)

def post_query(self):
# parse the results
for (query_id, hits) in self._results:
if hits:
return hits.get_ids()

0 comments on commit ac4e3b4

Please sign in to comment.