Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc2doc search to improve index performance on hard queries. #417

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
.vscode/
Release/
Debug/
datasets/
build/
data/
# Prerequisites
*.d

Expand Down
78 changes: 76 additions & 2 deletions AnnService/inc/Core/Common/QueryResultSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
#ifndef _SPTAG_COMMON_QUERYRESULTSET_H_
#define _SPTAG_COMMON_QUERYRESULTSET_H_

#include "inc/Core/Common.h"
#include "inc/Core/CommonDataStructure.h"
#include "inc/Core/SearchQuery.h"
#include "DistanceUtils.h"
#include <algorithm>
#include <memory>
#include "IQuantizer.h"

namespace SPTAG
Expand Down Expand Up @@ -36,9 +39,19 @@ class QueryResultSet : public QueryResult
}

QueryResultSet(const QueryResultSet& other) : QueryResult(other)
{
{
}

// QueryResultSet(const T*_target, int _K, bool _withResult) {
// m_withResultVector = _withResult;
// if(m_withResultVector) {
// m_resultVectors.resize(_K);
// }
// for(auto& ptr : m_resultVectors) {
// ptr = std::shared_ptr<T>(new T(), [](T* p) { delete p; });
// }
// }

~QueryResultSet()
{
}
Expand Down Expand Up @@ -86,11 +99,46 @@ class QueryResultSet : public QueryResult
return false;
}

// bool NeedResultVector() const {
// return m_withResultVector;
// }

// void RemoveResultVector() {
// if(m_withResultVector) {
// m_withResultVector = false;
// m_resultVectors.clear();
// }
// }

// if we want to use spread search, the query result should be copied to the result set
bool AddPoint(const SizeType index, float dist, ByteArray& vector) {
if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
{
m_results[0].VID = index;
m_results[0].Dist = dist;
m_results[0].Vector = vector;
// if(data != nullptr) // && m_withResultVector)
// {
// // copy data to m_resultVectors[0]
// // since we have already allocated memory for each result vector, we can directly copy data to it
// // memcpy(m_resultVectors[0].get(), data, sizeof(T));
// memcpy()
// }
Heapify(m_resultNum);
return true;
}
return false;
}

inline void SortResult()
{
for (int i = m_resultNum - 1; i >= 0; i--)
{
std::swap(m_results[0], m_results[i]);
// if(m_withResultVector)
// {
// std::swap(m_resultVectors[0], m_resultVectors[i]);
// }
Heapify(i);
}
}
Expand All @@ -100,6 +148,17 @@ class QueryResultSet : public QueryResult
std::reverse(m_results.Data(), m_results.Data() + m_resultNum);
}

// std::shared_ptr<T> GetVector(int idx) const
// {
// if (idx < m_resultNum) return m_resultVectors[idx];
// return nullptr;
// }
ByteArray GetVector(int idx) const
{
if (idx < m_resultNum) return m_results[idx].Vector;
return ByteArray();
}

private:
void Heapify(int count)
{
Expand All @@ -110,15 +169,30 @@ class QueryResultSet : public QueryResult
if (m_results[parent] < m_results[next])
{
std::swap(m_results[next], m_results[parent]);
// if(m_withResultVector)
// {
// std::swap(m_resultVectors[next], m_resultVectors[parent]);
// }
parent = next;
next = (parent << 1) + 1;
}
else break;
}
if (next == maxidx && m_results[parent] < m_results[next]) std::swap(m_results[parent], m_results[next]);
if (next == maxidx && m_results[parent] < m_results[next])
{
std::swap(m_results[parent], m_results[next]);
// if(m_withResultVector)
// {
// std::swap(m_resultVectors[parent], m_resultVectors[next]);
// }
}
}

// bool m_withResultVector = false;
// std::vector<std::shared_ptr<T>> m_resultVectors;
};
}
}


#endif // _SPTAG_COMMON_QUERYRESULTSET_H_
9 changes: 7 additions & 2 deletions AnnService/inc/Core/SPANN/ExtraFullGraphSearcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef _SPTAG_SPANN_EXTRASEARCHER_H_
#define _SPTAG_SPANN_EXTRASEARCHER_H_

#include "inc/Core/CommonDataStructure.h"
#include "inc/Helper/VectorSetReader.h"
#include "inc/Helper/AsyncFileReader.h"
#include "IExtraSearcher.h"
Expand Down Expand Up @@ -125,7 +126,9 @@ namespace SPTAG
if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) continue; \
(this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector));\
auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), p_postingListFullData + offsetVector); \
queryResults.AddPoint(vectorID, distance2leaf); \
SPTAG::ByteArray tmpVector = SPTAG::ByteArray::Alloc(sizeof(ValueType)* (m_vectorInfoSize - sizeof(int))); \
memcpy(tmpVector.Data(), p_postingListFullData + offsetVector, sizeof(ValueType)* (m_vectorInfoSize - sizeof(int))); \
queryResults.AddPoint(vectorID, distance2leaf, tmpVector); \
} \

