Skip to content

Commit

Permalink
fix tanimoto
Browse files Browse the repository at this point in the history
  • Loading branch information
Greg Friedland committed Feb 17, 2019
1 parent ea72c88 commit c37e33d
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 20 deletions.
170 changes: 151 additions & 19 deletions python_bindings/tests/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,124 @@

import nmslib
import psutil
import logging
import multiprocessing
import time
import os
import threading

MB = 1024 * 1024


class StoppableThread(threading.Thread):
"""Thread class with a stop() method. The thread itself has to check
regularly for the stopped() condition."""

def __init__(self, *args, **kwargs):
super().__init__()
self._stop_event = threading.Event()

def stop(self):
self._stop_event.set()

def stopped(self):
return self._stop_event.is_set()


class Timer:
""" Context manager for timing named blocks of code """
def __init__(self, name, logger=None):
self.name = name
self.logger = logger if logger else logging.getLogger()

def __enter__(self):
self.start = time.time()
self.logger.debug("Starting {}".format(self.name))

def __exit__(self, type, value, trace):
self.logger.info("{}: {:0.2f}s".format(self.name, time.time() - self.start))


class PeakMemoryUsage:
class Worker(StoppableThread):
def __init__(self, interval, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval = interval
self.max_rss = self.max_vms = 0

def run(self):
process = psutil.Process()
while not self.stopped():
mem = process.memory_info()
self.max_rss = max(self.max_rss, mem.rss)
self.max_vms = max(self.max_vms, mem.vms)
time.sleep(self.interval)

""" Context manager to calculate peak memory usage in a statement block """
def __init__(self, name, logger=None, interval=1):
self.name = name
self.logger = logger if logger else logging.getLogger()
self.interval = interval
self.start = time.time()
self.worker = None

def __enter__(self):
if self.interval > 0:
pid = os.getpid()
mem = psutil.Process(pid).memory_info()
self.start_rss, self.start_vms = mem.rss, mem.vms

self.worker = PeakMemoryUsage.Worker(self.interval)
self.worker.start()
return self

def __exit__(self, _, value, trace):
if self.worker:
self.worker.stop()
self.worker.join()
self.logger.warning("Peak memory usage for '{}' in MBs: orig=(rss={:0.1f} vms={:0.1f}) "
"peak=(rss={:0.1f} vms={:0.1f}) in {:0.2f}s"
.format(self.name, self.start_rss / MB, self.start_vms / MB,
self.worker.max_rss / MB,
self.worker.max_vms / MB, time.time() - self.start))


class PsUtil(object):
def __init__(self, attr=('virtual_memory',), proc_attr=None,
logger=None, interval=60):
""" attr can be multiple methods of psutil (e.g. attr=['virtual_memory', 'cpu_times_percent']) """
self.ps_mon = None
self.attr = attr
self.proc_attr = proc_attr
self.logger = logger if logger else logging.getLogger()
self.interval = interval

def psutil_worker(self, pid):
root_proc = psutil.Process(pid)
while True:
for attr in self.attr:
self.logger.warning("PSUTIL {}".format(getattr(psutil, attr)()))
if self.proc_attr:
procs = set(root_proc.children(recursive=True))
procs.add(root_proc)
procs = sorted(procs, key=lambda p: p.pid)

for proc in procs:
self.logger.warning("PSUTIL process={}: {}"
.format(proc.pid, proc.as_dict(self.proc_attr)))

time.sleep(self.interval)

def __enter__(self):
if self.interval > 0:
self.ps_mon = multiprocessing.Process(target=self.psutil_worker, args=(os.getpid(),))
self.ps_mon.start()
time.sleep(1) # sleep so the first iteration doesn't include statements in the PsUtil context
return self

def __exit__(self, type, value, trace):
if self.ps_mon is not None:
self.ps_mon.terminate()


def get_exact_cosine(row, data, N=10):
Expand All @@ -19,6 +137,14 @@ def get_hitrate(ground_truth, ids):
return len(set(i for i, _ in ground_truth).intersection(ids))


def bit_vector_to_str(bit_vect):
return " ".join(["1" if e else "0" for e in bit_vect])


def bit_vector_sparse_str(bit_vect):
return " ".join([str(k) for k, b in enumerate(bit_vect) if b])


class DenseIndexTestMixin(object):
def _get_index(self, space='cosinesimil'):
raise NotImplementedError()
Expand Down Expand Up @@ -93,33 +219,39 @@ def _get_index(self, space='bit_jaccard'):
raise NotImplementedError()

def testKnnQuery(self):
nbits = 2048
chunk_size = 1000
nbits = 512
chunk_size = 10000
num_elems = 100000

ps_proc = psutil.Process()
print(f"\n{ps_proc.memory_info()}")
# print(f"\n{ps_proc.memory_info()}")
index = self._get_index()

np.random.seed(23)
for i in range(0, 10000, chunk_size):
strs = []
for j in range(chunk_size):
a = np.random.rand(nbits) > 0.5
s = " ".join(["1" if e else "0" for e in a])
strs.append(s)
index.addDataPointBatch(ids=np.arange(i, i + chunk_size), data=strs)

print(f"\n{ps_proc.memory_info()}")
index.createIndex()
print(f"\n{ps_proc.memory_info()}")
if "bit_jaccard" in str(index):
bit_vector_str_func = bit_vector_to_str
else:
bit_vector_str_func = bit_vector_sparse_str

# logging.basicConfig(level=logging.INFO)
# with PsUtil(interval=2, proc_attr=["memory_info"]):
with PeakMemoryUsage(f"AddData: vector={nbits}-bit elems={num_elems}"):
np.random.seed(23)
for i in range(0, num_elems, chunk_size):
strs = []
for j in range(chunk_size):
a = np.random.rand(nbits) > 0.5
strs.append(bit_vector_str_func(a))
index.addDataPointBatch(ids=np.arange(i, i + chunk_size), data=strs)

# print(f"\n{ps_proc.memory_info()}")
with PeakMemoryUsage(f"CreateIndex: vector={nbits}-bit of elems={num_elems}"):
index.createIndex()
# print(f"\n{ps_proc.memory_info()}")

a = np.ones(nbits)
s = " ".join(["1" if e else "0" for e in a])
ids, distances = index.knnQuery(s, k=10)
ids, distances = index.knnQuery(bit_vector_str_func(a), k=10)
# print(ids)
print(distances)
# self.assertTrue(get_hitrate(get_exact_cosine(row, data), ids) >= 5)

# def testKnnQueryBatch(self):
# np.random.seed(23)
# data = np.random.randn(1000, 10).astype(np.float32)
Expand Down
2 changes: 1 addition & 1 deletion similarity_search/include/distcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ dist_t inline BitJaccard(const dist_uint_t* a, const dist_uint_t* b, size_t qty)
den += __builtin_popcount(a[i] | b[i]);
}

return dist_t(num) / dist_t(den);
return 1 - (dist_t(num) / dist_t(den));
}

//unsigned BitHamming(const uint32_t* a, const uint32_t* b, size_t qty);
Expand Down

0 comments on commit c37e33d

Please sign in to comment.