diff --git a/python_bindings/tests/bindings_test.py b/python_bindings/tests/bindings_test.py index a42939e..b5b1062 100644 --- a/python_bindings/tests/bindings_test.py +++ b/python_bindings/tests/bindings_test.py @@ -7,6 +7,124 @@ import nmslib import psutil +import logging +import multiprocessing +import time +import os +import threading + +MB = 1024 * 1024 + + +class StoppableThread(threading.Thread): + """Thread class with a stop() method. The thread itself has to check + regularly for the stopped() condition.""" + + def __init__(self, *args, **kwargs): + super().__init__() + self._stop_event = threading.Event() + + def stop(self): + self._stop_event.set() + + def stopped(self): + return self._stop_event.is_set() + + +class Timer: + """ Context manager for timing named blocks of code """ + def __init__(self, name, logger=None): + self.name = name + self.logger = logger if logger else logging.getLogger() + + def __enter__(self): + self.start = time.time() + self.logger.debug("Starting {}".format(self.name)) + + def __exit__(self, type, value, trace): + self.logger.info("{}: {:0.2f}s".format(self.name, time.time() - self.start)) + + +class PeakMemoryUsage: + class Worker(StoppableThread): + def __init__(self, interval, *args, **kwargs): + super().__init__(*args, **kwargs) + self.interval = interval + self.max_rss = self.max_vms = 0 + + def run(self): + process = psutil.Process() + while not self.stopped(): + mem = process.memory_info() + self.max_rss = max(self.max_rss, mem.rss) + self.max_vms = max(self.max_vms, mem.vms) + time.sleep(self.interval) + + """ Context manager to calculate peak memory usage in a statement block """ + def __init__(self, name, logger=None, interval=1): + self.name = name + self.logger = logger if logger else logging.getLogger() + self.interval = interval + self.start = time.time() + self.worker = None + + def __enter__(self): + if self.interval > 0: + pid = os.getpid() + mem = psutil.Process(pid).memory_info() + self.start_rss, self.start_vms = mem.rss, mem.vms + + self.worker = PeakMemoryUsage.Worker(self.interval) + self.worker.start() + return self + + def __exit__(self, _, value, trace): + if self.worker: + self.worker.stop() + self.worker.join() + self.logger.warning("Peak memory usage for '{}' in MBs: orig=(rss={:0.1f} vms={:0.1f}) " + "peak=(rss={:0.1f} vms={:0.1f}) in {:0.2f}s" + .format(self.name, self.start_rss / MB, self.start_vms / MB, + self.worker.max_rss / MB, + self.worker.max_vms / MB, time.time() - self.start)) + + +class PsUtil(object): + def __init__(self, attr=('virtual_memory',), proc_attr=None, + logger=None, interval=60): + """ attr can be multiple methods of psutil (e.g. attr=['virtual_memory', 'cpu_times_percent']) """ + self.ps_mon = None + self.attr = attr + self.proc_attr = proc_attr + self.logger = logger if logger else logging.getLogger() + self.interval = interval + + def psutil_worker(self, pid): + root_proc = psutil.Process(pid) + while True: + for attr in self.attr: + self.logger.warning("PSUTIL {}".format(getattr(psutil, attr)())) + if self.proc_attr: + procs = set(root_proc.children(recursive=True)) + procs.add(root_proc) + procs = sorted(procs, key=lambda p: p.pid) + + for proc in procs: + self.logger.warning("PSUTIL process={}: {}" + .format(proc.pid, proc.as_dict(self.proc_attr))) + + time.sleep(self.interval) + + def __enter__(self): + if self.interval > 0: + self.ps_mon = multiprocessing.Process(target=self.psutil_worker, args=(os.getpid(),)) + self.ps_mon.start() + time.sleep(1) # sleep so the first iteration doesn't include statements in the PsUtil context + return self + + def __exit__(self, type, value, trace): + if self.ps_mon is not None: + self.ps_mon.terminate() def get_exact_cosine(row, data, N=10): @@ -19,6 +137,14 @@ def get_hitrate(ground_truth, ids): return len(set(i for i, _ in ground_truth).intersection(ids)) +def bit_vector_to_str(bit_vect): + return " ".join(["1" if e else "0" for e in bit_vect]) + + +def bit_vector_sparse_str(bit_vect): + return " ".join([str(k) for k, b in enumerate(bit_vect) if b]) + + class DenseIndexTestMixin(object): def _get_index(self, space='cosinesimil'): raise NotImplementedError() @@ -93,33 +219,39 @@ def _get_index(self, space='bit_jaccard'): raise NotImplementedError() def testKnnQuery(self): - nbits = 2048 - chunk_size = 1000 + nbits = 512 + chunk_size = 10000 + num_elems = 100000 ps_proc = psutil.Process() - print(f"\n{ps_proc.memory_info()}") + # print(f"\n{ps_proc.memory_info()}") index = self._get_index() - - np.random.seed(23) - for i in range(0, 10000, chunk_size): - strs = [] - for j in range(chunk_size): - a = np.random.rand(nbits) > 0.5 - s = " ".join(["1" if e else "0" for e in a]) - strs.append(s) - index.addDataPointBatch(ids=np.arange(i, i + chunk_size), data=strs) - - print(f"\n{ps_proc.memory_info()}") - index.createIndex() - print(f"\n{ps_proc.memory_info()}") + if "bit_jaccard" in str(index): + bit_vector_str_func = bit_vector_to_str + else: + bit_vector_str_func = bit_vector_sparse_str + + # logging.basicConfig(level=logging.INFO) + # with PsUtil(interval=2, proc_attr=["memory_info"]): + with PeakMemoryUsage(f"AddData: vector={nbits}-bit elems={num_elems}"): + np.random.seed(23) + for i in range(0, num_elems, chunk_size): + strs = [] + for j in range(chunk_size): + a = np.random.rand(nbits) > 0.5 + strs.append(bit_vector_str_func(a)) + index.addDataPointBatch(ids=np.arange(i, i + chunk_size), data=strs) + + # print(f"\n{ps_proc.memory_info()}") + with PeakMemoryUsage(f"CreateIndex: vector={nbits}-bit of elems={num_elems}"): + index.createIndex() + # print(f"\n{ps_proc.memory_info()}") a = np.ones(nbits) - s = " ".join(["1" if e else "0" for e in a]) - ids, distances = index.knnQuery(s, k=10) + ids, distances = index.knnQuery(bit_vector_str_func(a), k=10) # print(ids) print(distances) # self.assertTrue(get_hitrate(get_exact_cosine(row, data), ids) >= 5) - # def testKnnQueryBatch(self): # np.random.seed(23) # data = np.random.randn(1000, 10).astype(np.float32) diff --git a/similarity_search/include/distcomp.h b/similarity_search/include/distcomp.h index d16e80b..84e960b 100644 --- a/similarity_search/include/distcomp.h +++ b/similarity_search/include/distcomp.h @@ -233,7 +233,7 @@ dist_t inline BitJaccard(const dist_uint_t* a, const dist_uint_t* b, size_t qty) den += __builtin_popcount(a[i] | b[i]); } - return dist_t(num) / dist_t(den); + return 1 - (dist_t(num) / dist_t(den)); } //unsigned BitHamming(const uint32_t* a, const uint32_t* b, size_t qty);