Skip to content

Commit

Permalink
Add GetSize() to Index
Browse files Browse the repository at this point in the history
* Adds a protected variable called data_ to reduce
boilerplate
* Implements GetSize in the Index class
  • Loading branch information
Will Sackfield committed Nov 30, 2017
1 parent f8343d1 commit f8edffe
Show file tree
Hide file tree
Showing 45 changed files with 172 additions and 186 deletions.
7 changes: 7 additions & 0 deletions similarity_search/include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class KNNQuery;
template <typename dist_t>
class Index {
public:
Index(const ObjectVector& data) : data_(data) {}

// Create an index using given parameters
virtual void CreateIndex(const AnyParams& indexParams) = 0;
// SaveIndex is not necessarily implemented
Expand Down Expand Up @@ -96,6 +98,11 @@ class Index {
bool checkIDs = false/* this is a debug flag only, turning it on may affect performance */) {
throw runtime_error("DeleteBatch is not implemented!");
}

virtual size_t GetSize() const { return data_.size(); }
protected:
const ObjectVector& data_;

private:
template <typename QueryType>
void GenericSearch(QueryType* query, IdType) const;
Expand Down
2 changes: 0 additions & 2 deletions similarity_search/include/method/bbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ class BBTree : public Index<dist_t> {
DISABLE_COPY_AND_ASSIGN(BBNode);
};

const ObjectVector& data_;

unique_ptr<BBNode> root_node_;
size_t BucketSize_;
int MaxLeavesToVisit_;
Expand Down
3 changes: 1 addition & 2 deletions similarity_search/include/method/dummy.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class DummyMethod : public Index<dist_t> {
* So, we can memorize them safely.
*/
DummyMethod(Space<dist_t>& space,
const ObjectVector& data) : data_(data), space_(space) {}
const ObjectVector& data) : Index<dist_t>(data), space_(space) {}

/*
* This function is supposed to create a search index (or call a
Expand Down Expand Up @@ -105,7 +105,6 @@ class DummyMethod : public Index<dist_t> {

private:
bool data_duplicate_;
const ObjectVector& data_;
Space<dist_t>& space_;
bool bDoSeqSearch_;
// disable copy and assign
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/ghtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class GHTree : public Index<dist_t> {
};

const Space<dist_t>& space_;
const ObjectVector& data_;
bool use_random_center_;
unique_ptr<GHNode> root_;

Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,6 @@ namespace similarity {
unsigned int enterpointId_;
unsigned int totalElementsStored_;

const ObjectVector &data_; // We do not copy objects
ObjectVector data_rearranged_;

VisitedListPool *visitedlistpool;
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/list_clusters.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class ListClusters : public Index<dist_t> {
};

const Space<dist_t>& space_;
const ObjectVector& data_;

std::vector<Cluster*> cluster_list_;

Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/multi_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class MultiIndex : public Index<dist_t> {

std::vector<Index<dist_t>*> indices_;
Space<dist_t>& space_;
const ObjectVector& data_;
string SpaceType_;
bool PrintProgress_;
size_t IndexQty_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ class MultiVantagePointTree : public Index<dist_t> {
void GenericSearch(Node* node, QueryType* query, Dists& path, size_t query_path_len, int& MaxLeavesToVisit) const;

const Space<dist_t>& space_;
const ObjectVector& data_;
unique_ptr<Node> root_; // root node

size_t MaxPathLength_; // the number of distances for the data
Expand Down
3 changes: 1 addition & 2 deletions similarity_search/include/method/nonmetr_list_clust.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class NonMetrListClust : public Index<dist_t> {
public:
NonMetrListClust(bool printProgress,
Space<dist_t>& space,
const ObjectVector& data) : printProgress_(printProgress), data_(data), space_(space) {
const ObjectVector& data) : Index<dist_t>(data), printProgress_(printProgress), space_(space) {
maxObjId_ = 0;
for (const Object* o: data) {
maxObjId_ = max(maxObjId_, o->id());
Expand Down Expand Up @@ -79,7 +79,6 @@ class NonMetrListClust : public Index<dist_t> {

private:
bool printProgress_;
const ObjectVector& data_;
Space<dist_t>& space_;

size_t db_scan_;
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/omedrank.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class OMedRank : public Index<dist_t> {
void IndexChunk(size_t chunkId, ProgressDisplay* displayBar);

const Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

size_t num_pivot_;
Expand Down Expand Up @@ -117,8 +116,8 @@ class OMedRank : public Index<dist_t> {
// Heuristics: try to read db_scan_fraction/index_qty entries from each index part
// or alternatively K * knn_amp_ entries, for KNN-search
size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(db_scan_frac_ * data_.size());
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(db_scan_frac_ * this->data_.size());
}

template <typename QueryType> void GenSearch(QueryType* query, size_t K) const;
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/perm_bin_vptree.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class PermBinVPTree : public Index<dist_t> {
private:

Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;
size_t bin_threshold_;
size_t bin_perm_word_qty_;
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/perm_index_incr_bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class PermutationIndexIncrementalBin : public Index<dist_t> {

private:
const Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

ObjectVector pivot_;
Expand All @@ -71,8 +70,8 @@ class PermutationIndexIncrementalBin : public Index<dist_t> {
std::vector<uint32_t> permtable_;

size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(db_scan_frac_ * data_.size());
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(db_scan_frac_ * this->data_.size());
}

template <typename QueryType> void GenSearch(QueryType* query, size_t K) const;
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/perm_lsh_bin.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class PermutationIndexLSHBin : public Index<dist_t> {
void SetQueryTimeParams(const AnyParams &) override {}
private:
const Space<dist_t>& space_;
const ObjectVector& data_;
bool printProgress_;

size_t num_pivot_;
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/permutation_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class PermutationInvertedIndex : public Index<dist_t> {

private:
const Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

float db_scan_frac_;
Expand All @@ -67,8 +66,8 @@ class PermutationInvertedIndex : public Index<dist_t> {
ObjectVector pivot_;

size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(db_scan_frac_ * data_.size());
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(db_scan_frac_ * this->data_.size());
}

struct ObjectInvEntry {
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/permutation_prefix_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,15 @@ class PermutationPrefixIndex : public Index<dist_t> {
private:

size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(min(min_candidate_, data_.size()));
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(min(min_candidate_, this->data_.size()));
}


template <typename QueryType>
void GenSearch(QueryType* query, size_t K) const;

const Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

// permutation prefix length (l in the original paper) in (0, num_pivot]
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/pivot_neighb_invindx.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class PivotNeighbInvertedIndex : public Index<dist_t> {
void SetQueryTimeParams(const AnyParams& QueryTimeParams) override;
private:

const ObjectVector& data_;
const Space<dist_t>& space_;
bool PrintProgress_;
bool recreate_points_;
Expand Down Expand Up @@ -130,11 +129,11 @@ class PivotNeighbInvertedIndex : public Index<dist_t> {
ObjectVector genPivot_; // generated pivots

size_t computeDbScan(size_t K, size_t chunkQty) const {
size_t totalDbScan = static_cast<size_t>(db_scan_frac_ * data_.size());
size_t totalDbScan = static_cast<size_t>(db_scan_frac_ * this->data_.size());
if (knn_amp_) {
totalDbScan = K * knn_amp_;
}
totalDbScan = min(totalDbScan, data_.size());
totalDbScan = min(totalDbScan, this->data_.size());
CHECK_MSG(chunkQty, "Bug or inconsistent parameters: the number of index chunks cannot be zero!");
return (totalDbScan + chunkQty - 1) / chunkQty;
}
Expand Down
5 changes: 2 additions & 3 deletions similarity_search/include/method/proj_vptree.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,15 @@ class ProjectionVPTree : public Index<dist_t> {


Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

size_t K_;
size_t knn_amp_;
float db_scan_frac_;

size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(db_scan_frac_ * data_.size());
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(db_scan_frac_ * this->data_.size());
}

unique_ptr<Projection<dist_t> > projObj_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ class ProjectionIndexIncremental : public Index<dist_t> {
private:

const Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;

float max_proj_dist_;
Expand All @@ -80,8 +79,8 @@ class ProjectionIndexIncremental : public Index<dist_t> {
#endif

size_t computeDbScan(size_t K) const {
if (knn_amp_) { return min(K * knn_amp_, data_.size()); }
return static_cast<size_t>(db_scan_frac_ * data_.size());
if (knn_amp_) { return min(K * knn_amp_, this->data_.size()); }
return static_cast<size_t>(db_scan_frac_ * this->data_.size());
}

template <typename QueryType> void GenSearch(QueryType* query, size_t K) const;
Expand Down
5 changes: 3 additions & 2 deletions similarity_search/include/method/seqsearch.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ class SeqSearch : public Index<dist_t> {
void Search(KNNQuery<dist_t>* query, IdType) const override;

void SetQueryTimeParams(const AnyParams& params) override {}

size_t GetSize() const override { return getData().size(); }
private:
Space<dist_t>& space_;
const ObjectVector& origData_;
char* cacheOptimizedBucket_;

ObjectVector* pData_;
bool multiThread_;
IdTypeUnsign threadQty_;
vector<ObjectVector> vvThreadData;

const ObjectVector& getData() const { return pData_ != NULL ? *pData_ : origData_; }
const ObjectVector& getData() const { return pData_ != NULL ? *pData_ : this->data_; }
// disable copy and assign
DISABLE_COPY_AND_ASSIGN(SeqSearch);
};
Expand Down
3 changes: 1 addition & 2 deletions similarity_search/include/method/simple_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class SimplInvIndex : public Index<dist_t> {
* So, we can memorize them safely.
*/
SimplInvIndex(Space<dist_t>& space,
const ObjectVector& data) : data_(data),
const ObjectVector& data) : Index<dist_t>(data),
pSpace_(dynamic_cast<SpaceSparseNegativeScalarProductFast*>(&space)) {
if (pSpace_ == nullptr) {
PREPARE_RUNTIME_ERR(err) <<
Expand Down Expand Up @@ -119,7 +119,6 @@ class SimplInvIndex : public Index<dist_t> {
: post_(&pl), post_pos_(0), qval_(qval), qval_x_docval_(qval_x_docval) {}
};

const ObjectVector& data_;
SpaceSparseNegativeScalarProductFast* pSpace_;
std::unordered_map<unsigned, std::unique_ptr<PostList>> index_;
// disable copy and assign
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/small_world_rand.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ class SmallWorldRand : public Index<dist_t> {
ObjectVector pivots_;

const Space<dist_t>& space_;
const ObjectVector& data_; // We don't copy data
bool PrintProgress_;
bool use_proxy_dist_;

Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/spatial_approx_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class SpatialApproxTree : public Index<dist_t> {
class SATNode;

const Space<dist_t>& space_;
const ObjectVector data_;

unique_ptr<SATNode> root_;
};
Expand Down
1 change: 0 additions & 1 deletion similarity_search/include/method/vptree.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class VPTree : public Index<dist_t> {
};

Space<dist_t>& space_;
const ObjectVector& data_;
bool PrintProgress_;
bool use_random_center_;
size_t max_pivot_select_attempts_;
Expand Down
4 changes: 2 additions & 2 deletions similarity_search/src/method/bbtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using std::unique_ptr;
template <typename dist_t>
BBTree<dist_t>::BBTree(
const Space<dist_t>& space,
const ObjectVector& data) : data_(data) {
const ObjectVector& data) : Index<dist_t>(data) {
BregmanDivSpace_ = BregmanDiv<dist_t>::ConvertFrom(&space); // Should be the special space!
}

Expand All @@ -61,7 +61,7 @@ void BBTree<dist_t>::CreateIndex(const AnyParams& MethParams) {

pmgr.CheckUnused();

root_node_.reset(new BBNode(BregmanDivSpace_, data_, BucketSize_, ChunkBucket_));
root_node_.reset(new BBNode(BregmanDivSpace_, this->data_, BucketSize_, ChunkBucket_));
}

template <typename dist_t>
Expand Down
8 changes: 4 additions & 4 deletions similarity_search/src/method/dummy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ namespace similarity {
template <typename dist_t>
void DummyMethod<dist_t>::Search(RangeQuery<dist_t>* query, IdType) const {
if (bDoSeqSearch_) {
for (size_t i = 0; i < data_.size(); ++i) {
query->CheckAndAddToResult(data_[i]);
for (size_t i = 0; i < this->data_.size(); ++i) {
query->CheckAndAddToResult(this->data_[i]);
}
} else {
for (int i =0; i < 100000; ++i);
Expand All @@ -35,8 +35,8 @@ void DummyMethod<dist_t>::Search(RangeQuery<dist_t>* query, IdType) const {
template <typename dist_t>
void DummyMethod<dist_t>::Search(KNNQuery<dist_t>* query, IdType) const {
if (bDoSeqSearch_) {
for (size_t i = 0; i < data_.size(); ++i) {
query->CheckAndAddToResult(data_[i]);
for (size_t i = 0; i < this->data_.size(); ++i) {
query->CheckAndAddToResult(this->data_[i]);
}
} else {
for (int i =0; i < 100000; ++i);
Expand Down
6 changes: 3 additions & 3 deletions similarity_search/src/method/ghtree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ template <typename dist_t>
GHTree<dist_t>::GHTree(const Space<dist_t>& space,
const ObjectVector& data,
bool use_random_center) :
space_(space),
data_(data),
Index<dist_t>(data),
space_(space),
use_random_center_(use_random_center) {
}

Expand All @@ -46,7 +46,7 @@ void GHTree<dist_t>::CreateIndex(const AnyParams& IndexParams) {
pmgr.CheckUnused();
this->ResetQueryTimeParams();

root_.reset(new GHNode(space_, data_,
root_.reset(new GHNode(space_, this->data_,
BucketSize_, ChunkBucket_,
use_random_center_ /* random center */));
}
Expand Down
Loading

0 comments on commit f8edffe

Please sign in to comment.