Skip to content

Commit

Permalink
add index_feature for hnsw (#284)
Browse files Browse the repository at this point in the history
- some clean code on hnsw.cpp

Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 authored Jan 2, 2025
1 parent 028c00f commit 3534971
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
50 changes: 42 additions & 8 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ HNSW::HNSW(HnswParameters hnsw_params, const IndexCommonParam& index_common_para
hnsw_params.ef_construction,
Options::Instance().block_size_limit());
}

this->init_feature_list();
}

tl::expected<std::vector<int64_t>, Error>
Expand Down Expand Up @@ -195,7 +197,7 @@ HNSW::knn_search(const DatasetPtr& query,
int64_t k,
const std::string& parameters,
BaseFilterFunctor* filter_ptr) const {
SlowTaskTimer t("hnsw knnsearch", 20);
SlowTaskTimer t_total("hnsw knnsearch", 20);

try {
// cannot perform search on empty index
Expand All @@ -219,7 +221,7 @@ HNSW::knn_search(const DatasetPtr& query,
CHECK_ARGUMENT(k > 0, fmt::format("k({}) must be greater than 0", k))
k = std::min(k, GetNumElements());

std::shared_lock lock(rw_mutex_);
std::shared_lock lock_global(rw_mutex_);

// check search parameters
auto params = HnswSearchParameters::FromJson(parameters);
Expand Down Expand Up @@ -249,7 +251,7 @@ HNSW::knn_search(const DatasetPtr& query,

// return result
auto result = Dataset::Make();
if (results.size() == 0) {
if (results.empty()) {
result->Dim(0)->NumElements(1);
return result;
}
Expand All @@ -274,9 +276,9 @@ HNSW::knn_search(const DatasetPtr& query,

result->Dim(results.size())->NumElements(1)->Owner(true, allocator_.get());

int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * results.size());
auto* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * results.size());
result->Ids(ids);
float* dists = (float*)allocator_->Allocate(sizeof(float) * results.size());
auto* dists = (float*)allocator_->Allocate(sizeof(float) * results.size());
result->Distances(dists);

for (int64_t j = results.size() - 1; j >= 0; --j) {
Expand Down Expand Up @@ -376,17 +378,17 @@ HNSW::range_search(const DatasetPtr& query,
// return result
auto result = Dataset::Make();
size_t target_size = results.size();
if (results.size() == 0) {
if (results.empty()) {
result->Dim(0)->NumElements(1);
return result;
}
if (limited_size >= 1) {
target_size = std::min((size_t)limited_size, target_size);
}
result->Dim(target_size)->NumElements(1)->Owner(true, allocator_.get());
int64_t* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * target_size);
auto* ids = (int64_t*)allocator_->Allocate(sizeof(int64_t) * target_size);
result->Ids(ids);
float* dists = (float*)allocator_->Allocate(sizeof(float) * target_size);
auto* dists = (float*)allocator_->Allocate(sizeof(float) * target_size);
result->Distances(dists);
for (int64_t j = results.size() - 1; j >= 0; --j) {
if (j < target_size) {
Expand Down Expand Up @@ -915,5 +917,37 @@ HNSW::set_dataset(const DatasetPtr& base, const void* vectors_ptr, uint32_t num_
throw std::invalid_argument(fmt::format("no support for this type: {}", (int)type_));
}
}
bool
HNSW::CheckFeature(IndexFeature feature) const {
return this->feature_list_.CheckFeature(feature);
}

void
HNSW::init_feature_list() {
// Add & Build
feature_list_.SetFeatures({IndexFeature::SUPPORT_BUILD,
IndexFeature::SUPPORT_BUILD_WITH_MULTI_THREAD,
IndexFeature::SUPPORT_ADD_AFTER_BUILD,
IndexFeature::SUPPORT_ADD_FROM_EMPTY});
// Search
feature_list_.SetFeatures({IndexFeature::SUPPORT_KNN_SEARCH,
IndexFeature::SUPPORT_RANGE_SEARCH,
IndexFeature::SUPPORT_KNN_SEARCH_WITH_ID_FILTER,
IndexFeature::SUPPORT_RANGE_SEARCH_WITH_ID_FILTER});
// concurrency
feature_list_.SetFeatures({IndexFeature::SUPPORT_SEARCH_CONCURRENT,
IndexFeature::SUPPORT_ADD_SEARCH_CONCURRENT,
IndexFeature::SUPPORT_ADD_CONCURRENT,
IndexFeature::SUPPORT_UPDATE_ID_CONCURRENT,
IndexFeature::SUPPORT_UPDATE_VECTOR_CONCURRENT});
// serialize
feature_list_.SetFeatures({IndexFeature::SUPPORT_DESERIALIZE_BINARY_SET,
IndexFeature::SUPPORT_DESERIALIZE_FILE,
IndexFeature::SUPPORT_DESERIALIZE_READER_SET,
IndexFeature::SUPPORT_SERIALIZE_BINARY_SET,
IndexFeature::SUPPORT_SERIALIZE_FILE});
// other
feature_list_.SetFeature(IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID);
}

} // namespace vsag
22 changes: 15 additions & 7 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@
#include <utility>
#include <vector>

#include "../algorithm/hnswlib/hnswlib.h"
#include "../common.h"
#include "../data_type.h"
#include "../default_allocator.h"
#include "../impl/conjugate_graph.h"
#include "../logger.h"
#include "../safe_allocator.h"
#include "../utils.h"
#include "algorithm/hnswlib/hnswlib.h"
#include "base_filter_functor.h"
#include "common.h"
#include "data_type.h"
#include "hnsw_zparameters.h"
#include "impl/conjugate_graph.h"
#include "index_common_param.h"
#include "index_feature_list.h"
#include "logger.h"
#include "safe_allocator.h"
#include "typing.h"
#include "vsag/binaryset.h"
#include "vsag/errors.h"
Expand Down Expand Up @@ -145,6 +145,9 @@ class HNSW : public Index {
SAFE_CALL(return alg_hnsw_->getDistanceByLabel(id, vector));
};

[[nodiscard]] bool
CheckFeature(IndexFeature feature) const override;

public:
tl::expected<BinarySet, Error>
Serialize() const override {
Expand Down Expand Up @@ -278,6 +281,9 @@ class HNSW : public Index {
BinarySet
empty_binaryset() const;

void
init_feature_list();

private:
std::shared_ptr<hnswlib::AlgorithmInterface<float>> alg_hnsw_;
std::shared_ptr<hnswlib::SpaceInterface> space_;
Expand All @@ -298,6 +304,8 @@ class HNSW : public Index {
mutable std::map<std::string, WindowResultQueue> result_queues_;

mutable std::shared_mutex rw_mutex_;

IndexFeatureList feature_list_{};
};

} // namespace vsag

0 comments on commit 3534971

Please sign in to comment.