Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ChunjiangZhu committed Oct 28, 2019
1 parent 0bcd267 commit 5bdf053
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 2 deletions.
1 change: 1 addition & 0 deletions similarity_search/include/method/small_world_rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ class SmallWorldRand : public Index<dist_t> {
void SearchOld(KNNQuery<dist_t>* query) const;
void SearchV1Merge(KNNQuery<dist_t>* query) const;

void SearchOld(RangeQuery<dist_t>* query) const;
void SearchV1Merge(RangeQuery<dist_t>* query) const;

void UpdateNextNodeId(size_t newNextNodeId);
Expand Down
84 changes: 82 additions & 2 deletions similarity_search/src/method/small_world_rand.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ void SmallWorldRand<dist_t>::addCriticalSection(MSWNode *newElement){
template <typename dist_t>
void SmallWorldRand<dist_t>::Search(RangeQuery<dist_t>* query, IdType) const {
// throw runtime_error("Range search is not supported!");
if (searchAlgoType_ == kV1Merge) { std::cerr << "call SearchV1Merge" << endl; SearchV1Merge(query);}
// else SearchOld(query);
if (searchAlgoType_ == kV1Merge) SearchV1Merge(query);
else SearchOld(query);
}

template <typename dist_t>
Expand Down Expand Up @@ -909,6 +909,86 @@ void SmallWorldRand<dist_t>::SearchOld(KNNQuery<dist_t>* query) const {
}
}

template <typename dist_t>
void SmallWorldRand<dist_t>::SearchOld(RangeQuery<dist_t>* query) const {

if (ElList_.empty()) return;
CHECK_MSG(efSearch_ > 0, "efSearch should be > 0");
/*
* The trick of using large dense bitsets instead of unordered_set was
* borrowed from Wei Dong's kgraph: https://github.com/aaalgo/kgraph
*
* This trick works really well even in a multi-threaded mode. Indeed, the amount
* of allocated memory is small. For example, if one has 8M entries, the size of
* the bitmap is merely 1 MB. Furthermore, setting 1MB of entries to zero via memset would take only
* a fraction of millisecond.
*/
vector<bool> visitedBitset(NextNodeId_);

MSWNode* provider = pEntryPoint_;
CHECK_MSG(provider != nullptr, "Bug: there is not entry point set!")

priority_queue <dist_t> closestDistQueue; //The set of all elements which distance was calculated
priority_queue <EvaluatedMSWNodeReverse<dist_t>> candidateQueue; //the set of elements which we can use to evaluate

const Object* currObj = provider->getData();
dist_t d = query->DistanceObjLeft(currObj);
query->CheckAndAddToResult(d, currObj); // This should be done before the object goes to the queue: otherwise it will not be compared to the query at all!

EvaluatedMSWNodeReverse<dist_t> ev(d, provider);
candidateQueue.push(ev);
closestDistQueue.emplace(d);

IdType nodeId = provider->getId();
CHECK_MSG(nodeId < NextNodeId_, "Bug: nodeId (" + ConvertToString(nodeId) + ") > NextNodeId_ (" +ConvertToString(NextNodeId_) + ")");
visitedBitset[nodeId] = true;

while(!candidateQueue.empty()){

auto iter = candidateQueue.top(); // This one was already compared to the query
const EvaluatedMSWNodeReverse<dist_t>& currEv = iter;

// Did we reach a local minimum?
if (currEv.getDistance() > closestDistQueue.top()) {
break;
}

for (MSWNode* neighbor : (currEv.getMSWNode())->getAllFriends()) {
_mm_prefetch(reinterpret_cast<const char*>(const_cast<const Object*>(neighbor->getData())), _MM_HINT_T0);
}
for (MSWNode* neighbor : (currEv.getMSWNode())->getAllFriends()) {
_mm_prefetch(const_cast<const char*>(neighbor->getData()->data()), _MM_HINT_T0);
}

const vector<MSWNode*>& neighbor = (currEv.getMSWNode())->getAllFriends();

// Can't access curEv anymore! The reference would become invalid
candidateQueue.pop();

//calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter){
nodeId = (*iter)->getId();
CHECK_MSG(nodeId < NextNodeId_, "Bug: nodeId (" + ConvertToString(nodeId) + ") > NextNodeId_ (" +ConvertToString(NextNodeId_));
if (!visitedBitset[nodeId]) {
currObj = (*iter)->getData();
d = query->DistanceObjLeft(currObj);
visitedBitset[nodeId] = true;

if (closestDistQueue.size() < efSearch_ || d < closestDistQueue.top()) {
closestDistQueue.emplace(d);
if (closestDistQueue.size() > efSearch_) {
closestDistQueue.pop();
}

candidateQueue.emplace(d, *iter);
}

query->CheckAndAddToResult(d, currObj);
}
}
}
}

template <typename dist_t>
void SmallWorldRand<dist_t>::SaveIndex(const string &location) {
CHECK_MSG(!changedAfterCreateIndex_,
Expand Down

0 comments on commit 5bdf053

Please sign in to comment.