From 55f9fc49ba39dfefc9f8492c168522cf769926e5 Mon Sep 17 00:00:00 2001 From: searchivarius Date: Fri, 31 May 2019 00:07:13 -0400 Subject: [PATCH] Resolving #356 --- python_bindings/nmslib.cc | 30 ++++++--- python_bindings/setup.py | 2 +- python_bindings/tests/bindings_test.py | 72 +++++++++++---------- similarity_search/src/space.cc | 2 +- similarity_search/test/test_space_serial.cc | 16 ++--- 5 files changed, 71 insertions(+), 51 deletions(-) diff --git a/python_bindings/nmslib.cc b/python_bindings/nmslib.cc index 950a28c..add5588 100644 --- a/python_bindings/nmslib.cc +++ b/python_bindings/nmslib.cc @@ -39,7 +39,8 @@ namespace py = pybind11; namespace similarity { -const char * module_name = "nmslib"; +const char* module_name = "nmslib"; +const char* data_suff = ".dat"; enum DistType { DISTTYPE_FLOAT, @@ -93,10 +94,16 @@ struct IndexWrapper { index->CreateIndex(params); } - void loadIndex(const std::string & filename, bool print_progress = false) { + void loadIndex(const std::string & filename, bool load_data = false) { py::gil_scoped_release l; auto factory = MethodFactoryRegistry::Instance(); + bool print_progress=false; // We are not going to creat the index anyways, only to load an existing one index.reset(factory.CreateMethod(print_progress, method, space_type, *space, data)); + if (load_data) { + vector dummy; + data.clear(); + space->ReadObjectVectorFromBinData(data, dummy, filename + data_suff); + } index->LoadIndex(filename); // querying reloaded indices don't seem to work correctly (at least hnsw ones) until @@ -104,11 +111,15 @@ struct IndexWrapper { index->ResetQueryTimeParams(); } - void saveIndex(const std::string & filename) { + void saveIndex(const std::string & filename, bool save_data = false) { if (!index) { throw std::invalid_argument("Must call createIndex or loadIndex before this method"); } py::gil_scoped_release l; + if (save_data) { + vector dummy; + space->WriteObjectVectorBinData(data, dummy, filename + data_suff); + } index->SaveIndex(filename); } @@ -573,22 +584,25 @@ void exportIndex(py::module * m) { .def("loadIndex", &IndexWrapper::loadIndex, py::arg("filename"), - py::arg("print_progress") = false, + py::arg("load_data") = false, "Loads the index from disk\n\n" "Parameters\n" "----------\n" "filename: str\n" " The filename to read from\n" - "print_progress: bool optional\n" - " Whether or not to display progress bar when creating index\n") + "load_data: bool optional\n" + " Whether or not to load previously saved data.\n") .def("saveIndex", &IndexWrapper::saveIndex, py::arg("filename"), - "Saves the index to disk\n\n" + py::arg("save_data") = false, + "Saves the index and/or data to disk\n\n" "Parameters\n" "----------\n" "filename: str\n" - " The filename to save to\n") + " The filename to save to\n" + "save_data: bool optional\n" + " Whether or not to save data\n") .def("setQueryTimeParams", [](IndexWrapper * self, py::object params) { diff --git a/python_bindings/setup.py b/python_bindings/setup.py index 78fb79d..2ca8138 100755 --- a/python_bindings/setup.py +++ b/python_bindings/setup.py @@ -4,7 +4,7 @@ import sys import setuptools -__version__ = '1.7.3.6' +__version__ = '1.8' libdir = os.path.join(".", "nmslib", "similarity_search") if not os.path.isdir(libdir) and sys.platform.startswith("win"): diff --git a/python_bindings/tests/bindings_test.py b/python_bindings/tests/bindings_test.py index 61920fb..b84dc53 100644 --- a/python_bindings/tests/bindings_test.py +++ b/python_bindings/tests/bindings_test.py @@ -112,18 +112,20 @@ def testReloadIndex(self): original.addDataPointBatch(data) original.createIndex() - # test out saving/reloading index - with tempfile.NamedTemporaryFile() as tmp: - original.saveIndex(tmp.name + ".index") + for save_data in [0, 1]: + # test out saving/reloading index + with tempfile.NamedTemporaryFile() as tmp: + original.saveIndex(tmp.name + ".index", save_data=save_data) - reloaded = self._get_index() - reloaded.addDataPointBatch(data) - reloaded.loadIndex(tmp.name + ".index") + reloaded = self._get_index() + if save_data == 0: + reloaded.addDataPointBatch(data) + reloaded.loadIndex(tmp.name + ".index", load_data=save_data) - original_results = original.knnQuery(data[0]) - reloaded_results = reloaded.knnQuery(data[0]) - self.assert_allclose(original_results, - reloaded_results) + original_results = original.knnQuery(data[0]) + reloaded_results = reloaded.knnQuery(data[0]) + self.assert_allclose(original_results, + reloaded_results) class BitVectorIndexTestMixin(object): @@ -167,20 +169,22 @@ def testReloadIndex(self): original.addDataPointBatch(ids=ids, data=data) original.createIndex() - # test out saving/reloading index - with tempfile.NamedTemporaryFile() as tmp: - original.saveIndex(tmp.name + ".index") + for save_data in [0, 1]: + # test out saving/reloading index + with tempfile.NamedTemporaryFile() as tmp: + original.saveIndex(tmp.name + ".index", save_data=save_data) - reloaded = self._get_index() - for ids, data in batches: - reloaded.addDataPointBatch(ids=ids, data=data) - reloaded.loadIndex(tmp.name + ".index") + reloaded = self._get_index() + if save_data == 0: + for ids, data in batches: + reloaded.addDataPointBatch(ids=ids, data=data) + reloaded.loadIndex(tmp.name + ".index", load_data=save_data) - s = self.bit_vector_str_func(np.ones(512)) - original_results = original.knnQuery(s) - reloaded_results = reloaded.knnQuery(s) - self.assert_allclose(original_results, - reloaded_results) + s = self.bit_vector_str_func(np.ones(512)) + original_results = original.knnQuery(s) + reloaded_results = reloaded.knnQuery(s) + self.assert_allclose(original_results, + reloaded_results) class HNSWTestCase(TestCaseBase, DenseIndexTestMixin): @@ -219,17 +223,19 @@ def testReloadIndex(self): original.createIndex() # test out saving/reloading index - with tempfile.NamedTemporaryFile() as tmp: - original.saveIndex(tmp.name + ".index") - - reloaded = self._get_index() - reloaded.addDataPointBatch(data) - reloaded.loadIndex(tmp.name + ".index") - - original_results = original.knnQuery(data[0]) - reloaded_results = reloaded.knnQuery(data[0]) - self.assert_allclose(original_results, - reloaded_results) + for save_data in [0, 1]: + with tempfile.NamedTemporaryFile() as tmp: + original.saveIndex(tmp.name + ".index", save_data=save_data) + + reloaded = self._get_index() + if save_data == 0: + reloaded.addDataPointBatch(data) + reloaded.loadIndex(tmp.name + ".index", load_data=save_data) + + original_results = original.knnQuery(data[0]) + reloaded_results = reloaded.knnQuery(data[0]) + self.assert_allclose(original_results, + reloaded_results) class BallTreeTestCase(TestCaseBase, DenseIndexTestMixin): diff --git a/similarity_search/src/space.cc b/similarity_search/src/space.cc index af77ca4..447c64d 100644 --- a/similarity_search/src/space.cc +++ b/similarity_search/src/space.cc @@ -67,7 +67,7 @@ Space::ReadObjectVectorFromBinData(ObjectVector& data, size_t qty; size_t objSize; std::ifstream input(fileName, std::ios::binary); - CHECK_MSG(input, "Cannot open file '" + fileName + "' for writing"); + CHECK_MSG(input, "Cannot open file '" + fileName + "' for reading"); input.exceptions(std::ios::badbit | std::ios::failbit); vExternIds.clear(); diff --git a/similarity_search/test/test_space_serial.cc b/similarity_search/test/test_space_serial.cc index 8da38e2..6ff1419 100644 --- a/similarity_search/test/test_space_serial.cc +++ b/similarity_search/test/test_space_serial.cc @@ -54,23 +54,23 @@ bool fullTestCommon(bool binTest, Space* pSpace, unique_ptr inpState; if (binTest) { - pSpace->WriteDataset(dataSet1, vExternIds1, tmpFileName); - inpState = pSpace->ReadDataset(dataSet2, vExternIds2, tmpFileName); - } else { pSpace->WriteObjectVectorBinData(dataSet1, vExternIds1, tmpFileName); inpState = pSpace->ReadObjectVectorFromBinData(dataSet2, vExternIds2, tmpFileName); + } else { + pSpace->WriteDataset(dataSet1, vExternIds1, tmpFileName); + inpState = pSpace->ReadDataset(dataSet2, vExternIds2, tmpFileName); } pSpace->UpdateParamsFromFile(*inpState); if (maxNumRec != dataSet2.size()) { - LOG(LIB_ERROR) << "Expected to read " << maxNumRec << " records from " + LOG(LIB_ERROR) << "binTest" << binTest << "Expected to read " << maxNumRec << " records from " "dataSet, but read only: " << dataSet2.size(); return false; } if (vExternIds2.size() != dataSet2.size()) { - LOG(LIB_ERROR) << "The number of external IDs (" << vExternIds1.size() << ") is different from the number of records: " << dataSet2.size(); + LOG(LIB_ERROR) << "binTest" << binTest << "The number of external IDs (" << vExternIds1.size() << ") is different from the number of records: " << dataSet2.size(); return false; } @@ -81,12 +81,12 @@ bool fullTestCommon(bool binTest, Space* pSpace, if (bTestExternId) { if (vExternIds1[i] != vExternIds2[i]) { - LOG(LIB_ERROR) << "External IDs are different, i = " << i << " id1 = '" << vExternIds1[i] << "' id2 = '" << vExternIds2[i] << "'" ; + LOG(LIB_ERROR) << "binTest" << binTest << " External IDs are different, i = " << i << " id1 = '" << vExternIds1[i] << "' id2 = '" << vExternIds2[i] << "'" ; return false; } } if (!pSpace->ApproxEqual(*dataSet1[i], *dataSet2[i])) { - LOG(LIB_ERROR) << "Objects are different, i = " << i; + LOG(LIB_ERROR) << "binTest" << binTest << "Objects are different, i = " << i; LOG(LIB_ERROR) << "Object 1 string representation produced by the space:" << pSpace->CreateStrFromObj(dataSet1[i], vExternIds1[i]); LOG(LIB_ERROR) << "Object 2 string representation produced by the space:" << @@ -94,7 +94,7 @@ bool fullTestCommon(bool binTest, Space* pSpace, return false; } if (dataSet1[i]->id() != dataSet2[i]->id()) { - LOG(LIB_ERROR) << "Objects IDs are different, i = " << i; + LOG(LIB_ERROR) << "binTest" << binTest << "Objects IDs are different, i = " << i; LOG(LIB_ERROR) << "Object 1 id: "<< dataSet1[i]->id() << " Object 2 id: " << dataSet2[i]->id(); } }