Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
hnsw add rangequery
  • Loading branch information
ChunjiangZhu committed Oct 25, 2019
1 parent d9cf5dc commit fc2cfd1
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 1 deletion.
Binary file modified .DS_Store
Binary file not shown.
Binary file modified similarity_search/.DS_Store
Binary file not shown.
Binary file modified similarity_search/src/.DS_Store
Binary file not shown.
168 changes: 167 additions & 1 deletion similarity_search/src/method/hnsw.cc
Expand Up @@ -677,7 +677,43 @@ namespace similarity {
void
Hnsw<dist_t>::Search(RangeQuery<dist_t> *query, IdType) const
{
throw runtime_error("Range search is not supported!");
// throw runtime_error("Range search is not supported!");
if (this->data_.empty() && this->data_rearranged_.empty()) {
return;
}
bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000);
// cout << "Ef = " << ef_ << " use old = " << useOld << endl;
switch (searchMethod_) {
default:
throw runtime_error("Invalid searchMethod: " + ConvertToString(searchMethod_));
break;
case 0:
/// Basic search using Nmslib data structure:
// if (useOld)
// const_cast<Hnsw *>(this)->baseSearchAlgorithmOld(query);
// else
const_cast<Hnsw *>(this)->baseSearchAlgorithmV1Merge(query);
break;
// case 1:
// /// Experimental search using Nmslib data structure (should not be used):
// const_cast<Hnsw *>(this)->listPassingModifiedAlgorithm(query);
// break;
// case 3:
// /// Basic search using optimized index(cosine+L2)
// if (useOld)
// const_cast<Hnsw *>(this)->SearchL2CustomOld(query);
// else
// const_cast<Hnsw *>(this)->SearchL2CustomV1Merge(query);
// break;
// case 4:
// /// Basic search using optimized index with one-time normalized cosine similarity
// /// Only for cosine similarity!
// if (useOld)
// const_cast<Hnsw *>(this)->SearchCosineNormalizedOld(query);
// else
// const_cast<Hnsw *>(this)->SearchCosineNormalizedV1Merge(query);
// break;
};
}

template <typename dist_t>
Expand Down Expand Up @@ -1274,6 +1310,136 @@ namespace similarity {

visitedlistpool->releaseVisitedList(vl);
}

template <typename dist_t>
void
Hnsw<dist_t>::baseSearchAlgorithmV1Merge(RangeQuery<dist_t> *query)
{
VisitedList *vl = visitedlistpool->getFreeVisitedList();
vl_type *massVisited = vl->mass;
vl_type currentV = vl->curV;

HnswNode *provider;
int maxlevel1 = enterpoint_->level;
provider = enterpoint_;

const Object *currObj = provider->getData();

dist_t d = query->DistanceObjLeft(currObj);
dist_t curdist = d;
HnswNode *curNode = provider;
for (int i = maxlevel1; i > 0; i--) {
bool changed = true;
while (changed) {
changed = false;

const vector<HnswNode *> &neighbor = curNode->getAllFriends(i);
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
}
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
if (d < curdist) {
curdist = d;
curNode = *iter;
changed = true;
}
}
}
}

SortArrBI<dist_t, HnswNode *> sortedArr(ef_); // max<size_t>(ef_, query->GetK())
sortedArr.push_unsorted_grow(curdist, curNode);

int_fast32_t currElem = 0;

typedef typename SortArrBI<dist_t, HnswNode *>::Item QueueItem;
vector<QueueItem> &queueData = sortedArr.get_data();
vector<QueueItem> itemBuff(1 + max(maxM_, maxM0_));

massVisited[curNode->getId()] = currentV;
// visitedQueue.insert(curNode->getId());

////////////////////////////////////////////////////////////////////////////////
// PHASE TWO OF THE SEARCH
// Extraction of the neighborhood to find k nearest neighbors.
////////////////////////////////////////////////////////////////////////////////

while (currElem < min(sortedArr.size(), ef_)) {
auto &e = queueData[currElem];
CHECK(!e.used);
e.used = true;
HnswNode *initNode = e.data;
++currElem;

size_t itemQty = 0;
dist_t topKey = sortedArr.top_key();

const vector<HnswNode *> &neighbor = (initNode)->getAllFriends(0);

size_t curId;

for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
IdType curId = (*iter)->getId();
CHECK(curId >= 0 && curId < this->data_.size());
_mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
curId = (*iter)->getId();

if (!(massVisited[curId] == currentV)) {
massVisited[curId] = currentV;
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);

if (d < topKey || sortedArr.size() < ef_) {
CHECK_MSG(itemBuff.size() > itemQty,
"Perhaps a bug: buffer size is not enough " +
ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size()));
itemBuff[itemQty++] = QueueItem(d, *iter);
}
}
}

if (itemQty) {
_mm_prefetch(const_cast<const char *>(reinterpret_cast<char *>(&itemBuff[0])), _MM_HINT_T0);
std::sort(itemBuff.begin(), itemBuff.begin() + itemQty);

size_t insIndex = 0;
if (itemQty > MERGE_BUFFER_ALGO_SWITCH_THRESHOLD) {
insIndex = sortedArr.merge_with_sorted_items(&itemBuff[0], itemQty);

if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
} else {
for (size_t ii = 0; ii < itemQty; ++ii) {
size_t insIndex = sortedArr.push_or_replace_non_empty_exp(itemBuff[ii].key, itemBuff[ii].data);

if (insIndex < currElem) {
// LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex;
currElem = insIndex;
}
}
}
}
// To ensure that we either reach the end of the unexplored queue or currElem points to the first unused element
while (currElem < sortedArr.size() && queueData[currElem].used == true)
++currElem;
}

for (uint_fast32_t i = 0; i < sortedArr.size(); ++i) { // i < query->GetK() &&
query->CheckAndAddToResult(queueData[i].key, queueData[i].data->getData());
}

visitedlistpool->releaseVisitedList(vl);
}


// Experimental search algorithm
template <typename dist_t>
void
Expand Down

0 comments on commit fc2cfd1

Please sign in to comment.