Skip to content

Commit

Permalink
Resolving #356
Browse files Browse the repository at this point in the history
  • Loading branch information
searchivarius committed May 31, 2019
1 parent 9067080 commit 55f9fc4
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 51 deletions.
30 changes: 22 additions & 8 deletions python_bindings/nmslib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
namespace py = pybind11;

namespace similarity {
const char * module_name = "nmslib";
const char* module_name = "nmslib";
const char* data_suff = ".dat";

enum DistType {
DISTTYPE_FLOAT,
Expand Down Expand Up @@ -93,22 +94,32 @@ struct IndexWrapper {
index->CreateIndex(params);
}

void loadIndex(const std::string & filename, bool print_progress = false) {
void loadIndex(const std::string & filename, bool load_data = false) {
py::gil_scoped_release l;
auto factory = MethodFactoryRegistry<dist_t>::Instance();
bool print_progress=false; // We are not going to creat the index anyways, only to load an existing one
index.reset(factory.CreateMethod(print_progress, method, space_type, *space, data));
if (load_data) {
vector<string> dummy;
data.clear();
space->ReadObjectVectorFromBinData(data, dummy, filename + data_suff);
}
index->LoadIndex(filename);

// querying reloaded indices don't seem to work correctly (at least hnsw ones) until
// SetQueryTimeParams is called
index->ResetQueryTimeParams();
}

void saveIndex(const std::string & filename) {
void saveIndex(const std::string & filename, bool save_data = false) {
if (!index) {
throw std::invalid_argument("Must call createIndex or loadIndex before this method");
}
py::gil_scoped_release l;
if (save_data) {
vector<string> dummy;
space->WriteObjectVectorBinData(data, dummy, filename + data_suff);
}
index->SaveIndex(filename);
}

Expand Down Expand Up @@ -573,22 +584,25 @@ void exportIndex(py::module * m) {

.def("loadIndex", &IndexWrapper<dist_t>::loadIndex,
py::arg("filename"),
py::arg("print_progress") = false,
py::arg("load_data") = false,
"Loads the index from disk\n\n"
"Parameters\n"
"----------\n"
"filename: str\n"
" The filename to read from\n"
"print_progress: bool optional\n"
" Whether or not to display progress bar when creating index\n")
"load_data: bool optional\n"
" Whether or not to load previously saved data.\n")

.def("saveIndex", &IndexWrapper<dist_t>::saveIndex,
py::arg("filename"),
"Saves the index to disk\n\n"
py::arg("save_data") = false,
"Saves the index and/or data to disk\n\n"
"Parameters\n"
"----------\n"
"filename: str\n"
" The filename to save to\n")
" The filename to save to\n"
"save_data: bool optional\n"
" Whether or not to save data\n")

.def("setQueryTimeParams",
[](IndexWrapper<dist_t> * self, py::object params) {
Expand Down
2 changes: 1 addition & 1 deletion python_bindings/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
import setuptools

__version__ = '1.7.3.6'
__version__ = '1.8'

libdir = os.path.join(".", "nmslib", "similarity_search")
if not os.path.isdir(libdir) and sys.platform.startswith("win"):
Expand Down
72 changes: 39 additions & 33 deletions python_bindings/tests/bindings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,20 @@ def testReloadIndex(self):
original.addDataPointBatch(data)
original.createIndex()

# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index")
for save_data in [0, 1]:
# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index", save_data=save_data)

reloaded = self._get_index()
reloaded.addDataPointBatch(data)
reloaded.loadIndex(tmp.name + ".index")
reloaded = self._get_index()
if save_data == 0:
reloaded.addDataPointBatch(data)
reloaded.loadIndex(tmp.name + ".index", load_data=save_data)

original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
self.assert_allclose(original_results,
reloaded_results)
original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
self.assert_allclose(original_results,
reloaded_results)


class BitVectorIndexTestMixin(object):
Expand Down Expand Up @@ -167,20 +169,22 @@ def testReloadIndex(self):
original.addDataPointBatch(ids=ids, data=data)
original.createIndex()

# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index")
for save_data in [0, 1]:
# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index", save_data=save_data)

reloaded = self._get_index()
for ids, data in batches:
reloaded.addDataPointBatch(ids=ids, data=data)
reloaded.loadIndex(tmp.name + ".index")
reloaded = self._get_index()
if save_data == 0:
for ids, data in batches:
reloaded.addDataPointBatch(ids=ids, data=data)
reloaded.loadIndex(tmp.name + ".index", load_data=save_data)

s = self.bit_vector_str_func(np.ones(512))
original_results = original.knnQuery(s)
reloaded_results = reloaded.knnQuery(s)
self.assert_allclose(original_results,
reloaded_results)
s = self.bit_vector_str_func(np.ones(512))
original_results = original.knnQuery(s)
reloaded_results = reloaded.knnQuery(s)
self.assert_allclose(original_results,
reloaded_results)


class HNSWTestCase(TestCaseBase, DenseIndexTestMixin):
Expand Down Expand Up @@ -219,17 +223,19 @@ def testReloadIndex(self):
original.createIndex()

# test out saving/reloading index
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index")

reloaded = self._get_index()
reloaded.addDataPointBatch(data)
reloaded.loadIndex(tmp.name + ".index")

original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
self.assert_allclose(original_results,
reloaded_results)
for save_data in [0, 1]:
with tempfile.NamedTemporaryFile() as tmp:
original.saveIndex(tmp.name + ".index", save_data=save_data)

reloaded = self._get_index()
if save_data == 0:
reloaded.addDataPointBatch(data)
reloaded.loadIndex(tmp.name + ".index", load_data=save_data)

original_results = original.knnQuery(data[0])
reloaded_results = reloaded.knnQuery(data[0])
self.assert_allclose(original_results,
reloaded_results)


class BallTreeTestCase(TestCaseBase, DenseIndexTestMixin):
Expand Down
2 changes: 1 addition & 1 deletion similarity_search/src/space.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ Space<dist_t>::ReadObjectVectorFromBinData(ObjectVector& data,
size_t qty;
size_t objSize;
std::ifstream input(fileName, std::ios::binary);
CHECK_MSG(input, "Cannot open file '" + fileName + "' for writing");
CHECK_MSG(input, "Cannot open file '" + fileName + "' for reading");
input.exceptions(std::ios::badbit | std::ios::failbit);
vExternIds.clear();

Expand Down
16 changes: 8 additions & 8 deletions similarity_search/test/test_space_serial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,23 @@ bool fullTestCommon(bool binTest, Space<dist_t>* pSpace,
unique_ptr<DataFileInputState> inpState;

if (binTest) {
pSpace->WriteDataset(dataSet1, vExternIds1, tmpFileName);
inpState = pSpace->ReadDataset(dataSet2, vExternIds2, tmpFileName);
} else {
pSpace->WriteObjectVectorBinData(dataSet1, vExternIds1, tmpFileName);
inpState = pSpace->ReadObjectVectorFromBinData(dataSet2, vExternIds2, tmpFileName);
} else {
pSpace->WriteDataset(dataSet1, vExternIds1, tmpFileName);
inpState = pSpace->ReadDataset(dataSet2, vExternIds2, tmpFileName);
}

pSpace->UpdateParamsFromFile(*inpState);

if (maxNumRec != dataSet2.size()) {
LOG(LIB_ERROR) << "Expected to read " << maxNumRec << " records from "
LOG(LIB_ERROR) << "binTest" << binTest << "Expected to read " << maxNumRec << " records from "
"dataSet, but read only: " << dataSet2.size();
return false;
}

if (vExternIds2.size() != dataSet2.size()) {
LOG(LIB_ERROR) << "The number of external IDs (" << vExternIds1.size() << ") is different from the number of records: " << dataSet2.size();
LOG(LIB_ERROR) << "binTest" << binTest << "The number of external IDs (" << vExternIds1.size() << ") is different from the number of records: " << dataSet2.size();
return false;
}

Expand All @@ -81,20 +81,20 @@ bool fullTestCommon(bool binTest, Space<dist_t>* pSpace,

if (bTestExternId) {
if (vExternIds1[i] != vExternIds2[i]) {
LOG(LIB_ERROR) << "External IDs are different, i = " << i << " id1 = '" << vExternIds1[i] << "' id2 = '" << vExternIds2[i] << "'" ;
LOG(LIB_ERROR) << "binTest" << binTest << " External IDs are different, i = " << i << " id1 = '" << vExternIds1[i] << "' id2 = '" << vExternIds2[i] << "'" ;
return false;
}
}
if (!pSpace->ApproxEqual(*dataSet1[i], *dataSet2[i])) {
LOG(LIB_ERROR) << "Objects are different, i = " << i;
LOG(LIB_ERROR) << "binTest" << binTest << "Objects are different, i = " << i;
LOG(LIB_ERROR) << "Object 1 string representation produced by the space:" <<
pSpace->CreateStrFromObj(dataSet1[i], vExternIds1[i]);
LOG(LIB_ERROR) << "Object 2 string representation produced by the space:" <<
pSpace->CreateStrFromObj(dataSet2[i], vExternIds2[i]);
return false;
}
if (dataSet1[i]->id() != dataSet2[i]->id()) {
LOG(LIB_ERROR) << "Objects IDs are different, i = " << i;
LOG(LIB_ERROR) << "binTest" << binTest << "Objects IDs are different, i = " << i;
LOG(LIB_ERROR) << "Object 1 id: "<< dataSet1[i]->id() << " Object 2 id: " << dataSet2[i]->id();
}
}
Expand Down

0 comments on commit 55f9fc4

Please sign in to comment.