Skip to content
Permalink
2b7cf56841
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
1592 lines (1377 sloc) 61.7 KB
/**
* Non-metric Space Library
*
* Main developers: Bilegsaikhan Naidan, Leonid Boytsov, Yury Malkov, Ben Frederickson, David Novak
*
* For the complete list of contributors and further details see:
* https://github.com/searchivarius/NonMetricSpaceLib
*
* Copyright (c) 2013-2018
*
* This code is released under the
* Apache License Version 2.0 http://www.apache.org/licenses/.
*
*/
/*
*
* A Hierarchical Navigable Small World (HNSW) approach.
*
* The main publication is (available on arxiv: http://arxiv.org/abs/1603.09320):
* "Efficient and robust approximate nearest neighbor search using Hierarchical Navigable Small World graphs" by Yu. A. Malkov, D. A. Yashunin
* This code was contributed by Yu. A. Malkov. It also was used in tests from the paper.
*
*
*/
#include <cmath>
#include <iostream>
#include <memory>
// This is only for _mm_prefetch
#include <mmintrin.h>
#include "portable_simd.h"
#include "knnquery.h"
#include "method/hnsw.h"
#include "ported_boost_progress.h"
#include "rangequery.h"
#include "space.h"
#include "space/space_lp.h"
#include "thread_pool.h"
#include "utils.h"
#include <map>
#include <set>
#include <sstream>
#include <typeinfo>
#include <vector>
#include "sort_arr_bi.h"
#define MERGE_BUFFER_ALGO_SWITCH_THRESHOLD 100
#define USE_BITSET_FOR_INDEXING 1
#define EXTEND_USE_EXTENDED_NEIGHB_AT_CONSTR (0) // 0 is faster build, 1 is faster search on clustered data
#if defined(__GNUC__)
#define PORTABLE_ALIGN16 __attribute__((aligned(16)))
#else
#define PORTABLE_ALIGN16 __declspec(align(16))
#endif
// For debug purposes we also implemented saving an index to a text file
#define USE_TEXT_REGULAR_INDEX (false)
#define TOTAL_QTY "TOTAL_QTY"
#define MAX_LEVEL "MAX_LEVEL"
#define ENTER_POINT_ID "ENTER_POINT_ID"
#define FIELD_M "M"
#define FIELD_MAX_M "MAX_M"
#define FIELD_MAX_M0 "MAX_M0"
#define CURR_LEVEL "CURR_LEVEL"
#define MAXIMUM_K 1000
namespace similarity {
// This is the counter to keep the size of neighborhood information (for one node)
// TODO Can this one overflow? I really doubt
typedef uint32_t SIZEMASS_TYPE;
using namespace std;
/*Functions from hnsw_distfunc_opt.cc:*/
float L2SqrSIMDExt(const float *pVect1, const float *pVect2, size_t &qty, float *TmpRes);
float L2SqrSIMD16Ext(const float *pVect1, const float *pVect2, size_t &qty, float *TmpRes);
float NormScalarProductSIMD(const float *pVect1, const float *pVect2, size_t &qty, float *TmpRes);
template <typename dist_t>
Hnsw<dist_t>::Hnsw(bool PrintProgress, const Space<dist_t> &space, const ObjectVector &data)
: Index<dist_t>(data)
, space_(space)
, PrintProgress_(PrintProgress)
, visitedlistpool(nullptr)
, enterpoint_(nullptr)
, data_level0_memory_(nullptr)
, linkLists_(nullptr)
, fstdistfunc_(nullptr)
{
}
void
checkList1(vector<HnswNode *> list)
{
int ok = 1;
for (size_t i = 0; i < list.size(); i++) {
for (size_t j = 0; j < list[i]->allFriends_[0].size(); j++) {
for (size_t k = j + 1; k < list[i]->allFriends_[0].size(); k++) {
if (list[i]->allFriends_[0][j] == list[i]->allFriends_[0][k]) {
cout << "\nDuplicate links\n\n\n\n\n!!!!!";
ok = 0;
}
}
if (list[i]->allFriends_[0][j] == list[i]) {
cout << "\nLink to the same element\n\n\n\n\n!!!!!";
ok = 0;
}
}
}
if (ok)
cout << "\nOK\n";
else
cout << "\nNOT OK!!!\n";
return;
}
void
getDegreeDistr(string filename, vector<HnswNode *> list)
{
ofstream out(filename);
size_t maxdegree = 0;
for (HnswNode *node : list) {
if (node->allFriends_[0].size() > maxdegree)
maxdegree = node->allFriends_[0].size();
}
vector<int> distrin = vector<int>(1000);
vector<int> distrout = vector<int>(1000);
vector<int> inconnections = vector<int>(list.size());
vector<int> outconnections = vector<int>(list.size());
for (size_t i = 0; i < list.size(); i++) {
for (HnswNode *node : list[i]->allFriends_[0]) {
outconnections[list[i]->getId()]++;
inconnections[node->getId()]++;
}
}
for (size_t i = 0; i < list.size(); i++) {
distrin[inconnections[i]]++;
distrout[outconnections[i]]++;
}
for (size_t i = 0; i < distrin.size(); i++) {
out << i << "\t" << distrin[i] << "\t" << distrout[i] << "\n";
}
out.close();
return;
}
template <typename dist_t>
void
Hnsw<dist_t>::CreateIndex(const AnyParams &IndexParams)
{
AnyParamManager pmgr(IndexParams);
pmgr.GetParamOptional("M", M_, 16);
// Let's use a generic algorithm by default!
pmgr.GetParamOptional(
"searchMethod", searchMethod_, 0); // this is just to prevent terminating the program when searchMethod is specified
searchMethod_ = 0;
indexThreadQty_ = std::thread::hardware_concurrency();
pmgr.GetParamOptional("indexThreadQty", indexThreadQty_, indexThreadQty_);
// indexThreadQty_ = 1;
pmgr.GetParamOptional("efConstruction", efConstruction_, 200);
pmgr.GetParamOptional("maxM", maxM_, M_);
pmgr.GetParamOptional("maxM0", maxM0_, M_ * 2);
pmgr.GetParamOptional("mult", mult_, 1 / log(1.0 * M_));
pmgr.GetParamOptional("delaunay_type", delaunay_type_, 2);
int post_;
pmgr.GetParamOptional("post", post_, 0);
int skip_optimized_index = 0;
pmgr.GetParamOptional("skip_optimized_index", skip_optimized_index, 0);
LOG(LIB_INFO) << "M = " << M_;
LOG(LIB_INFO) << "indexThreadQty = " << indexThreadQty_;
LOG(LIB_INFO) << "efConstruction = " << efConstruction_;
LOG(LIB_INFO) << "maxM = " << maxM_;
LOG(LIB_INFO) << "maxM0 = " << maxM0_;
LOG(LIB_INFO) << "mult = " << mult_;
LOG(LIB_INFO) << "skip_optimized_index= " << skip_optimized_index;
LOG(LIB_INFO) << "delaunay_type = " << delaunay_type_;
SetQueryTimeParams(getEmptyParams());
if (this->data_.empty()) {
pmgr.CheckUnused();
return;
}
ElList_.resize(this->data_.size());
// One entry should be added before all the threads are started, or else add() will not work properly
HnswNode *first = new HnswNode(this->data_[0], 0 /* id == 0 */);
first->init(getRandomLevel(mult_), maxM_, maxM0_);
maxlevel_ = first->level;
enterpoint_ = first;
ElList_[0] = first;
visitedlistpool = new VisitedListPool(indexThreadQty_, this->data_.size());
unique_ptr<ProgressDisplay> progress_bar(PrintProgress_ ? new ProgressDisplay(this->data_.size(), cerr) : NULL);
ParallelFor(1, this->data_.size(), indexThreadQty_, [&](int id, int threadId) {
HnswNode *node = new HnswNode(this->data_[id], id);
add(&space_, node);
{
unique_lock<mutex> lock(ElListGuard_);
ElList_[id] = node;
if (progress_bar)
++(*progress_bar);
}
});
if (progress_bar)
progress_bar->finish();
if (post_ == 1 || post_ == 2) {
vector<HnswNode *> temp;
temp.swap(ElList_);
ElList_.resize(this->data_.size());
first = new HnswNode(this->data_[0], 0 /* id == 0 */);
first->init(getRandomLevel(mult_), maxM_, maxM0_);
maxlevel_ = first->level;
enterpoint_ = first;
ElList_[0] = first;
/// Making the same index in reverse order
unique_ptr<ProgressDisplay> progress_bar1(PrintProgress_ ? new ProgressDisplay(this->data_.size(), cerr) : NULL);
ParallelFor(1, this->data_.size(), indexThreadQty_, [&](int pos_id, int threadId) {
// reverse ordering (so we iterate decreasing). given
// parallelfor, this might not make a difference
int id = this->data_.size() - pos_id;
HnswNode *node = new HnswNode(this->data_[id], id);
add(&space_, node);
{
unique_lock<mutex> lock(ElListGuard_);
ElList_[id] = node;
if (progress_bar1)
++(*progress_bar1);
}
if (progress_bar1)
progress_bar1->finish();
});
int maxF = 0;
// int degrees[100] = {0};
ParallelFor(1, this->data_.size(), indexThreadQty_, [&](int id, int threadId) {
HnswNode *node1 = ElList_[id];
HnswNode *node2 = temp[id];
vector<HnswNode *> f1 = node1->getAllFriends(0);
vector<HnswNode *> f2 = node2->getAllFriends(0);
unordered_set<size_t> intersect = unordered_set<size_t>();
for (HnswNode *cur : f1) {
intersect.insert(cur->getId());
}
for (HnswNode *cur : f2) {
intersect.insert(cur->getId());
}
if (intersect.size() > maxF)
maxF = intersect.size();
vector<HnswNode *> rez = vector<HnswNode *>();
if (post_ == 2) {
priority_queue<HnswNodeDistCloser<dist_t>> resultSet;
for (int cur : intersect) {
resultSet.emplace(space_.IndexTimeDistance(ElList_[cur]->getData(), ElList_[id]->getData()),
ElList_[cur]);
}
switch (delaunay_type_) {
case 0:
while (resultSet.size() > maxM0_)
resultSet.pop();
break;
case 2:
case 1:
ElList_[id]->getNeighborsByHeuristic1(resultSet, maxM0_, &space_);
break;
case 3:
ElList_[id]->getNeighborsByHeuristic3(resultSet, maxM0_, &space_, 0);
break;
}
while (!resultSet.empty()) {
rez.push_back(resultSet.top().getMSWNodeHier());
resultSet.pop();
}
} else if (post_ == 1) {
maxM0_ = maxF;
for (int cur : intersect) {
rez.push_back(ElList_[cur]);
}
}
{
unique_lock<mutex> lock(ElList_[id]->accessGuard_);
ElList_[id]->allFriends_[0].swap(rez);
}
// degrees[ElList_[id]->allFriends_[0].size()]++;
});
for (int i = 0; i < temp.size(); i++)
delete temp[i];
temp.clear();
}
// Uncomment for debug mode
// checkList1(ElList_);
data_level0_memory_ = NULL;
linkLists_ = NULL;
enterpointId_ = enterpoint_->getId();
if (skip_optimized_index) {
LOG(LIB_INFO) << "searchMethod = " << searchMethod_;
pmgr.CheckUnused();
return;
}
int friendsSectionSize = (maxM0_ + 1) * sizeof(int);
// Checking for maximum size of the datasection:
int dataSectionSize = 1;
for (int i = 0; i < ElList_.size(); i++) {
if (ElList_[i]->getData()->bufferlength() > dataSectionSize)
dataSectionSize = ElList_[i]->getData()->bufferlength();
}
// Selecting custom made functions
if (space_.StrDesc().compare("SpaceLp: p = 2 do we have a special implementation for this p? : 1") == 0 &&
sizeof(dist_t) == 4) {
LOG(LIB_INFO) << "\nThe space is Euclidean";
vectorlength_ = ((dataSectionSize - 16) >> 2);
LOG(LIB_INFO) << "Vector length=" << vectorlength_;
if (vectorlength_ % 16 == 0) {
LOG(LIB_INFO) << "Thus using an optimised function for base 16";
fstdistfunc_ = L2SqrSIMD16Ext;
dist_func_type_ = 1;
searchMethod_ = 3;
} else {
LOG(LIB_INFO) << "Thus using function with any base";
fstdistfunc_ = L2SqrSIMDExt;
dist_func_type_ = 2;
searchMethod_ = 3;
}
} else if (space_.StrDesc().compare("CosineSimilarity") == 0 && sizeof(dist_t) == 4) {
LOG(LIB_INFO) << "\nThe vectorspace is Cosine Similarity";
vectorlength_ = ((dataSectionSize - 16) >> 2);
LOG(LIB_INFO) << "Vector length=" << vectorlength_;
iscosine_ = true;
if (vectorlength_ % 4 == 0) {
LOG(LIB_INFO) << "Thus using an optimised function for base 4";
fstdistfunc_ = NormScalarProductSIMD;
dist_func_type_ = 3;
searchMethod_ = 4;
} else {
LOG(LIB_INFO) << "Thus using function with any base";
LOG(LIB_INFO) << "Search method 4 is not allowed in this case";
fstdistfunc_ = NormScalarProductSIMD;
dist_func_type_ = 3;
searchMethod_ = 3;
}
} else {
LOG(LIB_INFO) << "No appropriate custom distance function for " << space_.StrDesc();
// if (searchMethod_ != 0 && searchMethod_ != 1)
searchMethod_ = 0;
LOG(LIB_INFO) << "searchMethod = " << searchMethod_;
pmgr.CheckUnused();
return; // No optimized index
}
pmgr.CheckUnused();
LOG(LIB_INFO) << "searchMethod = " << searchMethod_;
memoryPerObject_ = dataSectionSize + friendsSectionSize;
size_t total_memory_allocated = (memoryPerObject_ * ElList_.size());
data_level0_memory_ = (char *)malloc(memoryPerObject_ * ElList_.size());
CHECK(data_level0_memory_);
offsetLevel0_ = dataSectionSize;
offsetData_ = 0;
memset(data_level0_memory_, 1, memoryPerObject_ * ElList_.size());
LOG(LIB_INFO) << "Making optimized index";
data_rearranged_.resize(ElList_.size());
for (long i = 0; i < ElList_.size(); i++) {
ElList_[i]->copyDataAndLevel0LinksToOptIndex(
data_level0_memory_ + (size_t)i * memoryPerObject_, offsetLevel0_, offsetData_);
data_rearranged_[i] = new Object(data_level0_memory_ + (i)*memoryPerObject_ + offsetData_);
};
////////////////////////////////////////////////////////////////////////
//
// The next step is needed only fos cosine similarity space
// All vectors are normalized, so we don't have to normalize them later
//
////////////////////////////////////////////////////////////////////////
if (iscosine_) {
for (long i = 0; i < ElList_.size(); i++) {
float *v = (float *)(data_level0_memory_ + (size_t)i * memoryPerObject_ + offsetData_ + 16);
float sum = 0;
for (int i = 0; i < vectorlength_; i++) {
sum += v[i] * v[i];
}
if (sum != 0.0) {
sum = 1 / sqrt(sum);
for (int i = 0; i < vectorlength_; i++) {
v[i] *= sum;
}
}
};
}
/////////////////////////////////////////////////////////
////////////////////////////////////////////////////////
linkLists_ = (char **)malloc(sizeof(void *) * ElList_.size());
CHECK(linkLists_);
for (long i = 0; i < ElList_.size(); i++) {
if (ElList_[i]->level < 1) {
linkLists_[i] = nullptr;
continue;
}
// TODO Can this one overflow? I really doubt
SIZEMASS_TYPE sizemass = ((ElList_[i]->level) * (maxM_ + 1)) * sizeof(int);
total_memory_allocated += sizemass;
char *linkList = (char *)malloc(sizemass);
CHECK(linkList);
linkLists_[i] = linkList;
ElList_[i]->copyHigherLevelLinksToOptIndex(linkList, 0);
};
LOG(LIB_INFO) << "Finished making optimized index";
LOG(LIB_INFO) << "Maximum level = " << enterpoint_->level;
LOG(LIB_INFO) << "Total memory allocated for optimized index+data: " << (total_memory_allocated >> 20) << " Mb";
}
template <typename dist_t>
void
Hnsw<dist_t>::SetQueryTimeParams(const AnyParams &QueryTimeParams)
{
AnyParamManager pmgr(QueryTimeParams);
if (pmgr.hasParam("ef") && pmgr.hasParam("efSearch")) {
throw runtime_error("The user shouldn't specify parameters ef and efSearch at the same time (they are synonyms)");
}
// ef and efSearch are going to be parameter-synonyms with the default value 20
pmgr.GetParamOptional("ef", ef_, 20);
pmgr.GetParamOptional("efSearch", ef_, ef_);
int tmp;
pmgr.GetParamOptional(
"searchMethod", tmp, 0); // this is just to prevent terminating the program when searchMethod is specified
string tmps;
pmgr.GetParamOptional("algoType", tmps, "hybrid");
ToLower(tmps);
if (tmps == "v1merge")
searchAlgoType_ = kV1Merge;
else if (tmps == "old")
searchAlgoType_ = kOld;
else if (tmps == "hybrid")
searchAlgoType_ = kHybrid;
else {
throw runtime_error("algoType should be one of the following: old, v1merge");
}
pmgr.CheckUnused();
LOG(LIB_INFO) << "Set HNSW query-time parameters:";
LOG(LIB_INFO) << "ef(Search) =" << ef_;
LOG(LIB_INFO) << "algoType =" << searchAlgoType_;
}
template <typename dist_t>
const std::string
Hnsw<dist_t>::StrDesc() const
{
return METH_HNSW;
}
template <typename dist_t> Hnsw<dist_t>::~Hnsw()
{
delete visitedlistpool;
if (data_level0_memory_)
free(data_level0_memory_);
if (linkLists_) {
for (int i = 0; i < data_rearranged_.size(); i++) {
if (linkLists_[i])
free(linkLists_[i]);
}
free(linkLists_);
}
for (int i = 0; i < ElList_.size(); i++)
delete ElList_[i];
for (const Object *p : data_rearranged_)
delete p;
}
template <typename dist_t>
void
Hnsw<dist_t>::add(const Space<dist_t> *space, HnswNode *NewElement)
{
int curlevel = getRandomLevel(mult_);
unique_lock<mutex> *lock = nullptr;
if (curlevel > maxlevel_)
lock = new unique_lock<mutex>(MaxLevelGuard_);
NewElement->init(curlevel, maxM_, maxM0_);
int maxlevelcopy = maxlevel_;
HnswNode *ep = enterpoint_;
if (curlevel < maxlevelcopy) {
const Object *currObj = ep->getData();
dist_t d = space->IndexTimeDistance(NewElement->getData(), currObj);
dist_t curdist = d;
HnswNode *curNode = ep;
for (int level = maxlevelcopy; level > curlevel; level--) {
bool changed = true;
while (changed) {
changed = false;
unique_lock<mutex> lock(curNode->accessGuard_);
const vector<HnswNode *> &neighbor = curNode->getAllFriends(level);
int size = neighbor.size();
for (int i = 0; i < size; i++) {
HnswNode *node = neighbor[i];
_mm_prefetch((char *)(node)->getData(), _MM_HINT_T0);
}
for (int i = 0; i < size; i++) {
currObj = (neighbor[i])->getData();
d = space->IndexTimeDistance(NewElement->getData(), currObj);
if (d < curdist) {
curdist = d;
curNode = neighbor[i];
changed = true;
}
}
}
}
ep = curNode;
}
for (int level = min(curlevel, maxlevelcopy); level >= 0; level--) {
priority_queue<HnswNodeDistCloser<dist_t>> resultSet;
kSearchElementsWithAttemptsLevel(space, NewElement->getData(), efConstruction_, resultSet, ep, level);
switch (delaunay_type_) {
case 0:
while (resultSet.size() > M_)
resultSet.pop();
break;
case 1:
NewElement->getNeighborsByHeuristic1(resultSet, M_, space);
break;
case 2:
NewElement->getNeighborsByHeuristic2(resultSet, M_, space, level);
break;
case 3:
NewElement->getNeighborsByHeuristic3(resultSet, M_, space, level);
break;
}
while (!resultSet.empty()) {
ep = resultSet.top().getMSWNodeHier(); // memorizing the closest
link(resultSet.top().getMSWNodeHier(), NewElement, level, space, delaunay_type_);
resultSet.pop();
}
}
if (curlevel > enterpoint_->level) {
enterpoint_ = NewElement;
maxlevel_ = curlevel;
}
if (lock != nullptr)
delete lock;
}
template <typename dist_t>
void
Hnsw<dist_t>::kSearchElementsWithAttemptsLevel(const Space<dist_t> *space, const Object *queryObj, size_t efConstruction,
priority_queue<HnswNodeDistCloser<dist_t>> &resultSet, HnswNode *ep,
int level) const
{
#if EXTEND_USE_EXTENDED_NEIGHB_AT_CONSTR != 0
priority_queue<HnswNodeDistCloser<dist_t>> fullResultSet;
#endif
#if USE_BITSET_FOR_INDEXING
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *mass = vl->mass;
vl_type curV = vl->curV;
#else
unordered_set<HnswNode *> visited;
#endif
HnswNode *provider = ep;
priority_queue<HnswNodeDistFarther<dist_t>> candidateSet;
dist_t d = space->IndexTimeDistance(queryObj, provider->getData());
HnswNodeDistFarther<dist_t> ev(d, provider);
candidateSet.push(ev);
resultSet.emplace(d, provider);
#if EXTEND_USE_EXTENDED_NEIGHB_AT_CONSTR != 0
fullResultSet.emplace(d, provider);
#endif
#if USE_BITSET_FOR_INDEXING
size_t nodeId = provider->getId();
mass[nodeId] = curV;
#else
visited.insert(provider);
#endif
while (!candidateSet.empty()) {
const HnswNodeDistFarther<dist_t> &currEv = candidateSet.top();
dist_t lowerBound = resultSet.top().getDistance();
/*
* Check if we reached a local minimum.
*/
if (currEv.getDistance() > lowerBound) {
break;
}
HnswNode *currNode = currEv.getMSWNodeHier();
/*
* This lock protects currNode from being modified
* while we are accessing elements of currNode.
*/
unique_lock<mutex> lock(currNode->accessGuard_);
const vector<HnswNode *> &neighbor = currNode->getAllFriends(level);
// Can't access curEv anymore! The reference would become invalid
candidateSet.pop();
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
}
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
#if USE_BITSET_FOR_INDEXING
size_t nodeId = (*iter)->getId();
if (mass[nodeId] != curV) {
mass[nodeId] = curV;
#else
if (visited.find((*iter)) == visited.end()) {
visited.insert(*iter);
#endif
d = space->IndexTimeDistance(queryObj, (*iter)->getData());
HnswNodeDistFarther<dist_t> evE1(d, *iter);
#if EXTEND_USE_EXTENDED_NEIGHB_AT_CONSTR != 0
fullResultSet.emplace(d, *iter);
#endif
if (resultSet.size() < efConstruction || resultSet.top().getDistance() > d) {
resultSet.emplace(d, *iter);
candidateSet.push(evE1);
if (resultSet.size() > efConstruction) {
resultSet.pop();
}
}
}
}
}
#if EXTEND_USE_EXTENDED_NEIGHB_AT_CONSTR != 0
resultSet.swap(fullResultSet);
#endif
#if USE_BITSET_FOR_INDEXING
visitedlistpool->releaseVisitedList(vl);
#endif
}
template <typename dist_t>
void
Hnsw<dist_t>::Search(RangeQuery<dist_t> *query, IdType) const
{
// throw runtime_error("Range search is not supported!");
if (this->data_.empty() && this->data_rearranged_.empty()) {
return;
}
bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000);
// cout << "Ef = " << ef_ << " use old = " << useOld << endl;
switch (searchMethod_) {
default:
throw runtime_error("Invalid searchMethod: " + ConvertToString(searchMethod_));
break;
case 0:
/// Basic search using Nmslib data structure:
// if (useOld)
// const_cast<Hnsw *>(this)->baseSearchAlgorithmOld(query);
// else
const_cast<Hnsw *>(this)->baseSearchAlgorithmV1Merge(query);
break;
// case 1:
// /// Experimental search using Nmslib data structure (should not be used):
// const_cast<Hnsw *>(this)->listPassingModifiedAlgorithm(query);
// break;
// case 3:
// /// Basic search using optimized index(cosine+L2)
// if (useOld)
// const_cast<Hnsw *>(this)->SearchL2CustomOld(query);
// else
// const_cast<Hnsw *>(this)->SearchL2CustomV1Merge(query);
// break;
// case 4:
// /// Basic search using optimized index with one-time normalized cosine similarity
// /// Only for cosine similarity!
// if (useOld)
// const_cast<Hnsw *>(this)->SearchCosineNormalizedOld(query);
// else
// const_cast<Hnsw *>(this)->SearchCosineNormalizedV1Merge(query);
// break;
};
}
template <typename dist_t>
void
Hnsw<dist_t>::Search(KNNQuery<dist_t> *query, IdType) const
{
if (this->data_.empty() && this->data_rearranged_.empty()) {
return;
}
bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000);
// cout << "Ef = " << ef_ << " use old = " << useOld << endl;
switch (searchMethod_) {
default:
throw runtime_error("Invalid searchMethod: " + ConvertToString(searchMethod_));
break;
case 0:
/// Basic search using Nmslib data structure:
if (useOld)
const_cast<Hnsw *>(this)->baseSearchAlgorithmOld(query);
else
const_cast<Hnsw *>(this)->baseSearchAlgorithmV1Merge(query);
break;
case 1:
/// Experimental search using Nmslib data structure (should not be used):
const_cast<Hnsw *>(this)->listPassingModifiedAlgorithm(query);
break;
case 3:
/// Basic search using optimized index(cosine+L2)
if (useOld)
const_cast<Hnsw *>(this)->SearchL2CustomOld(query);
else
const_cast<Hnsw *>(this)->SearchL2CustomV1Merge(query);
break;
case 4:
/// Basic search using optimized index with one-time normalized cosine similarity
/// Only for cosine similarity!
if (useOld)
const_cast<Hnsw *>(this)->SearchCosineNormalizedOld(query);
else
const_cast<Hnsw *>(this)->SearchCosineNormalizedV1Merge(query);
break;
};
}
template <typename dist_t>
void
Hnsw<dist_t>::SaveIndex(const string &location) {
std::ofstream output(location,
std::ios::binary /* text files can be opened in binary mode as well */);
CHECK_MSG(output, "Cannot open file '" + location + "' for writing");
output.exceptions(ios::badbit | ios::failbit);
unsigned int optimIndexFlag = data_level0_memory_ != nullptr;
if (!optimIndexFlag) {
#if USE_TEXT_REGULAR_INDEX
SaveRegularIndexText(output);
#else
writeBinaryPOD(output, optimIndexFlag);
SaveRegularIndexBin(output);
#endif
} else {
writeBinaryPOD(output, optimIndexFlag);
SaveOptimizedIndex(output);
}
output.close();
}
template <typename dist_t>
void
Hnsw<dist_t>::SaveOptimizedIndex(std::ostream& output) {
totalElementsStored_ = ElList_.size();
writeBinaryPOD(output, totalElementsStored_);
writeBinaryPOD(output, memoryPerObject_);
writeBinaryPOD(output, offsetLevel0_);
writeBinaryPOD(output, offsetData_);
writeBinaryPOD(output, maxlevel_);
writeBinaryPOD(output, enterpointId_);
writeBinaryPOD(output, maxM_);
writeBinaryPOD(output, maxM0_);
writeBinaryPOD(output, dist_func_type_);
writeBinaryPOD(output, searchMethod_);
size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_;
LOG(LIB_INFO) << "writing " << data_plus_links0_size << " bytes";
output.write(data_level0_memory_, data_plus_links0_size);
// output.write(data_level0_memory_, memoryPerObject_*totalElementsStored_);
// size_t total_memory_allocated = 0;
for (size_t i = 0; i < totalElementsStored_; i++) {
// TODO Can this one overflow? I really doubt
SIZEMASS_TYPE sizemass = ((ElList_[i]->level) * (maxM_ + 1)) * sizeof(int);
writeBinaryPOD(output, sizemass);
if ((sizemass))
output.write(linkLists_[i], sizemass);
};
}
template <typename dist_t>
void
Hnsw<dist_t>::SaveRegularIndexBin(std::ostream& output) {
totalElementsStored_ = ElList_.size();
writeBinaryPOD(output, totalElementsStored_);
writeBinaryPOD(output, maxlevel_);
writeBinaryPOD(output, enterpointId_);
writeBinaryPOD(output, M_);
writeBinaryPOD(output, maxM_);
writeBinaryPOD(output, maxM0_);
for (unsigned i = 0; i < totalElementsStored_; ++i) {
const HnswNode& node = *ElList_[i];
unsigned currlevel = node.level;
CHECK(currlevel + 1 == node.allFriends_.size());
/*
* This check strangely fails ...
CHECK_MSG(maxlevel_ >= currlevel, ""
"maxlevel_ (" + ConvertToString(maxlevel_) + ") < node.allFriends_.size() (" + ConvertToString(currlevel));
*/
writeBinaryPOD(output, currlevel);
for (unsigned level = 0; level <= currlevel; ++level) {
const auto& friends = node.allFriends_[level];
unsigned friendQty = friends.size();
writeBinaryPOD(output, friendQty);
for (unsigned k = 0; k < friendQty; ++k) {
IdType friendId = friends[k]->id_;
writeBinaryPOD(output, friendId);
}
}
}
}
template <typename dist_t>
void
Hnsw<dist_t>::SaveRegularIndexText(std::ostream& output) {
size_t lineNum = 0;
totalElementsStored_ = ElList_.size();
WriteField(output, TOTAL_QTY, totalElementsStored_); lineNum++;
WriteField(output, MAX_LEVEL, maxlevel_); lineNum++;
WriteField(output, ENTER_POINT_ID, enterpointId_); lineNum++;
WriteField(output, FIELD_M, M_); lineNum++;
WriteField(output, FIELD_MAX_M, maxM_); lineNum++;
WriteField(output, FIELD_MAX_M0, maxM0_); lineNum++;
vector<IdType> friendIds;
for (unsigned i = 0; i < totalElementsStored_; ++i) {
const HnswNode& node = *ElList_[i];
unsigned currlevel = node.level;
CHECK(currlevel + 1 == node.allFriends_.size());
/*
* This check strangely fails ...
CHECK_MSG(maxlevel_ >= currlevel, ""
"maxlevel_ (" + ConvertToString(maxlevel_) + ") < node.allFriends_.size() (" + ConvertToString(currlevel));
*/
WriteField(output, CURR_LEVEL, currlevel); lineNum++;
for (unsigned level = 0; level <= currlevel; ++level) {
const auto& friends = node.allFriends_[level];
unsigned friendQty = friends.size();
friendIds.resize(friendQty);
for (unsigned k = 0; k < friendQty; ++k) {
friendIds[k] = friends[k]->id_;
}
output << MergeIntoStr(friendIds, ' ') << endl; lineNum++;
}
}
WriteField(output, LINE_QTY, lineNum);
}
template <typename dist_t>
void
Hnsw<dist_t>::LoadRegularIndexText(std::istream& input) {
LOG(LIB_INFO) << "Loading regular index.";
size_t lineNum = 0;
ReadField(input, TOTAL_QTY, totalElementsStored_); lineNum++;
ReadField(input, MAX_LEVEL, maxlevel_); lineNum++;
ReadField(input, ENTER_POINT_ID, enterpointId_); lineNum++;
ReadField(input, FIELD_M, M_); lineNum++;
ReadField(input, FIELD_MAX_M, maxM_); lineNum++;
ReadField(input, FIELD_MAX_M0, maxM0_); lineNum++;
fstdistfunc_ = nullptr;
dist_func_type_ = 0;
searchMethod_ = 0;
ElList_.resize(totalElementsStored_);
for (unsigned id = 0; id < totalElementsStored_; ++id) {
ElList_[id] = new HnswNode(this->data_[id], id);
}
enterpoint_ = ElList_[enterpointId_];
string line;
vector<IdType> friendIds;
for (unsigned id = 0; id < totalElementsStored_; ++id) {
HnswNode& node = *ElList_[id];
unsigned currlevel;
ReadField(input, CURR_LEVEL, currlevel); lineNum++;
node.level = currlevel;
node.allFriends_.resize(currlevel + 1);
for (unsigned level = 0; level <= currlevel; ++level) {
CHECK_MSG(getline(input, line),
"Failed to read line #" + ConvertToString(lineNum)); lineNum++;
CHECK_MSG(SplitStr(line, friendIds, ' '),
"Failed to extract neighbor IDs from line #" + ConvertToString(lineNum));
unsigned friendQty = friendIds.size();
auto& friends = node.allFriends_[level];
friends.resize(friendQty);
for (unsigned k = 0; k < friendQty; ++k) {
IdType friendId = friendIds[k];
CHECK_MSG(friendId >= 0 && friendId < totalElementsStored_,
"Invalid friendId = " + ConvertToString(friendId) + " for node id: " + ConvertToString(id));
friends[k] = ElList_[friendId];
}
}
}
size_t ExpLineNum;
ReadField(input, LINE_QTY, ExpLineNum);
CHECK_MSG(lineNum == ExpLineNum,
DATA_MUTATION_ERROR_MSG + " (expected number of lines " + ConvertToString(ExpLineNum) +
" read so far doesn't match the number of read lines: " + ConvertToString(lineNum));
}
template <typename dist_t>
void
Hnsw<dist_t>::LoadRegularIndexBin(std::istream& input) {
LOG(LIB_INFO) << "Loading regular index.";
readBinaryPOD(input, totalElementsStored_);
readBinaryPOD(input, maxlevel_);
readBinaryPOD(input, enterpointId_);
readBinaryPOD(input, M_);
readBinaryPOD(input, maxM_);
readBinaryPOD(input, maxM0_);
fstdistfunc_ = nullptr;
dist_func_type_ = 0;
searchMethod_ = 0;
CHECK_MSG(totalElementsStored_ == this->data_.size(),
"The number of stored elements " + ConvertToString(totalElementsStored_) +
" doesn't match the number of data points " + ConvertToString(this->data_.size() +
"! Did you forget to re-load data?"))
ElList_.resize(totalElementsStored_);
for (unsigned id = 0; id < totalElementsStored_; ++id) {
ElList_[id] = new HnswNode(this->data_[id], id);
}
enterpoint_ = ElList_[enterpointId_];
for (unsigned id = 0; id < totalElementsStored_; ++id) {
HnswNode& node = *ElList_[id];
unsigned currlevel;
readBinaryPOD(input, currlevel);
node.level = currlevel;
node.allFriends_.resize(currlevel + 1);
for (unsigned level = 0; level <= currlevel; ++level) {
auto& friends = node.allFriends_[level];
unsigned friendQty;
readBinaryPOD(input, friendQty);
friends.resize(friendQty);
for (unsigned k = 0; k < friendQty; ++k) {
IdType friendId;
readBinaryPOD(input, friendId);
CHECK_MSG(friendId >= 0 && friendId < totalElementsStored_,
"Invalid friendId = " + ConvertToString(friendId) + " for node id: " + ConvertToString(id));
friends[k] = ElList_[friendId];
}
}
}
}
template <typename dist_t>
void
Hnsw<dist_t>::LoadIndex(const string &location) {
LOG(LIB_INFO) << "Loading index from " << location;
std::ifstream input(location,
std::ios::binary); /* text files can be opened in binary mode as well */
CHECK_MSG(input, "Cannot open file '" + location + "' for reading");
input.exceptions(ios::badbit | ios::failbit);
#if USE_TEXT_REGULAR_INDEX
LoadRegularIndexText(input);
#else
unsigned int optimIndexFlag= 0;
readBinaryPOD(input, optimIndexFlag);
if (!optimIndexFlag) {
LoadRegularIndexBin(input);
} else {
LoadOptimizedIndex(input);
}
#endif
input.close();
LOG(LIB_INFO) << "Finished loading index";
visitedlistpool = new VisitedListPool(1, totalElementsStored_);
}
template <typename dist_t>
void
Hnsw<dist_t>::LoadOptimizedIndex(std::istream& input) {
LOG(LIB_INFO) << "Loading optimized index.";
readBinaryPOD(input, totalElementsStored_);
readBinaryPOD(input, memoryPerObject_);
readBinaryPOD(input, offsetLevel0_);
readBinaryPOD(input, offsetData_);
readBinaryPOD(input, maxlevel_);
readBinaryPOD(input, enterpointId_);
readBinaryPOD(input, maxM_);
readBinaryPOD(input, maxM0_);
readBinaryPOD(input, dist_func_type_);
readBinaryPOD(input, searchMethod_);
LOG(LIB_INFO) << "searchMethod: " << searchMethod_;
if (dist_func_type_ == 1)
fstdistfunc_ = L2SqrSIMD16Ext;
else if (dist_func_type_ == 2)
fstdistfunc_ = L2SqrSIMDExt;
else if (dist_func_type_ == 3)
fstdistfunc_ = NormScalarProductSIMD;
// LOG(LIB_INFO) << input.tellg();
LOG(LIB_INFO) << "Total: " << totalElementsStored_ << ", Memory per object: " << memoryPerObject_;
size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_;
data_level0_memory_ = (char *)malloc(data_plus_links0_size);
CHECK(data_level0_memory_);
input.read(data_level0_memory_, data_plus_links0_size);
linkLists_ = (char **)malloc(sizeof(void *) * totalElementsStored_);
CHECK(linkLists_);
data_rearranged_.resize(totalElementsStored_);
for (size_t i = 0; i < totalElementsStored_; i++) {
SIZEMASS_TYPE linkListSize;
readBinaryPOD(input, linkListSize);
if (linkListSize == 0) {
linkLists_[i] = nullptr;
} else {
linkLists_[i] = (char *)malloc(linkListSize);
CHECK(linkLists_[i]);
input.read(linkLists_[i], linkListSize);
}
data_rearranged_[i] = new Object(data_level0_memory_ + (i)*memoryPerObject_ + offsetData_);
}
}
template <typename dist_t>
void
Hnsw<dist_t>::baseSearchAlgorithmOld(KNNQuery<dist_t> *query)
{
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *massVisited = vl->mass;
vl_type currentV = vl->curV;
HnswNode *provider;
int maxlevel1 = enterpoint_->level;
provider = enterpoint_;
const Object *currObj = provider->getData();
dist_t d = query->DistanceObjLeft(currObj);
dist_t curdist = d;
HnswNode *curNode = provider;
for (int i = maxlevel1; i > 0; i--) {
bool changed = true;
while (changed) {
changed = false;
const vector<HnswNode *> &neighbor = curNode->getAllFriends(i);
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
}
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < curdist) {
curdist = d;
curNode = *iter;
changed = true;
}
}
}
}
priority_queue<HnswNodeDistFarther<dist_t>> candidateQueue; // the set of elements which we can use to evaluate
priority_queue<HnswNodeDistCloser<dist_t>> closestDistQueue1; // The set of closest found elements
HnswNodeDistFarther<dist_t> ev(curdist, curNode);
candidateQueue.emplace(curdist, curNode);
closestDistQueue1.emplace(curdist, curNode);
query->CheckAndAddToResult(curdist, curNode->getData());
massVisited[curNode->getId()] = currentV;
// visitedQueue.insert(curNode->getId());
////////////////////////////////////////////////////////////////////////////////
// PHASE TWO OF THE SEARCH
// Extraction of the neighborhood to find k nearest neighbors.
////////////////////////////////////////////////////////////////////////////////
while (!candidateQueue.empty()) {
auto iter = candidateQueue.top(); // This one was already compared to the query
const HnswNodeDistFarther<dist_t> &currEv = iter;
// Check condition to end the search
dist_t lowerBound = closestDistQueue1.top().getDistance();
if (currEv.getDistance() > lowerBound) {
break;
}
HnswNode *initNode = currEv.getMSWNodeHier();
candidateQueue.pop();
const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(0);
size_t curId;
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
_mm_prefetch((char *)(massVisited + (*iter)->getId()), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();
if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (closestDistQueue1.top().getDistance() > d || closestDistQueue1.size() < ef_) {
{
query->CheckAndAddToResult(d, currObj);
candidateQueue.emplace(d, *iter);
closestDistQueue1.emplace(d, *iter);
if (closestDistQueue1.size() > ef_) {
closestDistQueue1.pop();
}
}
}
}
}
}
visitedlistpool->releaseVisitedList(vl);
}
template <typename dist_t>
void
Hnsw<dist_t>::baseSearchAlgorithmV1Merge(KNNQuery<dist_t> *query)
{
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *massVisited = vl->mass;
vl_type currentV = vl->curV;
HnswNode *provider;
int maxlevel1 = enterpoint_->level;
provider = enterpoint_;
const Object *currObj = provider->getData();
dist_t d = query->DistanceObjLeft(currObj);
dist_t curdist = d;
HnswNode *curNode = provider;
for (int i = maxlevel1; i > 0; i--) {
bool changed = true;
while (changed) {
changed = false;
const vector<HnswNode *> &neighbor = curNode->getAllFriends(i);
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
}
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < curdist) {
curdist = d;
curNode = *iter;
changed = true;
}
}
}
}
SortArrBI<dist_t, HnswNode *> sortedArr(max<size_t>(ef_, query->GetK()));
sortedArr.push_unsorted_grow(curdist, curNode);
int_fast32_t currElem = 0;
typedef typename SortArrBI<dist_t, HnswNode *>::Item QueueItem;
vector<QueueItem> &queueData = sortedArr.get_data();
vector<QueueItem> itemBuff(1 + max(maxM_, maxM0_));
massVisited[curNode->getId()] = currentV;
// visitedQueue.insert(curNode->getId());
////////////////////////////////////////////////////////////////////////////////
// PHASE TWO OF THE SEARCH
// Extraction of the neighborhood to find k nearest neighbors.
////////////////////////////////////////////////////////////////////////////////
while (currElem < min(sortedArr.size(), ef_)) {
auto &e = queueData[currElem];
CHECK(!e.used);
e.used = true;
HnswNode *initNode = e.data;
++currElem;
size_t itemQty = 0;
dist_t topKey = sortedArr.top_key();
const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(0);
size_t curId;
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
IdType curId = (*iter)->getId();
CHECK(curId >= 0 && curId < this->data_.size());
_mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();
if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < topKey || sortedArr.size() < ef_) {
CHECK_MSG(itemBuff.size() > itemQty,
"Perhaps a bug: buffer size is not enough " +
ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size()));
itemBuff[itemQty++] = QueueItem(d, *iter);
}
}
}
if (itemQty) {
_mm_prefetch(const_cast<const char *>(reinterpret_cast<char *>(&itemBuff[0])), _MM_HINT_T0);
std::sort(itemBuff.begin(), itemBuff.begin() + itemQty);
size_t insIndex = 0;
if (itemQty > MERGE_BUFFER_ALGO_SWITCH_THRESHOLD) {
insIndex = sortedArr.merge_with_sorted_items(&itemBuff[0], itemQty);
if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
} else {
for (size_t ii = 0; ii < itemQty; ++ii) {
size_t insIndex = sortedArr.push_or_replace_non_empty_exp(itemBuff[ii].key, itemBuff[ii].data);
if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
}
}
}
// To ensure that we either reach the end of the unexplored queue or currElem points to the first unused element
while (currElem < sortedArr.size() && queueData[currElem].used == true)
++currElem;
}
for (uint_fast32_t i = 0; i < query->GetK() && i < sortedArr.size(); ++i) {
query->CheckAndAddToResult(queueData[i].key, queueData[i].data->getData());
}
visitedlistpool->releaseVisitedList(vl);
}
template <typename dist_t>
void
Hnsw<dist_t>::baseSearchAlgorithmV1Merge(RangeQuery<dist_t> *query)
{
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *massVisited = vl->mass;
vl_type currentV = vl->curV;
HnswNode *provider;
int maxlevel1 = enterpoint_->level;
provider = enterpoint_;
const Object *currObj = provider->getData();
dist_t d = query->DistanceObjLeft(currObj);
dist_t curdist = d;
HnswNode *curNode = provider;
for (int i = maxlevel1; i > 0; i--) {
bool changed = true;
while (changed) {
changed = false;
const vector<HnswNode *> &neighbor = curNode->getAllFriends(i);
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
}
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < curdist) {
curdist = d;
curNode = *iter;
changed = true;
}
}
}
}
SortArrBI<dist_t, HnswNode *> sortedArr(max<size_t>(ef_, MAXIMUM_K)); // max<size_t>(ef_, query->GetK())
sortedArr.push_unsorted_grow(curdist, curNode);
int_fast32_t currElem = 0;
typedef typename SortArrBI<dist_t, HnswNode *>::Item QueueItem;
vector<QueueItem> &queueData = sortedArr.get_data();
vector<QueueItem> itemBuff(1 + max(maxM_, maxM0_));
massVisited[curNode->getId()] = currentV;
// visitedQueue.insert(curNode->getId());
////////////////////////////////////////////////////////////////////////////////
// PHASE TWO OF THE SEARCH
// Extraction of the neighborhood to find k nearest neighbors.
////////////////////////////////////////////////////////////////////////////////
while (currElem < min(sortedArr.size(), ef_)) {
auto &e = queueData[currElem];
CHECK(!e.used);
e.used = true;
HnswNode *initNode = e.data;
++currElem;
size_t itemQty = 0;
dist_t topKey = sortedArr.top_key();
const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(0);
size_t curId;
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
IdType curId = (*iter)->getId();
CHECK(curId >= 0 && curId < this->data_.size());
_mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();
if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < topKey || sortedArr.size() < ef_) {
CHECK_MSG(itemBuff.size() > itemQty,
"Perhaps a bug: buffer size is not enough " +
ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size()));
itemBuff[itemQty++] = QueueItem(d, *iter);
}
}
}
if (itemQty) {
_mm_prefetch(const_cast<const char *>(reinterpret_cast<char *>(&itemBuff[0])), _MM_HINT_T0);
std::sort(itemBuff.begin(), itemBuff.begin() + itemQty);
size_t insIndex = 0;
if (itemQty > MERGE_BUFFER_ALGO_SWITCH_THRESHOLD) {
insIndex = sortedArr.merge_with_sorted_items(&itemBuff[0], itemQty);
if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
} else {
for (size_t ii = 0; ii < itemQty; ++ii) {
size_t insIndex = sortedArr.push_or_replace_non_empty_exp(itemBuff[ii].key, itemBuff[ii].data);
if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
}
}
}
// To ensure that we either reach the end of the unexplored queue or currElem points to the first unused element
while (currElem < sortedArr.size() && queueData[currElem].used == true)
++currElem;
}
for (uint_fast32_t i = 0; i < sortedArr.size(); ++i) { // i < query->GetK() &&
query->CheckAndAddToResult(queueData[i].key, queueData[i].data->getData());
}
visitedlistpool->releaseVisitedList(vl);
}
// Experimental search algorithm
template <typename dist_t>
void
Hnsw<dist_t>::listPassingModifiedAlgorithm(KNNQuery<dist_t> *query)
{
int efSearchL = 4; // This parameters defines the confidence of searches at level higher than zero
// for zero level it is set to ef
// Getting the visitedlist
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *massVisited = vl->mass;
vl_type currentV = vl->curV;
int maxlevel1 = enterpoint_->level;
const Object *currObj = enterpoint_->getData();
dist_t d = query->DistanceObjLeft(currObj);
dist_t curdist = d;
HnswNode *curNode = enterpoint_;
priority_queue<HnswNodeDistFarther<dist_t>> candidateQueue; // the set of elements which we can use to evaluate
priority_queue<HnswNodeDistCloser<dist_t>> closestDistQueue =
priority_queue<HnswNodeDistCloser<dist_t>>(); // The set of closest found elements
priority_queue<HnswNodeDistCloser<dist_t>> closestDistQueueCpy = priority_queue<HnswNodeDistCloser<dist_t>>();
HnswNodeDistFarther<dist_t> ev(curdist, curNode);
candidateQueue.emplace(curdist, curNode);
closestDistQueue.emplace(curdist, curNode);
massVisited[curNode->getId()] = currentV;
for (int i = maxlevel1; i > 0; i--) {
while (!candidateQueue.empty()) {
auto iter = candidateQueue.top();
const HnswNodeDistFarther<dist_t> &currEv = iter;
// Check condtion to end the search
dist_t lowerBound = closestDistQueue.top().getDistance();
if (currEv.getDistance() > lowerBound) {
break;
}
HnswNode *initNode = currEv.getMSWNodeHier();
candidateQueue.pop();
const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(i);
size_t curId;
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
_mm_prefetch((char *)(massVisited + (*iter)->getId()), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();
if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (closestDistQueue.top().getDistance() > d || closestDistQueue.size() < efSearchL) {
candidateQueue.emplace(d, *iter);
closestDistQueue.emplace(d, *iter);
if (closestDistQueue.size() > efSearchL) {
closestDistQueue.pop();
}
}
}
}
}
// Updating the bitset key:
currentV++;
vl->curV++; // not to forget updating in the pool
if (currentV == 0) {
memset(massVisited, 0, ElList_.size() * sizeof(vl_type));
currentV++;
vl->curV++; // not to forget updating in the pool
}
candidateQueue = priority_queue<HnswNodeDistFarther<dist_t>>();
closestDistQueueCpy = priority_queue<HnswNodeDistCloser<dist_t>>(closestDistQueue);
if (i > 1) { // Passing the closest neighbors to layers higher than zero:
while (closestDistQueueCpy.size() > 0) {
massVisited[closestDistQueueCpy.top().getMSWNodeHier()->getId()] = currentV;
candidateQueue.emplace(closestDistQueueCpy.top().getDistance(), closestDistQueueCpy.top().getMSWNodeHier());
closestDistQueueCpy.pop();
}
} else { // Passing the closest neighbors to the 0 zero layer(one has to add also to query):
while (closestDistQueueCpy.size() > 0) {
massVisited[closestDistQueueCpy.top().getMSWNodeHier()->getId()] = currentV;
candidateQueue.emplace(closestDistQueueCpy.top().getDistance(), closestDistQueueCpy.top().getMSWNodeHier());
query->CheckAndAddToResult(closestDistQueueCpy.top().getDistance(),
closestDistQueueCpy.top().getMSWNodeHier()->getData());
closestDistQueueCpy.pop();
}
}
}
////////////////////////////////////////////////////////////////////////////////
// PHASE TWO OF THE SEARCH
// Extraction of the neighborhood to find k nearest neighbors.
////////////////////////////////////////////////////////////////////////////////
while (!candidateQueue.empty()) {
auto iter = candidateQueue.top();
const HnswNodeDistFarther<dist_t> &currEv = iter;
// Check condtion to end the search
dist_t lowerBound = closestDistQueue.top().getDistance();
if (currEv.getDistance() > lowerBound) {
break;
}
HnswNode *initNode = currEv.getMSWNodeHier();
candidateQueue.pop();
const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(0);
size_t curId;
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
_mm_prefetch((char *)(massVisited + (*iter)->getId()), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();
if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (closestDistQueue.top().getDistance() > d || closestDistQueue.size() < ef_) {
{
query->CheckAndAddToResult(d, currObj);
candidateQueue.emplace(d, *iter);
closestDistQueue.emplace(d, *iter);
if (closestDistQueue.size() > ef_) {
closestDistQueue.pop();
}
}
}
}
}
}
visitedlistpool->releaseVisitedList(vl);
}
template class Hnsw<float>;
template class Hnsw<double>;
template class Hnsw<int>;
}