From 5bdf0534bc53eafdad8ee3c992e1ac7b97564ddf Mon Sep 17 00:00:00 2001 From: ChunjiangZhu Date: Mon, 28 Oct 2019 12:42:46 -0400 Subject: [PATCH] fix bug --- .../include/method/small_world_rand.h | 1 + .../src/method/small_world_rand.cc | 84 ++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/similarity_search/include/method/small_world_rand.h b/similarity_search/include/method/small_world_rand.h index f8fc350..9e3ebee 100644 --- a/similarity_search/include/method/small_world_rand.h +++ b/similarity_search/include/method/small_world_rand.h @@ -302,6 +302,7 @@ class SmallWorldRand : public Index { void SearchOld(KNNQuery* query) const; void SearchV1Merge(KNNQuery* query) const; + void SearchOld(RangeQuery* query) const; void SearchV1Merge(RangeQuery* query) const; void UpdateNextNodeId(size_t newNextNodeId); diff --git a/similarity_search/src/method/small_world_rand.cc b/similarity_search/src/method/small_world_rand.cc index 32aa605..4498f99 100644 --- a/similarity_search/src/method/small_world_rand.cc +++ b/similarity_search/src/method/small_world_rand.cc @@ -601,8 +601,8 @@ void SmallWorldRand::addCriticalSection(MSWNode *newElement){ template void SmallWorldRand::Search(RangeQuery* query, IdType) const { // throw runtime_error("Range search is not supported!"); - if (searchAlgoType_ == kV1Merge) { std::cerr << "call SearchV1Merge" << endl; SearchV1Merge(query);} -// else SearchOld(query); + if (searchAlgoType_ == kV1Merge) SearchV1Merge(query); + else SearchOld(query); } template @@ -909,6 +909,86 @@ void SmallWorldRand::SearchOld(KNNQuery* query) const { } } +template +void SmallWorldRand::SearchOld(RangeQuery* query) const { + + if (ElList_.empty()) return; + CHECK_MSG(efSearch_ > 0, "efSearch should be > 0"); +/* + * The trick of using large dense bitsets instead of unordered_set was + * borrowed from Wei Dong's kgraph: https://github.com/aaalgo/kgraph + * + * This trick works really well even in a multi-threaded mode. Indeed, the amount + * of allocated memory is small. For example, if one has 8M entries, the size of + * the bitmap is merely 1 MB. Furthermore, setting 1MB of entries to zero via memset would take only + * a fraction of millisecond. + */ + vector visitedBitset(NextNodeId_); + + MSWNode* provider = pEntryPoint_; + CHECK_MSG(provider != nullptr, "Bug: there is not entry point set!") + + priority_queue closestDistQueue; //The set of all elements which distance was calculated + priority_queue > candidateQueue; //the set of elements which we can use to evaluate + + const Object* currObj = provider->getData(); + dist_t d = query->DistanceObjLeft(currObj); + query->CheckAndAddToResult(d, currObj); // This should be done before the object goes to the queue: otherwise it will not be compared to the query at all! + + EvaluatedMSWNodeReverse ev(d, provider); + candidateQueue.push(ev); + closestDistQueue.emplace(d); + + IdType nodeId = provider->getId(); + CHECK_MSG(nodeId < NextNodeId_, "Bug: nodeId (" + ConvertToString(nodeId) + ") > NextNodeId_ (" +ConvertToString(NextNodeId_) + ")"); + visitedBitset[nodeId] = true; + + while(!candidateQueue.empty()){ + + auto iter = candidateQueue.top(); // This one was already compared to the query + const EvaluatedMSWNodeReverse& currEv = iter; + + // Did we reach a local minimum? + if (currEv.getDistance() > closestDistQueue.top()) { + break; + } + + for (MSWNode* neighbor : (currEv.getMSWNode())->getAllFriends()) { + _mm_prefetch(reinterpret_cast(const_cast(neighbor->getData())), _MM_HINT_T0); + } + for (MSWNode* neighbor : (currEv.getMSWNode())->getAllFriends()) { + _mm_prefetch(const_cast(neighbor->getData()->data()), _MM_HINT_T0); + } + + const vector& neighbor = (currEv.getMSWNode())->getAllFriends(); + + // Can't access curEv anymore! The reference would become invalid + candidateQueue.pop(); + + //calculate distance to each neighbor + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter){ + nodeId = (*iter)->getId(); + CHECK_MSG(nodeId < NextNodeId_, "Bug: nodeId (" + ConvertToString(nodeId) + ") > NextNodeId_ (" +ConvertToString(NextNodeId_)); + if (!visitedBitset[nodeId]) { + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + visitedBitset[nodeId] = true; + + if (closestDistQueue.size() < efSearch_ || d < closestDistQueue.top()) { + closestDistQueue.emplace(d); + if (closestDistQueue.size() > efSearch_) { + closestDistQueue.pop(); + } + + candidateQueue.emplace(d, *iter); + } + + query->CheckAndAddToResult(d, currObj); + } + } + } +} + template void SmallWorldRand::SaveIndex(const string &location) { CHECK_MSG(!changedAfterCreateIndex_,