Skip to content

Commit

Permalink
Merge pull request #275 from benfred/support_pybind_v2.2
Browse files Browse the repository at this point in the history
support pybind v2.2
  • Loading branch information
Leonid Boytsov authored and GitHub committed Jan 31, 2018
2 parents 805a53a + 09487c1 commit 5cbbef6
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions python_bindings/nmslib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,12 @@ struct IndexWrapper {
size_t size = res->Size();
py::array_t<int> ids(size);
py::array_t<dist_t> distances(size);
auto raw_ids = ids.mutable_unchecked();
auto raw_distances = distances.mutable_unchecked();

while (!res->Empty() && size > 0) {
// iterating here in reversed order, undo that
size -= 1;
raw_ids(size) = res->TopObject()->id();
raw_distances(size) = res->TopDistance();
ids.mutable_at(size) = res->TopObject()->id();
distances.mutable_at(size) = res->TopDistance();
res->Pop();
}
return py::make_tuple(ids, distances);
Expand Down Expand Up @@ -195,7 +193,7 @@ struct IndexWrapper {
size_t readObjectVector(py::object input, ObjectVector * output,
py::object ids_ = py::none()) {
std::vector<int> ids;
if (ids_ != py::none()) {
if (!ids_) {
ids = py::cast<std::vector<int>>(ids_);
}

Expand Down Expand Up @@ -234,10 +232,10 @@ struct IndexWrapper {
// read each row from the sparse matrix, and insert
auto sparse_space = reinterpret_cast<const SpaceSparseVector<dist_t>*>(space.get());
std::vector<SparseVectElem<dist_t>> sparse_items;
for (size_t rowid = 0; rowid < indptr.size() - 1; ++rowid) {
for (int rowid = 0; rowid < indptr.size() - 1; ++rowid) {
sparse_items.clear();

for (size_t i = indptr.at(rowid); i < indptr.at(rowid + 1); ++i) {
for (int i = indptr.at(rowid); i < indptr.at(rowid + 1); ++i) {
sparse_items.push_back(SparseVectElem<dist_t>(indices.at(i),
sparse_data.at(i)));
}
Expand Down Expand Up @@ -352,15 +350,20 @@ class PythonLogger
}
};

#ifdef PYBIND11_MODULE
PYBIND11_MODULE(nmslib, m) {
m.doc() = "Python Bindings for Non-Metric Space Library (NMSLIB)";
#else
PYBIND11_PLUGIN(nmslib) {
py::module m(module_name, "Python Bindings for Non-Metric Space Library (NMSLIB)");
#endif
// Log using the python logger, instead of defaults built in here
py::module logging = py::module::import("logging");
py::module nmslibLogger = logging.attr("getLogger")("nmslib");
setGlobalLogger(new PythonLogger(nmslibLogger));

initLibrary(0 /* seed */, LIB_LOGCUSTOM, NULL);

py::module m(module_name, "Bindings for Non-Metric Space Library (NMSLIB)");

#ifdef VERSION_INFO
m.attr("__version__") = py::str(VERSION_INFO);
Expand Down Expand Up @@ -438,7 +441,10 @@ PYBIND11_PLUGIN(nmslib) {
exportIndex<double>(&dist_module);

exportLegacyAPI(&m);

#ifndef PYBIND11_MODULE
return m.ptr();
#endif
}

template <typename dist_t>
Expand Down Expand Up @@ -583,7 +589,7 @@ AnyParams loadParams(py::object o) {
if (py::isinstance<py::dict>(o)) {
AnyParams ret;
py::dict items(o);
for (auto & item : items) {
for (auto item : items) {
std::string key = py::cast<std::string>(item.first);
auto & value = item.second;

Expand Down

0 comments on commit 5cbbef6

Please sign in to comment.