Skip to content

Commit

Permalink
updates for #249 and sort of a fix for #274
Browse files Browse the repository at this point in the history
  • Loading branch information
searchivairus committed Jan 29, 2018
1 parent 2a4ee7b commit c6ac418
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 144 deletions.
177 changes: 45 additions & 132 deletions similarity_search/include/experiments.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "eval_results.h"
#include "meta_analysis.h"
#include "query_creator.h"
#include "thread_pool.h"

namespace similarity {

Expand Down Expand Up @@ -104,102 +105,6 @@ class Experiments {
if (LogInfo) LOG(LIB_INFO) << "experiment done at " << LibGetCurrentTime();
}

template <typename QueryType, typename QueryCreatorType>
struct BenchmarkThreadParams {
BenchmarkThreadParams(
mutex& UpdateStat,
unsigned ThreadQty,
unsigned QueryPart,
size_t TestSetId,
std::vector<MetaAnalysis*>& ExpRes,
const ExperimentConfig<dist_t>& config,
const QueryCreatorType& QueryCreator,
const Index<dist_t>& Method,
unsigned MethNum,
vector<uint64_t>& SearchTime,
vector<double>& AvgNumDistComp,
vector<unsigned>& max_result_size,
vector<double>& avg_result_size,
vector<uint64_t>& DistCompQty) :
UpdateStat_(UpdateStat),
ThreadQty_(ThreadQty),
QueryPart_(QueryPart),
TestSetId_(TestSetId),
ExpRes_(ExpRes),
config_(config),
QueryCreator_(QueryCreator),
Method_(Method),
MethNum_(MethNum),
SearchTime_(SearchTime),

AvgNumDistComp_(AvgNumDistComp),
max_result_size_(max_result_size),
avg_result_size_(avg_result_size),
DistCompQty_(DistCompQty)
{}

mutex& UpdateStat_;
unsigned ThreadQty_;
unsigned QueryPart_;
size_t TestSetId_;
std::vector<MetaAnalysis*>& ExpRes_;
const ExperimentConfig<dist_t>& config_;
const QueryCreatorType& QueryCreator_;
const Index<dist_t>& Method_;
unsigned MethNum_;
vector<uint64_t>& SearchTime_;

vector<double>& AvgNumDistComp_;
vector<unsigned>& max_result_size_;
vector<double>& avg_result_size_;
vector<uint64_t>& DistCompQty_;

vector<size_t> queryIds;
vector<unique_ptr<QueryType>> queries; // queries with results
};

template <typename QueryType, typename QueryCreatorType>
struct BenchmarkThread {
void operator ()(BenchmarkThreadParams<QueryType, QueryCreatorType>& prm) {
size_t numquery = prm.config_.GetQueryObjects().size();

WallClockTimer wtm;

wtm.reset();

unsigned MethNum = prm.MethNum_;
unsigned QueryPart = prm.QueryPart_;
unsigned ThreadQty = prm.ThreadQty_;

for (size_t q = 0; q < numquery; ++q) {
if ((q % ThreadQty) == QueryPart) {
unique_ptr<QueryType> query(prm.QueryCreator_(prm.config_.GetSpace(),
prm.config_.GetQueryObjects()[q]));
uint64_t t1 = wtm.split();
prm.Method_.Search(query.get());
uint64_t t2 = wtm.split();

{
lock_guard<mutex> g(prm.UpdateStat_);

prm.ExpRes_[MethNum]->AddDistComp(prm.TestSetId_, query->DistanceComputations());
prm.ExpRes_[MethNum]->AddQueryTime(prm.TestSetId_, (1.0*t2 - t1)/1e3);


prm.DistCompQty_[MethNum] += query->DistanceComputations();
prm.avg_result_size_[MethNum] += query->ResultSize();

if (query->ResultSize() > prm.max_result_size_[MethNum]) {
prm.max_result_size_[MethNum] = query->ResultSize();
}

prm.queryIds.push_back(q);
prm.queries.push_back(std::move(query));
}
}
}
}
};

template <typename QueryType, typename QueryCreatorType>
static void Execute(bool LogInfo, unsigned ThreadTestQty, size_t TestSetId,
Expand Down Expand Up @@ -259,41 +164,51 @@ class Experiments {

if (!ThreadTestQty) ThreadTestQty = 1;

vector<BenchmarkThreadParams<QueryType, QueryCreatorType>*> ThreadParams(ThreadTestQty);
vector<thread> Threads(ThreadTestQty);
AutoVectDel<BenchmarkThreadParams<QueryType, QueryCreatorType>> DelThreadParams(ThreadParams);
vector<vector<size_t>> QueryIds;
vector<vector<unique_ptr<QueryType>>> Queries; // queries with results

QueryIds.resize(ThreadTestQty);
Queries.resize(ThreadTestQty);

for (unsigned QueryPart = 0; QueryPart < ThreadTestQty; ++QueryPart) {
ThreadParams[QueryPart] = new BenchmarkThreadParams<QueryType, QueryCreatorType>(
UpdateStat,
ThreadTestQty,
QueryPart,
TestSetId,
ExpRes,
config,
QueryCreator,
Method,
MethNum,
SearchTime,
AvgNumDistComp,
max_result_size,
avg_result_size,
DistCompQty);
}
/*
* Because each thread uses its own parameter set, we must use
* exactly ThreadTestQty sets.
*/
ParallelFor(0, ThreadTestQty, ThreadTestQty, [&](unsigned QueryPart) {
size_t numquery = config.GetQueryObjects().size();

if (ThreadTestQty> 1) {
for (unsigned QueryPart = 0; QueryPart < ThreadTestQty; ++QueryPart) {
Threads[QueryPart] = std::thread(BenchmarkThread<QueryType, QueryCreatorType>(),
ref(*ThreadParams[QueryPart]));
}
for (unsigned QueryPart = 0; QueryPart < ThreadTestQty; ++QueryPart) {
Threads[QueryPart].join();
WallClockTimer wtm;

wtm.reset();

for (size_t q = 0; q < numquery; ++q) {
if ((q % ThreadTestQty) == QueryPart) {
unique_ptr<QueryType> query(QueryCreator(config.GetSpace(),
config.GetQueryObjects()[q]));
uint64_t t1 = wtm.split();
Method.Search(query.get());
uint64_t t2 = wtm.split();

{
lock_guard<mutex> g(UpdateStat);

ExpRes[MethNum]->AddDistComp(TestSetId, query->DistanceComputations());
ExpRes[MethNum]->AddQueryTime(TestSetId, (1.0*t2 - t1)/1e3);


DistCompQty[MethNum] += query->DistanceComputations();
avg_result_size[MethNum] += query->ResultSize();

if (query->ResultSize() > max_result_size[MethNum]) {
max_result_size[MethNum] = query->ResultSize();
}

QueryIds[QueryPart].push_back(q);
Queries[QueryPart].push_back(std::move(query));
}
}
}
} else {
CHECK(ThreadTestQty == 1);
BenchmarkThread<QueryType, QueryCreatorType>()(*ThreadParams[0]);
}
});

wtm.split();

Expand All @@ -309,11 +224,9 @@ class Experiments {
if (LogInfo) LOG(LIB_INFO) << ">>>> Computing effectiveness metrics for " << Method.StrDesc();

for (unsigned QueryPart = 0; QueryPart < ThreadTestQty; ++QueryPart) {
const BenchmarkThreadParams<QueryType, QueryCreatorType>* params = ThreadParams[QueryPart];

for (size_t qi = 0; qi < params->queries.size(); ++qi) {
size_t q = params->queryIds[qi] ;
const QueryType* pQuery = params->queries[qi].get();
for (size_t qi = 0; qi < Queries[QueryPart].size(); ++qi) {
size_t q = QueryIds[QueryPart][qi] ;
const QueryType* pQuery = Queries[QueryPart][qi].get();

unique_ptr<QueryType> queryGS(QueryCreator(config.GetSpace(), config.GetQueryObjects()[q]));

Expand Down
5 changes: 5 additions & 0 deletions similarity_search/include/ported_boost_progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ class ProgressDisplay {
return _count;
}

// Effects: increments enough to display a complete progress
void finish() {
operator+=(expected_count() - count());
}

unsigned long operator++() { return operator+=( 1 ); }
unsigned long count() const { return _count; }
unsigned long expected_count() const { return _expected_count; }
Expand Down
21 changes: 13 additions & 8 deletions similarity_search/include/sort_arr_bi.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ class SortArrBI {
}

void sort() {
_mm_prefetch(&v_[0], _MM_HINT_T0);
if (!v.empty())
_mm_prefetch(&v_[0], _MM_HINT_T0);
std::sort(v_.begin(), v_.begin() + num_elems_);
}

Expand All @@ -86,7 +87,7 @@ class SortArrBI {
// it also assumes a non-empty array
size_t push_or_replace_non_empty(const KeyType& key, const DataType& data) {
// num_elems_ > 0
unsigned curr = num_elems_ - 1;
size_t curr = num_elems_ - 1;
if (v_[curr].key <= key) {
if (num_elems_ < v_.size()) {
v_[num_elems_].used = false;
Expand All @@ -99,15 +100,17 @@ class SortArrBI {
}

while (curr > 0) {
unsigned j = curr - 1;
size_t j = curr - 1;
if (v_[j].key <= key) break;
curr = j;
}

if (num_elems_ < v_.size()) num_elems_++;
// curr + 1 <= num_elems_
_mm_prefetch((char *)&v_[curr], _MM_HINT_T0);
memmove((char *)&v_[curr+1], &v_[curr], (num_elems_ - (1 + curr)) * sizeof(v_[0]));

if (num_elems_ - (1 + curr) > 0)
memmove((char *)&v_[curr+1], &v_[curr], (num_elems_ - (1 + curr)) * sizeof(v_[0]));

v_[curr].used = false;
v_[curr].key = key;
Expand Down Expand Up @@ -150,9 +153,11 @@ class SortArrBI {
return ret;
}

// Checking for duplicate IDs isn't the responsibility of this function
// it also assumes a non-empty array
size_t push_or_replace_non_empty_exp(const KeyType& key, const DataType& data) {
// num_elems_ > 0
unsigned curr = num_elems_ - 1;
size_t curr = num_elems_ - 1;
if (v_[curr].key <= key) {
if (num_elems_ < v_.size()) {
v_[num_elems_].used = false;
Expand All @@ -163,9 +168,9 @@ class SortArrBI {
return num_elems_;
}
}
unsigned prev = curr;
size_t prev = curr;

unsigned d=1;
size_t d=1;
// always curr >= d
while (curr > 0 && v_[curr].key > key) {
prev = curr;
Expand All @@ -182,7 +187,7 @@ class SortArrBI {
if (num_elems_ < v_.size()) num_elems_++;
// curr + 1 <= num_elems_

if(num_elems_ - (1 + curr) > 0)
if (num_elems_ - (1 + curr) > 0)
memmove(&v_[curr+1], &v_[curr], (num_elems_ - (1 + curr)) * sizeof(v_[0]));


Expand Down
14 changes: 10 additions & 4 deletions similarity_search/src/method/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,12 @@ namespace similarity {
{
unique_lock<mutex> lock(ElListGuard_);
ElList_[id] = node;
if (progress_bar)
++(*progress_bar);
}
if (progress_bar)
++(*progress_bar);
});
if (progress_bar)
progress_bar->finish();

if (post_ == 1 || post_ == 2) {
vector<HnswNode *> temp;
Expand All @@ -239,9 +241,11 @@ namespace similarity {
{
unique_lock<mutex> lock(ElListGuard_);
ElList_[id] = node;
if (progress_bar1)
++(*progress_bar1);
}
if (progress_bar1)
++(*progress_bar1);
progress_bar1->finish();
});
int maxF = 0;

Expand Down Expand Up @@ -1209,7 +1213,9 @@ namespace similarity {

for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
_mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0);
_mm_prefetch((char *)(massVisited + (*iter)->getId()), _MM_HINT_T0);
IdType curId = (*iter)->getId();
CHECK(curId >= 0 && curId < this->data_.size());
_mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0);
}
// calculate distance to each neighbor
for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) {
Expand Down

0 comments on commit c6ac418

Please sign in to comment.