Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Bugfixes for index loading and AddBatch, optional single threaded Del…
…eteBatch
  • Loading branch information
Лесцов Борис committed Nov 16, 2018
1 parent c2c46d5 commit ac07969
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 17 deletions.
5 changes: 5 additions & 0 deletions similarity_search/include/method/small_world_rand.h
Expand Up @@ -274,6 +274,11 @@ public:
void SetQueryTimeParams(const AnyParams& ) override;

enum PatchingStrategy { kNone = 0, kNeighborsOnly = 1 };

//This method should be called before LoadIndex to initialize parameters,
//that are usually initialized in Create Index
void InitParamsManually(const AnyParams& IndexParams);

private:

size_t NN_;
Expand Down
75 changes: 60 additions & 15 deletions similarity_search/src/method/small_world_rand.cc
Expand Up @@ -151,7 +151,21 @@ void SmallWorldRand<dist_t>::AddBatch(const ObjectVector& batchData,
<< " futureNextNodeId + 1 after batch addition: " << futureNextNodeId;

// 2) One entry should be added before all the threads are started, or else add() will not work properly
addCriticalSection(new MSWNode(batchData[0], NextNodeId_));


bool isEmpty = false;

{
unique_lock<mutex> lock(ElListGuard_);
isEmpty = ElList_.empty();
}
int start_add=0;

if (isEmpty){
addCriticalSection(new MSWNode(batchData[0], NextNodeId_));
start_add = 1;
}


unique_ptr<ProgressDisplay> progress_bar(bPrintProgress ?
new ProgressDisplay(batchData.size(), cerr)
Expand All @@ -160,7 +174,7 @@ void SmallWorldRand<dist_t>::AddBatch(const ObjectVector& batchData,
if (indexThreadQty_ <= 1) {
// Skip the first element, one element is already added
if (progress_bar) ++(*progress_bar);
for (size_t id = 1; id < batchData.size(); ++id) {
for (size_t id = start_add; id < batchData.size(); ++id) {
MSWNode* node = new MSWNode(batchData[id], id + NextNodeId_);
add(node, futureNextNodeId);
if (progress_bar) ++(*progress_bar);
Expand Down Expand Up @@ -250,21 +264,32 @@ void SmallWorldRand<dist_t>::DeleteBatch(const vector<IdType>& batchData, int de
mutex mtx;
vector<thread> threads;

for (size_t i = 0; i < indexThreadQty_; ++i) {
threads.push_back(thread(
[&]() {
MSWNode* node = nullptr;
vector<MSWNode*> cacheDelNode;
while(GetNextQueueObj(mtx, toPatchQueue, node)) {
if (kNone == patchStrat) node->removeGivenFriends(delNodesBitset);
else node->removeGivenFriendsPatchWithClosestNeighbor<dist_t>(space_, use_proxy_dist_,
delNodesBitset, cacheDelNode);
if (indexThreadQty_ <= 1) {
LOG(LIB_INFO) << "Single threaded batch delete: " << vToPatchNodes.size();
MSWNode* node = nullptr;
vector<MSWNode*> cacheDelNode;
while(GetNextQueueObj(mtx, toPatchQueue, node)) {
if (kNone == patchStrat) node->removeGivenFriends(delNodesBitset);
else node->removeGivenFriendsPatchWithClosestNeighbor<dist_t>(space_, use_proxy_dist_,
delNodesBitset, cacheDelNode);
}

} else {
for (size_t i = 0; i < indexThreadQty_; ++i) {
threads.push_back(thread(
[&]() {
MSWNode* node = nullptr;
vector<MSWNode*> cacheDelNode;
while(GetNextQueueObj(mtx, toPatchQueue, node)) {
if (kNone == patchStrat) node->removeGivenFriends(delNodesBitset);
else node->removeGivenFriendsPatchWithClosestNeighbor<dist_t>(space_, use_proxy_dist_,
delNodesBitset, cacheDelNode);
}
}
}
));
));
}
for (auto& thread : threads) thread.join();
}
for (auto& thread : threads) thread.join();


if (checkIDs) {
for (auto it : ElList_) {
Expand Down Expand Up @@ -337,6 +362,26 @@ void SmallWorldRand<dist_t>::CheckIDs() const
}
}

template <typename dist_t>
void SmallWorldRand<dist_t>::InitParamsManually(const AnyParams& IndexParams)
{
AnyParamManager pmgr(IndexParams);

pmgr.GetParamOptional("NN", NN_, 10);
pmgr.GetParamOptional("efConstruction", efConstruction_, NN_);
efSearch_ = NN_;
pmgr.GetParamOptional("indexThreadQty", indexThreadQty_, thread::hardware_concurrency());
pmgr.GetParamOptional("useProxyDist", use_proxy_dist_, false);

LOG(LIB_INFO) << "NN = " << NN_;
LOG(LIB_INFO) << "efConstruction_ = " << efConstruction_;
LOG(LIB_INFO) << "indexThreadQty = " << indexThreadQty_;
LOG(LIB_INFO) << "useProxyDist = " << use_proxy_dist_;

pmgr.CheckUnused();
}


template <typename dist_t>
void SmallWorldRand<dist_t>::CreateIndex(const AnyParams& IndexParams)
{
Expand Down
6 changes: 4 additions & 2 deletions test_batch_app/test_batch_mod.cc
Expand Up @@ -194,10 +194,12 @@ void doWork(int argc, char* argv[]) {

CHECK_MSG(knnK > 0, "k-NN k should be > 0!");

int seed = 0;

if (LogFile != "")
initLibrary(LIB_LOGFILE, LogFile.c_str());
initLibrary(seed, LIB_LOGFILE, LogFile.c_str());
else
initLibrary(LIB_LOGSTDERR, NULL); // Use STDERR for logging
initLibrary(seed, LIB_LOGSTDERR, NULL); // Use STDERR for logging

unique_ptr<Space<float>> space(SpaceFactoryRegistry<float>::Instance().CreateSpace(SpaceType, *SpaceParams));

Expand Down

0 comments on commit ac07969

Please sign in to comment.