From fc2cfd19eb9a05def032cabc54db45db4d0aa9a9 Mon Sep 17 00:00:00 2001 From: ChunjiangZhu Date: Fri, 25 Oct 2019 00:59:15 -0400 Subject: [PATCH] hnsw add rangequery --- .DS_Store | Bin 10244 -> 10244 bytes similarity_search/.DS_Store | Bin 10244 -> 10244 bytes similarity_search/src/.DS_Store | Bin 10244 -> 10244 bytes similarity_search/src/method/hnsw.cc | 168 ++++++++++++++++++++++++++- 4 files changed, 167 insertions(+), 1 deletion(-) diff --git a/.DS_Store b/.DS_Store index 150336f9f1e4f6fd85679b81cb8f2e5ebfaad3d2..f522d6dcf805a403a1a0a6495bae0b3fe41941fc 100644 GIT binary patch delta 863 zcmZ{iOH30{6o$`_f_FwJ)ADeTmJ&*1(Sp1Lm4~z>#zc@BL8vHNI)jXqL1}raX<`U5 zN_5eni5eGfjA5go3m2NG8&`_TLfse_Zj74f1|J*mbVziev$)BeKj++gzwb^3rUJJg z%cxB(TK$RCyk?8MGh5u%rGVn&HDys~*F%Hql4Qxu$XzGPT_e5Pi4iR}A-Nbc z8|`v|Al+jJl~7D=9a4-=C7H-}xrniDquxBM^3^PhZh3C+`&8+a2F5oFfpCqXK`?=#s9SG_+#}= zx}nl)%gW9vaCy8HHFZsiNqc9!L(>&?ykksLV*SBTG_2~1rXTGb)uLK&M3vSU0NE1v zo98TaI3$Cd`CIIEh!VJ6A(@uAd7844m6q{r6&Gn{l~tlcO6AJT&f2Y_D6QsFc^c<~ z>l=8S#`T%8&(C8r-k0zffBGN%TQpCPIQX~pjy`bc-)I?xG^8UJ1t^3I#c-nx<)}mz zsyXf*@WG!%z6V*hatv5fqS7i8#wosXY8 z%q5%GjCXqD)fQVTH#^+v=C6#VdY>^E40IXOdcP18A^ng(><=50`d)un4+$vqN$IXE zPi|3JWo_e1a!RVOZRY%_q3MCJF=XiD=BnEf=??fubX(lAqO6FN)MVAJI+^e=-^741 zHerOPES0;}OhqiPK21K|9@L`Owb*&Lx}C9OQQbSP%T=>FW$#dVcSueq{7kWwynN2P zPZ|MDgk-DeG{q%c?THNg2lQc~Da_K!E4bP_sgH+@U_i93i&fQJ^$i9`M&#fj8w#at zjasKRF!7Qdk2c@$4~08M^}u@nCby~{iRz7-rfZBbPeU$QpQoWTwo!MhV zbQ+~8nxm_9jc(IDdPEDfNH6FWEzvT)ruX!LKGG_Ep>On^e$Y<<9EnJR1KY3@xyVC4 z%29_IJ>(S~*$KnHv{fdLFc!w4oZg=w6|EH2;@F5^}T?%+P=@c<9;6pQA?jG=f} zb6U%vkjwrP^1Gy1vRl&9Gd(#4rMqhC%^7#E`N`>s^hn(LDHcU05>}jchnlt3^7niXrR*{=6y*TNpc88M}xk?lnp4ecOt}U7Dw7lKC5YD@lK*uafneZ^+{i?ZBwc_)6k)fd zT`hTgCG8fpq7w(vg>D>1f0TK)gsvliAjU9`FwWpC&S3`UF^4O-jvFv>^GRann(`a- Cn9M`~ diff --git a/similarity_search/.DS_Store b/similarity_search/.DS_Store index 7f148cd61ce96fec3e71651ed7ae944364b381e1..803ba2066393739bb1490f3cf8e05e7fa8fab7ab 100644 GIT binary patch delta 1593 zcmeIwOKc2r90&0Kx81fg9opH|cGr}hR`sFU+G-=a~DL z`&PcidEVFVS5(dAdX1=POH6DW!^U&mbaB<{H7PY&`mnVkzlBd!TRVb&zdNWZni32& zHEQxctxnl5{}eUqJxO`CghbwIn~{{9Vo%G=Ehs7}t<*;>4*gxyl8|GjkUDENCpIVn zN#576SCO=xO2pMpXZR#b2&}7srBuSPS=1(x`o*A*B z-)nkEo{=%~ioB(J KWdnMjUHJ|wbBBxo delta 1590 zcmeIw?{5@E7zgm@+aB%A(83(_=*@y$;i#?jiU-m1qYd6!N?L)~z%?z1CD-l%tJiDo zUMqs}q9T!ml<1oHLJb6bhY2Qxgz$z$P)Q(w@i1zlFZ7KGAt4Y+3@>!{wlO6k;U8dM z?CkT*?9ODrpM$*zd%tcCS%X7kImay|lRu-PGBm5InsG7P>e?u}z2kKeZ!omNdl0Vc z-X-OPtGDW0w%|djfV@bZ|>d$;RaB0K%gmQmZ;?P$BSpgyXsg~DryG!<2SbV01avLut287uka zvPM%@$g`&OSBm9Ug3BzEmBEY8ij_>-T+Rw+pA&6dx?F~W*|nmb$xU2V1+(i#iiaz1 za7pD)^F~WB&f=+NDqopmK0Bg$Je3S;Vx(vVrKp>BlBO|wpN`NJ9i+f@Bt3tLwtmf;o%c}ffG0x9IOJ<2ke<1rw4rH#li{iX0-FngQ{Lp zewYd`=RwE11?$sC#jSc?5}Rkn{n}&VK2;NWYU$H0$>-O-u%Q$rFT@VVHgWrzj`zLk z*oD2@v+l6Y+fIH@|D@wOLph!C|F0ha_w%S#mRjU3iA0M`#@WEckn_fbQ+TT+=gh4Y zF|~x7&>yDa�JcWb756Q0KTdce{Lore-BFuCFo6c)62k6(aMxledkBptEw7P<^@V zkj@z~th}ggwNU-VNj>g5qj@7G=(cwV^}pXR5~f|=I;}c=m=B5e_cu(tVoGiI50$oU zZ8J7Q)nic=167~VmvoxW(gpgSuF`e7VRZdUf6!lam+sNubf5kO1IKLCKw&B3ScW9l zVm(qw<3(et7hAE-$a>Y7(q&U)hSGih)N6n<|LU^i=`gCXvxKs`nr1(c9#xSlvRr5^w!Ml z>F)1!_v>%wz3v47?8=+100{ul=;BgeKsQC=_58Xi#ln(Cl1O-f04&Hs3i6bTQe=V1 z0+9tG3q%%(EN~}UfP6Mj+I%Uadt`yg0+9t~EWp1H3A(sUhH_F$@#&zO*aDC&r(%&P zPx}CYNrW;P%1J2&lyu771BOrxw-_klRG;9@NhU)%DW!x1N;qKHGln}96ulEJA({gw zrHt;81tJUFYyrM^uZ9dbv}gB=_U|^8t)rSukK<+1j+dsE{N?JTtNBlnCqX9sCIL9k$$o(ynRBfupkT51(qxZY85 z8ZoTY@ob5X8ifvPrc7Z6yk4TAYNqv#xf~~JD6lc31iM&)b)b#VC&hV+PqR5n7FFr` zxo`D7uwOG^-yb>D$!D>m`zTs<_Y4HM>EQ)MpM4DM2~7+ebQhM zGYz#4!Ok}d*ttS67N6$9@jme*MfV#TQ)Fd>N|crc&fY@ZHH050UHr_;Ccaw4*FenB zB|Qpz3+5JTV8nBWERq&;QXx1`3la0ECzmee2n8n#ThFJr(xu-?db{~3PpL*F*jR|F z5MvsqvGKl??@x@gt`Nh-{Igg(|7kYYSg{Co$rS$m3#Kui6iPqMRGcAdU*0xP; zEs2gz>(88F+LET$tvyGwr%q2idFJe!G$BeDMlbZqEl%g@vB)5fKiVxP8Ou#i^4sZ2 zGB!WHvu@Y!l&oLnQ0(;cjZUmN}(yM)%ZBNE?_vfrpn%O|cbMjf2YHBiulgU$K zYn$Wr+NUjlFtEl0o*VMLT+l5V4%&p}WhuRutej7hv8e@%maS}TUfb5O zW!sg?s%o9pvbup$+qct>)jwvN!LfsepRruic87<^T5ez;wXGbhiCt!Wu93B_=vCE~ zivxQ_<;bFqtX(pwng$-uB zY?j9D%i@RPNlpJ;dxH19$MA!FBbH0u^I3G&$KwZi`8LzEOij}olg)Bv8&|K7vo-Gx!p|g3E9XzK0*+NB9YTfnVV_xQ>7f zD{ww8!35rgjkpR|V;gS3P1uf|xE*)kF5Hce-~bNd5FWw|I+(=~Jc+0AX?zBs!{_mJ zd>b#|MV!L-@e+P4M;4oAV~>2EzT+p9D`jK&%Ebm$Y-}JluKHKlcx?fvU`_49xBw2~5OP4KY4gEwXYrK@;)5O>3apGuKc9IJ3K#&J2;V*}oU&A0~d z!xo}h2X4j(1=V&6sy&Lm*pCN@a7QqWCR%tDJ))bB0p{^6PU2I7c<1mXd>LQCSMfD` zOOk7ba=WL=CMCXCO1aVAoa=b*FqLzfve0K6-v!1fA+L7BOe*5>(5T2*n)m~iFTRrMJUl9DFW;f)9$DbFEKnv}C;0jQ zj@keJzb&0nV!Z diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc index f5c7fca..4813619 100644 --- a/similarity_search/src/method/hnsw.cc +++ b/similarity_search/src/method/hnsw.cc @@ -677,7 +677,43 @@ namespace similarity { void Hnsw::Search(RangeQuery *query, IdType) const { - throw runtime_error("Range search is not supported!"); +// throw runtime_error("Range search is not supported!"); + if (this->data_.empty() && this->data_rearranged_.empty()) { + return; + } + bool useOld = searchAlgoType_ == kOld || (searchAlgoType_ == kHybrid && ef_ >= 1000); + // cout << "Ef = " << ef_ << " use old = " << useOld << endl; + switch (searchMethod_) { + default: + throw runtime_error("Invalid searchMethod: " + ConvertToString(searchMethod_)); + break; + case 0: + /// Basic search using Nmslib data structure: +// if (useOld) +// const_cast(this)->baseSearchAlgorithmOld(query); +// else + const_cast(this)->baseSearchAlgorithmV1Merge(query); + break; +// case 1: +// /// Experimental search using Nmslib data structure (should not be used): +// const_cast(this)->listPassingModifiedAlgorithm(query); +// break; +// case 3: +// /// Basic search using optimized index(cosine+L2) +// if (useOld) +// const_cast(this)->SearchL2CustomOld(query); +// else +// const_cast(this)->SearchL2CustomV1Merge(query); +// break; +// case 4: +// /// Basic search using optimized index with one-time normalized cosine similarity +// /// Only for cosine similarity! +// if (useOld) +// const_cast(this)->SearchCosineNormalizedOld(query); +// else +// const_cast(this)->SearchCosineNormalizedV1Merge(query); +// break; + }; } template @@ -1274,6 +1310,136 @@ namespace similarity { visitedlistpool->releaseVisitedList(vl); } + + template + void + Hnsw::baseSearchAlgorithmV1Merge(RangeQuery *query) + { + VisitedList *vl = visitedlistpool->getFreeVisitedList(); + vl_type *massVisited = vl->mass; + vl_type currentV = vl->curV; + + HnswNode *provider; + int maxlevel1 = enterpoint_->level; + provider = enterpoint_; + + const Object *currObj = provider->getData(); + + dist_t d = query->DistanceObjLeft(currObj); + dist_t curdist = d; + HnswNode *curNode = provider; + for (int i = maxlevel1; i > 0; i--) { + bool changed = true; + while (changed) { + changed = false; + + const vector &neighbor = curNode->getAllFriends(i); + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + _mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0); + } + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + if (d < curdist) { + curdist = d; + curNode = *iter; + changed = true; + } + } + } + } + + SortArrBI sortedArr(ef_); // max(ef_, query->GetK()) + sortedArr.push_unsorted_grow(curdist, curNode); + + int_fast32_t currElem = 0; + + typedef typename SortArrBI::Item QueueItem; + vector &queueData = sortedArr.get_data(); + vector itemBuff(1 + max(maxM_, maxM0_)); + + massVisited[curNode->getId()] = currentV; + // visitedQueue.insert(curNode->getId()); + + //////////////////////////////////////////////////////////////////////////////// + // PHASE TWO OF THE SEARCH + // Extraction of the neighborhood to find k nearest neighbors. + //////////////////////////////////////////////////////////////////////////////// + + while (currElem < min(sortedArr.size(), ef_)) { + auto &e = queueData[currElem]; + CHECK(!e.used); + e.used = true; + HnswNode *initNode = e.data; + ++currElem; + + size_t itemQty = 0; + dist_t topKey = sortedArr.top_key(); + + const vector &neighbor = (initNode)->getAllFriends(0); + + size_t curId; + + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + _mm_prefetch((char *)(*iter)->getData(), _MM_HINT_T0); + IdType curId = (*iter)->getId(); + CHECK(curId >= 0 && curId < this->data_.size()); + _mm_prefetch((char *)(massVisited + curId), _MM_HINT_T0); + } + // calculate distance to each neighbor + for (auto iter = neighbor.begin(); iter != neighbor.end(); ++iter) { + curId = (*iter)->getId(); + + if (!(massVisited[curId] == currentV)) { + massVisited[curId] = currentV; + currObj = (*iter)->getData(); + d = query->DistanceObjLeft(currObj); + + if (d < topKey || sortedArr.size() < ef_) { + CHECK_MSG(itemBuff.size() > itemQty, + "Perhaps a bug: buffer size is not enough " + + ConvertToString(itemQty) + " >= " + ConvertToString(itemBuff.size())); + itemBuff[itemQty++] = QueueItem(d, *iter); + } + } + } + + if (itemQty) { + _mm_prefetch(const_cast(reinterpret_cast(&itemBuff[0])), _MM_HINT_T0); + std::sort(itemBuff.begin(), itemBuff.begin() + itemQty); + + size_t insIndex = 0; + if (itemQty > MERGE_BUFFER_ALGO_SWITCH_THRESHOLD) { + insIndex = sortedArr.merge_with_sorted_items(&itemBuff[0], itemQty); + + if (insIndex < currElem) { + // LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex; + currElem = insIndex; + } + } else { + for (size_t ii = 0; ii < itemQty; ++ii) { + size_t insIndex = sortedArr.push_or_replace_non_empty_exp(itemBuff[ii].key, itemBuff[ii].data); + + if (insIndex < currElem) { + // LOG(LIB_INFO) << "@@@ " << currElem << " -> " << insIndex; + currElem = insIndex; + } + } + } + } + // To ensure that we either reach the end of the unexplored queue or currElem points to the first unused element + while (currElem < sortedArr.size() && queueData[currElem].used == true) + ++currElem; + } + + for (uint_fast32_t i = 0; i < sortedArr.size(); ++i) { // i < query->GetK() && + query->CheckAndAddToResult(queueData[i].key, queueData[i].data->getData()); + } + + visitedlistpool->releaseVisitedList(vl); + } + + // Experimental search algorithm template void