#define ProcessPostingOffset() \
Expand All @@ -137,7 +140,9 @@ namespace SPTAG
if (p_exWorkSpace->m_deduper.CheckAndSet(vectorID)) continue; \
(this->*m_parseEncoding)(p_index, listInfo, (ValueType*)(p_postingListFullData + offsetVector));\
auto distance2leaf = p_index->ComputeDistance(queryResults.GetQuantizedTarget(), p_postingListFullData + offsetVector); \
queryResults.AddPoint(vectorID, distance2leaf); \
SPTAG::ByteArray tmpVector = SPTAG::ByteArray::Alloc(sizeof(ValueType)* (m_vectorInfoSize - sizeof(int))); \
memcpy(tmpVector.Data(), p_postingListFullData + offsetVector, sizeof(ValueType)* (m_vectorInfoSize - sizeof(int))); \
queryResults.AddPoint(vectorID, distance2leaf, tmpVector);\
foundResult = true;\
break;\
} \
Expand Down
2 changes: 2 additions & 0 deletions AnnService/inc/Core/SPANN/IExtraSearcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "inc/Core/VectorIndex.h"
#include "inc/Helper/AsyncFileReader.h"
#include "inc/Helper/VectorSetReader.h"

#include <memory>
#include <vector>
Expand Down Expand Up @@ -229,6 +230,7 @@ namespace SPTAG {
virtual bool CheckValidPosting(SizeType postingID) = 0;

virtual ErrorCode GetPostingDebug(ExtraWorkSpace* p_exWorkSpace, std::shared_ptr<VectorIndex> p_index, SizeType vid, std::vector<SizeType>& VIDs, std::shared_ptr<VectorSet>& vecs) = 0;

};
} // SPANN
} // SPTAG
Expand Down
1 change: 1 addition & 0 deletions AnnService/inc/Core/SPANN/Index.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ namespace SPTAG

ErrorCode SearchDiskIndex(QueryResult& p_query, SearchStats* p_stats = nullptr) const;
bool SearchDiskIndexIterative(QueryResult& p_headQuery, QueryResult& p_query, ExtraWorkSpace* extraWorkspace) const;
ErrorCode SearchDoc2Doc(QueryResult &p_query, COMMON::QueryResultSet<T>* p_results, std::unique_ptr<ExtraWorkSpace> workSpace) const;
ErrorCode DebugSearchDiskIndex(QueryResult& p_query, int p_subInternalResultNum, int p_internalResultNum,
SearchStats* p_stats = nullptr, std::set<int>* truth = nullptr, std::map<int, std::set<int>>* found = nullptr) const;
ErrorCode UpdateIndex();
Expand Down
3 changes: 3 additions & 0 deletions AnnService/inc/Core/SPANN/Options.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ namespace SPTAG {
int m_debugBuildInternalResultNum;
bool m_enableADC;
int m_iotimeout;
bool m_spreadSearch;
int m_doc2docRounds;
int m_doc2docResults;

// Iterative
int m_headBatch;
Expand Down
4 changes: 3 additions & 1 deletion AnnService/inc/Core/SPANN/ParameterDefinitionList.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ DefineSSDParameter(m_enableADC, bool, false, "EnableADC")
DefineSSDParameter(m_recall_analysis, bool, false, "RecallAnalysis")
DefineSSDParameter(m_debugBuildInternalResultNum, int, 64, "DebugBuildInternalResultNum")
DefineSSDParameter(m_iotimeout, int, 30, "IOTimeout")

DefineSSDParameter(m_spreadSearch, bool, true, "SpreadSearch")
DefineSSDParameter(m_doc2docRounds, int, 2, "doc2docRounds")
DefineSSDParameter(m_doc2docResults, int, 32, "Doc2DocResults")
// Iterative
DefineSSDParameter(m_headBatch, int, 32, "IterativeSearchHeadBatch")

Expand Down
5 changes: 5 additions & 0 deletions AnnService/inc/Core/SearchResult.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#define _SPTAG_SEARCHRESULT_H_

#include "CommonDataStructure.h"
#include "inc/Core/Common.h"
#include <memory>

namespace SPTAG
{
Expand Down Expand Up @@ -67,6 +69,7 @@ namespace SPTAG
SizeType VID;
float Dist;
ByteArray Meta;
ByteArray Vector;
bool RelaxedMono;

BasicResult() : VID(-1), Dist(MaxDist), RelaxedMono(false) {}
Expand All @@ -75,6 +78,8 @@ namespace SPTAG

BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta) : VID(p_vid), Dist(p_dist), Meta(p_meta), RelaxedMono(false) {}
BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta, bool p_relaxedMono) : VID(p_vid), Dist(p_dist), Meta(p_meta), RelaxedMono(p_relaxedMono) {}
BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta, ByteArray p_vector) : VID(p_vid), Dist(p_dist), Meta(p_meta), Vector(p_vector), RelaxedMono(false) {}
BasicResult(SizeType p_vid, float p_dist, ByteArray p_meta, ByteArray p_vector, bool p_relaxedMono) : VID(p_vid), Dist(p_dist), Meta(p_meta), Vector(p_vector), RelaxedMono(p_relaxedMono) {}
};

} // namespace SPTAG
Expand Down
Loading