diff --git a/.DS_Store b/.DS_Store index 150336f..f522d6d 100644 Binary files a/.DS_Store and b/.DS_Store differ diff --git a/similarity_search/.DS_Store b/similarity_search/.DS_Store index 7f148cd..803ba20 100644 Binary files a/similarity_search/.DS_Store and b/similarity_search/.DS_Store differ diff --git a/similarity_search/src/.DS_Store b/similarity_search/src/.DS_Store index 6db9917..45dfbda 100644 Binary files a/similarity_search/src/.DS_Store and b/similarity_search/src/.DS_Store differ diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc index f5c7fca..4813619 100644 --- a/similarity_search/src/method/hnsw.cc +++ b/similarity_search/src/method/hnsw.cc @@ -677,7 +677,43 @@ namespace similarity { void Hnsw::Search(RangeQuery *query, IdType) const { - throw runtime_error("Range search is not supported!"); +// throw runtime_error("Range search is not supported!"); + if (this->data_.empty() && this->data_rearranged_.empty()) { + return; + } + bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000); + // cout << "Ef = " << ef_ << " use old = " << useOld << endl; + switch (searchMethod_) { + default: + throw runtime_error("Invalid searchMethod: " + ConvertToString(searchMethod_)); + break; + case 0: + /// Basic search using Nmslib data structure: +// if (useOld) +// const_cast(this)->baseSearchAlgorithmOld(query); +// else + const_cast(this)->baseSearchAlgorithmV1Merge(query); + break; +// case 1: +// /// Experimental search using Nmslib data structure (should not be used): +// const_cast(this)->listPassingModifiedAlgorithm(query); +// break; +// case 3: +// /// Basic search using optimized index(cosine+L2) +// if (useOld) +// const_cast(this)->SearchL2CustomOld(query); +// else +// const_cast(this)->SearchL2CustomV1Merge(query); +// break; +// case 4: +// /// Basic search using optimized index with one-time normalized cosine similarity +// /// Only for cosine similarity! +// if (useOld) +// const_cast(this)->SearchCosineNormalizedOld(query); +// else +// const_cast(this)->SearchCosineNormalizedV1Merge(query); +// break; + }; } template @@ -1274,6 +1310,136 @@ namespace similarity { visitedlistpool->releaseVisitedList(vl); } + + template + void + Hnsw::baseSearchAlgorithmV1Merge(RangeQuery *query) + { + VisitedList *vl = visitedlistpool->getFreeVisitedList(); + vl_type *massVisited = vl->mass; + vl_type currentV = vl->curV; + + HnswNode *provider; + int maxlevel1 = enterpoint_->level; + provider = enterpoint_; + + const Object *currObj = provider->getData(); + + dist_t d = query->DistanceObjLeft(currObj); + dist_t curdist = d; + HnswNode *curNode = provider; + for (int i = maxlevel1; i > 0; i--) { + bool changed = true; + while (changed) { + changed = false; + + const vector &neighbor = curNode->getAllFriends(i); + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + _mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0); + } + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + if (d < curdist) { + curdist = d; + curNode = *iter; + changed = true; + } + } + } + } + + SortArrBI sortedArr(ef_); // max(ef_, query->GetK()) + sortedArr.push_unsorted_grow(curdist, curNode); + + int_fast32_t currElem = 0; + + typedef typename SortArrBI::Item QueueItem; + vector &queueData = sortedArr.get_data(); + vector itemBuff(1 + max(maxM_, maxM0_)); + + massVisited[curNode->getId()] = currentV; + // visitedQueue.insert(curNode->getId()); + + //////////////////////////////////////////////////////////////////////////////// + // PHASE TWO OF THE SEARCH + // Extraction of the neighborhood to find k nearest neighbors. + //////////////////////////////////////////////////////////////////////////////// + + while (currElem < min(sortedArr.size(), ef_)) { + auto &e = queueData[currElem]; + CHECK(!e.used); + e.used = true; + HnswNode *initNode = e.data; + ++currElem; + + size_t itemQty = 0; + dist_t topKey = sortedArr.top_key(); + + const vector &neighbor = (initNode)->getAllFriends(0); + + size_t curId; + + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + _mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0); + IdType curId = (*iter)->getId(); + CHECK(curId >= 0 && curId < this->data_.size()); + _mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0); + } + // calculate distance to each neighbor + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + curId = (*iter)->getId(); + + if (!(massVisited[curId] == currentV)) { + massVisited[curId] = currentV; + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + + if (d < topKey || sortedArr.size() < ef_) { + CHECK_MSG(itemBuff.size() > itemQty, + "Perhaps a bug: buffer size is not enough " + + ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size())); + itemBuff[itemQty++] = QueueItem(d, *iter); + } + } + } + + if (itemQty) { + _mm_prefetch(const_cast(reinterpret_cast(&itemBuff[0])), _MM_HINT_T0); + std::sort(itemBuff.begin(), itemBuff.begin() + itemQty); + + size_t insIndex = 0; + if (itemQty > MERGE_BUFFER_ALGO_SWITCH_THRESHOLD) { + insIndex = sortedArr.merge_with_sorted_items(&itemBuff[0], itemQty); + + if (insIndex < currElem) { + // LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex; + currElem = insIndex; + } + } else { + for (size_t ii = 0; ii < itemQty; ++ii) { + size_t insIndex = sortedArr.push_or_replace_non_empty_exp(itemBuff[ii].key, itemBuff[ii].data); + + if (insIndex < currElem) { + // LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex; + currElem = insIndex; + } + } + } + } + // To ensure that we either reach the end of the unexplored queue or currElem points to the first unused element + while (currElem < sortedArr.size() && queueData[currElem].used == true) + ++currElem; + } + + for (uint_fast32_t i = 0; i < sortedArr.size(); ++i) { // i < query->GetK() && + query->CheckAndAddToResult(queueData[i].key, queueData[i].data->getData()); + } + + visitedlistpool->releaseVisitedList(vl); + } + + // Experimental search algorithm template void