Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Two important issues:
1) Implementing a space efficient SIFT/L2 space #280
2) Python API: reading/writing data of complex structure #282
  • Loading branch information
searchivairus committed Feb 8, 2018
1 parent 6c1a64b commit 524883d
Show file tree
Hide file tree
Showing 4 changed files with 10,452 additions and 4 deletions.
54 changes: 51 additions & 3 deletions python_bindings/nmslib.cc
Expand Up @@ -30,8 +30,10 @@
#include "knnqueue.h"
#include "methodfactory.h"
#include "space.h"
#include "space/space_vector.h"
#include "spacefactory.h"
#include "space/space_sparse_vector.h"
#include "space/space_l2sqr_sift.h"
#include "thread_pool.h"

namespace py = pybind11;
Expand All @@ -47,6 +49,7 @@ enum DistType {

enum DataType {
DATATYPE_DENSE_VECTOR,
DATATYPE_DENSE_UINT8_VECTOR,
DATATYPE_SPARSE_VECTOR,
DATATYPE_OBJECT_AS_STRING,
};
Expand All @@ -69,6 +72,16 @@ struct IndexWrapper {
: method(method), space_type(space_type), data_type(data_type), dist_type(dist_type),
space(SpaceFactoryRegistry<dist_t>::Instance().CreateSpace(space_type,
loadParams(space_params))) {
auto vectSpacePtr = dynamic_cast<VectorSpace<dist_t>*>(space.get());
if (data_type == DATATYPE_DENSE_VECTOR && vectSpacePtr == nullptr) {
throw std::invalid_argument("The space type " + space_type +
" is not compatible with the type DENSE_VECTOR, only dense vector spaces are allowed!");
}
auto vectSiftPtr = dynamic_cast<SpaceL2SqrSift*>(space.get());
if (data_type == DATATYPE_DENSE_UINT8_VECTOR && vectSiftPtr == nullptr) {
throw std::invalid_argument("The space type " + space_type +
" is not compatible with the type DENSE_UINT8_VECTOR!");
}
}

void createIndex(py::object index_params, bool print_progress = false) {
Expand Down Expand Up @@ -162,7 +175,17 @@ struct IndexWrapper {
switch (data_type) {
case DATATYPE_DENSE_VECTOR: {
py::array_t<dist_t> temp(input);
return new Object(id, -1, temp.size() * sizeof(dist_t), temp.data(0));
std::vector<dist_t> tempVect(temp.data(0), temp.data(0) + temp.size());
auto vectSpacePtr = reinterpret_cast<VectorSpace<dist_t>*>(space.get());
vectSpacePtr->CreateObjFromVect(id, -1, tempVect);
// This way it will not always work properly
//return new Object(id, -1, temp.size() * sizeof(dist_t), temp.data(0));
}
case DATATYPE_DENSE_UINT8_VECTOR: {
py::array_t<uint8_t> temp(input);
std::vector<uint8_t> tempVect(temp.data(0), temp.data(0) + temp.size());
auto vectSiftPtr = reinterpret_cast<SpaceL2SqrSift*>(space.get());
vectSiftPtr->CreateObjFromUint8Vect(id, -1, tempVect);
}
case DATATYPE_OBJECT_AS_STRING: {
std::string temp = py::cast<std::string>(input);
Expand Down Expand Up @@ -209,9 +232,31 @@ struct IndexWrapper {
if (buffer.ndim != 2) throw std::runtime_error("data must be a 2d array");

size_t rows = buffer.shape[0], features = buffer.shape[1];
std::vector<dist_t> tempVect(features);
auto vectSpacePtr = reinterpret_cast<VectorSpace<dist_t>*>(space.get());
for (size_t row = 0; row < rows; ++row) {
int id = ids.size() ? ids.at(row) : row;
const dist_t* elemVecStart = items.data(row);
std::copy(elemVecStart, elemVecStart + features, tempVect.begin());
output->push_back(vectSpacePtr->CreateObjFromVect(id, -1, tempVect));
//this way it won't always work properly
//output->push_back(new Object(id, -1, features * sizeof(dist_t), items.data(row)));
}
return rows;
} else if (data_type == DATATYPE_DENSE_UINT8_VECTOR) {
// allow numpy arrays to be returned here too
py::array_t<uint8_t, py::array::c_style | py::array::forcecast> items(input);
auto buffer = items.request();
if (buffer.ndim != 2) throw std::runtime_error("data must be a 2d array");

size_t rows = buffer.shape[0], features = buffer.shape[1];
std::vector<uint8_t> tempVect(features);
auto vectSiftPtr = reinterpret_cast<SpaceL2SqrSift*>(space.get());
for (size_t row = 0; row < rows; ++row) {
int id = ids.size() ? ids.at(row) : row;
output->push_back(new Object(id, -1, features * sizeof(dist_t), items.data(row)));
const uint8_t* elemVecStart = items.data(row);
std::copy(elemVecStart, elemVecStart + features, tempVect.begin());
output->push_back(vectSiftPtr->CreateObjFromUint8Vect(id, -1, tempVect));
}
return rows;

Expand Down Expand Up @@ -251,9 +296,11 @@ struct IndexWrapper {
py::object writeObject(const Object * obj) {
switch (data_type) {
case DATATYPE_DENSE_VECTOR: {
auto vectSpacePtr = reinterpret_cast<VectorSpace<dist_t>*>(space.get());
py::list ret;
const dist_t * values = reinterpret_cast<const dist_t *>(obj->data());
for (size_t i = 0; i < obj->datalength() / sizeof(dist_t); ++i) {
size_t elemQty = vectSpacePtr->GetElemQty(obj);
for (size_t i = 0; i < elemQty; ++i) {
ret.append(py::cast(values[i]));
}
return ret;
Expand Down Expand Up @@ -376,6 +423,7 @@ PYBIND11_PLUGIN(nmslib) {

py::enum_<DataType>(m, "DataType")
.value("DENSE_VECTOR", DATATYPE_DENSE_VECTOR)
.value("DENSE_UINT8_VECTOR", DATATYPE_DENSE_UINT8_VECTOR)
.value("SPARSE_VECTOR", DATATYPE_SPARSE_VECTOR)
.value("OBJECT_AS_STRING", DATATYPE_OBJECT_AS_STRING);

Expand Down

0 comments on commit 524883d

Please sign in to comment.