From fc5e69590a5930e84a1c7c21f9bb39f863ce9b8d Mon Sep 17 00:00:00 2001 From: "zhongxiaoyao.zxy" Date: Fri, 10 Jan 2025 17:15:05 +0800 Subject: [PATCH 1/2] support basic optimizer and searcher Signed-off-by: zhongxiaoyao.zxy --- include/vsag/constants.h | 4 + src/algorithm/hnswlib/hnswalg.cpp | 6 + src/algorithm/hnswlib/hnswalg.h | 11 ++ src/constants.cpp | 4 + src/data_cell/adapter_graph_datacell.h | 55 ++++++ src/data_cell/adapter_graph_datacell_test.cpp | 68 +++++++ src/data_cell/flatten_datacell.h | 16 ++ src/data_cell/flatten_interface.h | 20 ++ src/impl/basic_optimizer.h | 88 +++++++++ src/impl/basic_searcher.cpp | 181 ++++++++++++++++++ src/impl/basic_searcher.h | 103 ++++++++++ src/impl/basic_searcher_test.cpp | 123 ++++++++++++ src/impl/runtime_parameter.h | 102 ++++++++++ 13 files changed, 781 insertions(+) create mode 100644 src/data_cell/adapter_graph_datacell.h create mode 100644 src/data_cell/adapter_graph_datacell_test.cpp create mode 100644 src/impl/basic_optimizer.h create mode 100644 src/impl/basic_searcher.cpp create mode 100644 src/impl/basic_searcher.h create mode 100644 src/impl/basic_searcher_test.cpp create mode 100644 src/impl/runtime_parameter.h diff --git a/include/vsag/constants.h b/include/vsag/constants.h index cbc68825..cac3496f 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -50,6 +50,10 @@ extern const char* const PARAMETER_METRIC_TYPE; extern const char* const PARAMETER_USE_CONJUGATE_GRAPH; extern const char* const PARAMETER_USE_CONJUGATE_GRAPH_SEARCH; +extern const char* const PREFETCH_NEIGHBOR_VISIT_NUM; +extern const char* const PREFETCH_NEIGHBOR_CODE_NUM; +extern const char* const PREFETCH_CACHE_LINE; + extern const char* const DISKANN_PARAMETER_L; extern const char* const DISKANN_PARAMETER_R; extern const char* const DISKANN_PARAMETER_P_VAL; diff --git a/src/algorithm/hnswlib/hnswalg.cpp b/src/algorithm/hnswlib/hnswalg.cpp index e473fa51..03580ff1 100644 --- a/src/algorithm/hnswlib/hnswalg.cpp +++ b/src/algorithm/hnswlib/hnswalg.cpp @@ -1548,4 +1548,10 @@ HierarchicalNSW::checkIntegrity() { std::cout << "integrity ok, checked " << connections_checked << " connections\n"; } +template MaxHeap +HierarchicalNSW::searchBaseLayerST(InnerIdType ep_id, + const void* data_point, + size_t ef, + vsag::BaseFilterFunctor* isIdAllowed) const; + } // namespace hnswlib diff --git a/src/algorithm/hnswlib/hnswalg.h b/src/algorithm/hnswlib/hnswalg.h index f880e746..cd7dabdb 100644 --- a/src/algorithm/hnswlib/hnswalg.h +++ b/src/algorithm/hnswlib/hnswalg.h @@ -149,6 +149,17 @@ class HierarchicalNSW : public AlgorithmInterface { bool isValidLabel(LabelType label) override; + size_t + getMaxDegree() { + return maxM0_; + }; + + linklistsizeint* + get_linklist0(InnerIdType internal_id) const { + // only for test now + return (linklistsizeint*)(data_level0_memory_->GetElementPtr(internal_id, offsetLevel0_)); + } + inline LabelType getExternalLabel(InnerIdType internal_id) const { std::shared_lock lock(points_locks_[internal_id]); diff --git a/src/constants.cpp b/src/constants.cpp index d791cc52..5c285344 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -52,6 +52,10 @@ const char* const PARAMETER_METRIC_TYPE = "metric_type"; const char* const PARAMETER_USE_CONJUGATE_GRAPH = "use_conjugate_graph"; const char* const PARAMETER_USE_CONJUGATE_GRAPH_SEARCH = "use_conjugate_graph_search"; +const char* const PREFETCH_NEIGHBOR_VISIT_NUM = "prefetch_neighbor_visit_num"; +const char* const PREFETCH_NEIGHBOR_CODE_NUM = "prefetch_neighbor_codes_num"; +const char* const PREFETCH_CACHE_LINE = "prefetch_cache_line"; + const char* const DISKANN_PARAMETER_L = "ef_construction"; const char* const DISKANN_PARAMETER_R = "max_degree"; const char* const DISKANN_PARAMETER_P_VAL = "pq_sample_rate"; diff --git a/src/data_cell/adapter_graph_datacell.h b/src/data_cell/adapter_graph_datacell.h new file mode 100644 index 00000000..e4bc1a08 --- /dev/null +++ b/src/data_cell/adapter_graph_datacell.h @@ -0,0 +1,55 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "algorithm/hnswlib/hnswalg.h" +#include "algorithm/hnswlib/space_l2.h" + +namespace vsag { +class AdaptGraphDataCell { +public: + AdaptGraphDataCell(std::shared_ptr alg_hnsw) : alg_hnsw_(alg_hnsw){}; + + void + GetNeighbors(InnerIdType id, Vector& neighbor_ids) { + int* data = (int*)alg_hnsw_->get_linklist0(id); + uint32_t size = alg_hnsw_->getListCount((hnswlib::linklistsizeint*)data); + neighbor_ids.resize(size); + for (uint32_t i = 0; i < size; i++) { + neighbor_ids[i] = *(data + i + 1); + } + } + + uint32_t + GetNeighborSize(InnerIdType id) { + int* data = (int*)alg_hnsw_->get_linklist0(id); + return alg_hnsw_->getListCount((hnswlib::linklistsizeint*)data); + } + + void + Prefetch(InnerIdType id, InnerIdType neighbor_i) { + int* data = (int*)alg_hnsw_->get_linklist0(id); + _mm_prefetch(data + neighbor_i + 1, _MM_HINT_T0); + } + + uint32_t + MaximumDegree() { + return alg_hnsw_->getMaxDegree(); + } + +private: + std::shared_ptr alg_hnsw_; +}; +} // namespace vsag \ No newline at end of file diff --git a/src/data_cell/adapter_graph_datacell_test.cpp b/src/data_cell/adapter_graph_datacell_test.cpp new file mode 100644 index 00000000..c8dd4a9a --- /dev/null +++ b/src/data_cell/adapter_graph_datacell_test.cpp @@ -0,0 +1,68 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "adapter_graph_datacell.h" + +#include "catch2/catch_template_test_macros.hpp" +#include "fixtures.h" +#include "fmt/format-inl.h" +#include "graph_interface_test.h" +#include "io/io_headers.h" +#include "safe_allocator.h" + +using namespace vsag; + +TEST_CASE("basic usage for graph data cell (adapter of hnsw)", "[ut][GraphDataCell]") { + uint32_t M = 32; + uint32_t data_size = 1000; + uint32_t ef_construction = 100; + uint64_t DEFAULT_MAX_ELEMENT = 1; + uint64_t dim = 960; + auto vectors = fixtures::generate_vectors(data_size, dim); + std::vector ids(data_size); + std::iota(ids.begin(), ids.end(), 0); + + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + auto space = std::make_shared(dim); + auto io = std::make_shared(allocator.get()); + auto alg_hnsw = + std::make_shared(space.get(), + DEFAULT_MAX_ELEMENT, + allocator.get(), + M / 2, + ef_construction, + Options::Instance().block_size_limit()); + alg_hnsw->init_memory_space(); + for (int64_t i = 0; i < data_size; ++i) { + auto successful_insert = + alg_hnsw->addPoint((const void*)(vectors.data() + i * dim), ids[i]); + REQUIRE(successful_insert == true); + } + + auto graph_data_cell = std::make_shared(alg_hnsw); + + for (uint32_t i = 0; i < data_size; i++) { + auto neighbor_size = graph_data_cell->GetNeighborSize(i); + Vector neighbor_ids(neighbor_size, allocator.get()); + graph_data_cell->GetNeighbors(i, neighbor_ids); + + int* data = (int*)alg_hnsw->get_linklist0(i); + REQUIRE(neighbor_size == alg_hnsw->getListCount((hnswlib::linklistsizeint*)data)); + + for (uint32_t j = 0; j < neighbor_size; j++) { + REQUIRE(neighbor_ids[j] == *(data + j + 1)); + } + } +} diff --git a/src/data_cell/flatten_datacell.h b/src/data_cell/flatten_datacell.h index a206d698..0aec8577 100644 --- a/src/data_cell/flatten_datacell.h +++ b/src/data_cell/flatten_datacell.h @@ -70,6 +70,11 @@ class FlattenDataCell : public FlattenInterface { io_->Prefetch(id * code_size_); }; + bool + Decode(const uint8_t* codes, DataType* data) override { + return this->quantizer_->DecodeOne(codes, data); + } + [[nodiscard]] std::string GetQuantizerName() override; @@ -226,7 +231,18 @@ FlattenDataCell::query(float* result_dists, const std::shared_ptr>& computer, const InnerIdType* idx, InnerIdType id_count) { + for (uint32_t i = 0; i < this->prefetch_neighbor_codes_num_ and i < id_count; i++) { + this->io_->Prefetch(static_cast(idx[i]) * static_cast(code_size_), + this->prefetch_cache_line_); + } + for (int64_t i = 0; i < id_count; ++i) { + if (i + this->prefetch_neighbor_codes_num_ < id_count) { + this->io_->Prefetch(static_cast(idx[i + this->prefetch_neighbor_codes_num_]) * + static_cast(code_size_), + this->prefetch_cache_line_); + } + bool release = false; const auto* codes = this->GetCodesById(idx[i], release); computer->ComputeDist(codes, result_dists + i); diff --git a/src/data_cell/flatten_interface.h b/src/data_cell/flatten_interface.h index cb169cb7..39be9bcd 100644 --- a/src/data_cell/flatten_interface.h +++ b/src/data_cell/flatten_interface.h @@ -18,11 +18,13 @@ #include #include +#include "impl/runtime_parameter.h" #include "index/index_common_param.h" #include "quantization/computer.h" #include "stream_reader.h" #include "stream_writer.h" #include "typing.h" +#include "vsag/constants.h" namespace vsag { class FlattenInterface; @@ -83,6 +85,11 @@ class FlattenInterface { return false; } + virtual bool + Decode(const uint8_t* codes, DataType* vector) { + return false; + } + [[nodiscard]] virtual InnerIdType TotalCount() const { return this->total_count_; @@ -102,10 +109,23 @@ class FlattenInterface { StreamReader::ReadObj(reader, this->code_size_); } + virtual void + SetRuntimeParameters(const UnorderedMap& new_params) { + if (new_params.find(PREFETCH_NEIGHBOR_CODE_NUM) != new_params.end()) { + prefetch_neighbor_codes_num_ = std::get(new_params.at(PREFETCH_NEIGHBOR_CODE_NUM)); + } + + if (new_params.find(PREFETCH_CACHE_LINE) != new_params.end()) { + prefetch_cache_line_ = std::get(new_params.at(PREFETCH_CACHE_LINE)); + } + } + public: InnerIdType total_count_{0}; InnerIdType max_capacity_{1000000}; uint32_t code_size_{0}; + uint32_t prefetch_neighbor_codes_num_{1}; + uint32_t prefetch_cache_line_{1}; }; } // namespace vsag diff --git a/src/impl/basic_optimizer.h b/src/impl/basic_optimizer.h new file mode 100644 index 00000000..a9da32f1 --- /dev/null +++ b/src/impl/basic_optimizer.h @@ -0,0 +1,88 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "../utils.h" +#include "logger.h" +#include "runtime_parameter.h" + +namespace vsag { + +template +class Optimizer { + Optimizer(std::shared_ptr allocator, int trials = 100) + : parameters_(allocator.get()), + best_params_(allocator.get()), + n_trials_(trials), + best_loss_(std::numeric_limits::max()) { + allocator_ = allocator.get(); + std::random_device rd; + gen_.seed(rd()); + } + + void + RegisterParameter(const std::shared_ptr& runtime_parameter) { + parameters_.push_back(runtime_parameter); + } + + void + Optimize(OptimizableOBJ& obj) { + double original_loss = obj.MockRun(); + + for (int i = 0; i < n_trials_; ++i) { + // generate a group of runtime params + UnorderedMap current_params(allocator_); + for (auto& param : parameters_) { + current_params[param->name_] = param->sample(gen_); + } + obj.SetRuntimeParameters(current_params); + + // evaluate + double loss = obj.MockRun(); + + // update + if (loss < best_loss_) { + best_loss_ = loss; + best_params_ = current_params; + vsag::logger::debug(fmt::format("Trial {}: new best loss = {}, improving = {}", + i + 1, + best_loss_, + (original_loss - best_loss_) / original_loss)); + } + } + } + + UnorderedMap + GetBestParameters() const { + return best_params_; + } + + double + GetBestLoss() const { + return best_loss_; + } + +private: + Allocator* allocator_{nullptr}; + + Vector> parameters_; + int n_trials_{0}; + std::mt19937 gen_; + + UnorderedMap best_params_; + double best_loss_{0}; +}; + +} // namespace vsag \ No newline at end of file diff --git a/src/impl/basic_searcher.cpp b/src/impl/basic_searcher.cpp new file mode 100644 index 00000000..266ababe --- /dev/null +++ b/src/impl/basic_searcher.cpp @@ -0,0 +1,181 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "basic_searcher.h" + +namespace vsag { + +template +BasicSearcher::BasicSearcher(std::shared_ptr graph, + std::shared_ptr vector, + const IndexCommonParam& common_param) { + this->graph_ = graph; + this->vector_data_cell_ = vector; + this->allocator_ = common_param.allocator_.get(); + this->dim_ = common_param.dim_; + pool_ = std::make_shared(vector_data_cell_->TotalCount(), allocator_); +} + +template +void +BasicSearcher::SetRuntimeParameters( + const UnorderedMap& new_params) { + if (new_params.find(PREFETCH_NEIGHBOR_VISIT_NUM) != new_params.end()) { + prefetch_neighbor_visit_num_ = std::get(new_params.at(PREFETCH_NEIGHBOR_VISIT_NUM)); + } + this->vector_data_cell_->SetRuntimeParameters(new_params); +} + +template +void +BasicSearcher::Resize(uint64_t new_size) { + pool_ = std::make_shared(new_size, allocator_); +} + +template +uint32_t +BasicSearcher::visit(hnswlib::VisitedListPtr vl, + std::pair& current_node_pair, + Vector& to_be_visited_rid, + Vector& to_be_visited_id) const { + // to_be_visited_rid is used in redundant storage + // to_be_visited_id is used in flatten storage + uint32_t count_no_visited = 0; + Vector neighbors(allocator_); + + graph_->GetNeighbors(current_node_pair.second, neighbors); + +#ifdef USE_SSE + for (uint32_t i = 0; i < prefetch_neighbor_visit_num_; i++) { + _mm_prefetch(vl->mass + neighbors[i], _MM_HINT_T0); + } +#endif + + for (uint32_t i = 0; i < neighbors.size(); i++) { +#ifdef USE_SSE + if (i + prefetch_neighbor_visit_num_ < neighbors.size()) { + _mm_prefetch(vl->mass + neighbors[i + prefetch_neighbor_visit_num_], _MM_HINT_T0); + } +#endif + if (vl->mass[neighbors[i]] != vl->curV) { + to_be_visited_rid[count_no_visited] = i; + to_be_visited_id[count_no_visited] = neighbors[i]; + count_no_visited++; + vl->mass[neighbors[i]] = vl->curV; + } + } + return count_no_visited; +} + +template +MaxHeap +BasicSearcher::Search(const float* query, + InnerSearchParam& inner_search_param) const { + MaxHeap top_candidates(allocator_); + MaxHeap candidate_set(allocator_); + + auto computer = vector_data_cell_->FactoryComputer(query); + auto vl = pool_->getFreeVisitedList(); + + float lower_bound; + float dist; + uint64_t candidate_id; + uint32_t hops = 0; + uint32_t dist_cmp = 0; + uint32_t count_no_visited = 0; + Vector to_be_visited_rid(graph_->MaximumDegree(), allocator_); + Vector to_be_visited_id(graph_->MaximumDegree(), allocator_); + Vector line_dists(graph_->MaximumDegree(), allocator_); + + InnerIdType ep_id = inner_search_param.ep_; + vector_data_cell_->Query(&dist, computer, &ep_id, 1); + top_candidates.emplace(dist, ep_id); + candidate_set.emplace(-dist, ep_id); + vl->mass[ep_id] = vl->curV; + + while (!candidate_set.empty()) { + hops++; + std::pair current_node_pair = candidate_set.top(); + + if ((-current_node_pair.first) > lower_bound && + (top_candidates.size() == inner_search_param.ef_)) { + break; + } + candidate_set.pop(); + if (not candidate_set.empty()) { + graph_->Prefetch(candidate_set.top().second, 0); + } + + count_no_visited = visit(vl, current_node_pair, to_be_visited_rid, to_be_visited_id); + + dist_cmp += count_no_visited; + + // TODO(ZXY): implement mix storage query line + vector_data_cell_->Query( + line_dists.data(), computer, to_be_visited_id.data(), count_no_visited); + + for (uint32_t i = 0; i < count_no_visited; i++) { + dist = line_dists[i]; + candidate_id = to_be_visited_id[i]; + if (top_candidates.size() < inner_search_param.ef_ || lower_bound > dist) { + candidate_set.emplace(-dist, candidate_id); + + top_candidates.emplace(dist, candidate_id); + + if (top_candidates.size() > inner_search_param.ef_) + top_candidates.pop(); + + if (!top_candidates.empty()) + lower_bound = top_candidates.top().first; + } + } + } + + while (top_candidates.size() > inner_search_param.topk_) { + top_candidates.pop(); + } + + pool_->releaseVisitedList(vl); + return top_candidates; +} + +template +double +BasicSearcher::MockRun() const { + uint64_t sample_size = std::min(SAMPLE_SIZE, vector_data_cell_->TotalCount()); + + InnerSearchParam search_param; + search_param.ep_ = 0; + search_param.ef_ = 80; // experience value in benchmark + search_param.is_id_allowed_ = nullptr; + + auto st = std::chrono::high_resolution_clock::now(); + for (uint32_t i = 0; i < sample_size; ++i) { + bool release = false; + const auto* codes = vector_data_cell_->GetCodesById(i, release); + Vector raw_data(dim_, allocator_); + vector_data_cell_->Decode(codes, raw_data.data()); + Search(raw_data.data(), search_param); + } + auto ed = std::chrono::high_resolution_clock::now(); + double time_cost = std::chrono::duration(ed - st).count(); + return time_cost; +} + +template class BasicSearcher< + AdaptGraphDataCell, + FlattenDataCell, MemoryIO>>; + +} // namespace vsag \ No newline at end of file diff --git a/src/impl/basic_searcher.h b/src/impl/basic_searcher.h new file mode 100644 index 00000000..9a483b08 --- /dev/null +++ b/src/impl/basic_searcher.h @@ -0,0 +1,103 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "../utils.h" +#include "ThreadPool.h" +#include "algorithm/hnswlib/algorithm_interface.h" +#include "algorithm/hnswlib/visited_list_pool.h" +#include "common.h" +#include "data_cell/adapter_graph_datacell.h" +#include "data_cell/flatten_datacell.h" +#include "data_cell/flatten_interface.h" +#include "index/index_common_param.h" +#include "io/memory_io.h" +#include "quantization/fp32_quantizer.h" +#include "runtime_parameter.h" + +namespace vsag { + +static const InnerIdType SAMPLE_SIZE = 10000; +static const uint32_t CENTROID_EF = 500; +static const uint32_t PREFETCH_DEGREE_DIVIDE = 3; +static const uint32_t PREFETCH_MAXIMAL_DEGREE = 1; +static const uint32_t PREFETCH_MAXIMAL_LINES = 1; + +class InnerSearchParam { +public: + int topk_{0}; + float radius_{0.0f}; + InnerIdType ep_{0}; + uint64_t ef_{10}; + BaseFilterFunctor* is_id_allowed_{nullptr}; +}; + +struct CompareByFirst { + constexpr bool + operator()(std::pair const& a, + std::pair const& b) const noexcept { + return a.first < b.first; + } +}; + +using MaxHeap = std::priority_queue, + Vector>, + CompareByFirst>; + +template +class BasicSearcher { +public: + BasicSearcher(std::shared_ptr graph, + std::shared_ptr vector, + const IndexCommonParam& common_param); + + virtual MaxHeap + Search(const float* query, InnerSearchParam& inner_search_param) const; + + virtual double + MockRun() const; + + virtual void + Resize(uint64_t new_size); + + virtual void + SetRuntimeParameters(const UnorderedMap& new_params); + +private: + uint32_t + visit(hnswlib::VisitedListPtr vl, + std::pair& current_node_pair, + Vector& to_be_visited_rid, + Vector& to_be_visited_id) const; + +private: + Allocator* allocator_{nullptr}; + + std::shared_ptr graph_; + + std::shared_ptr vector_data_cell_; + + std::shared_ptr pool_{nullptr}; + + int64_t dim_{0}; + + uint32_t prefetch_neighbor_visit_num_{1}; +}; + +} // namespace vsag \ No newline at end of file diff --git a/src/impl/basic_searcher_test.cpp b/src/impl/basic_searcher_test.cpp new file mode 100644 index 00000000..748f440c --- /dev/null +++ b/src/impl/basic_searcher_test.cpp @@ -0,0 +1,123 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "basic_searcher.h" + +#include "algorithm/hnswlib/hnswalg.h" +#include "algorithm/hnswlib/space_l2.h" +#include "catch2/catch_template_test_macros.hpp" +#include "data_cell/adapter_graph_datacell.h" +#include "data_cell/flatten_datacell.h" +#include "default_allocator.h" +#include "fixtures.h" +#include "io/memory_io.h" +#include "quantization/fp32_quantizer.h" +#include "safe_allocator.h" + +using namespace vsag; + +TEST_CASE("search with alg_hnsw", "[ut][basic_searcher]") { + // data attr + uint32_t base_size = 1000; + uint32_t query_size = 100; + uint64_t dim = 960; + + // build and search attr + uint32_t M = 32; + uint32_t ef_construction = 100; + uint32_t ef_search = 300; + uint32_t k = ef_search; + InnerIdType fixed_entry_point_id = 0; + uint64_t DEFAULT_MAX_ELEMENT = 1; + + // data preparation + auto base_vectors = fixtures::generate_vectors(base_size, dim, true); + std::vector ids(base_size); + std::iota(ids.begin(), ids.end(), 0); + + // alg_hnsw + auto allocator = SafeAllocator::FactoryDefaultAllocator(); + auto space = std::make_shared(dim); + auto io = std::make_shared(allocator.get()); + auto alg_hnsw = + std::make_shared(space.get(), + DEFAULT_MAX_ELEMENT, + allocator.get(), + M / 2, + ef_construction, + Options::Instance().block_size_limit()); + alg_hnsw->init_memory_space(); + for (int64_t i = 0; i < base_size; ++i) { + auto successful_insert = + alg_hnsw->addPoint((const void*)(base_vectors.data() + i * dim), ids[i]); + REQUIRE(successful_insert == true); + } + + // graph data cell + auto graph_data_cell = std::make_shared(alg_hnsw); + using GraphTmpl = std::remove_pointer_t; + + // vector data cell + auto fp32_param = JsonType::parse("{}"); + auto io_param = JsonType::parse("{}"); + IndexCommonParam common; + common.dim_ = dim; + common.allocator_ = allocator; + common.metric_ = vsag::MetricType::METRIC_TYPE_L2SQR; + + auto vector_data_cell = std::make_shared< + FlattenDataCell, MemoryIO>>( + fp32_param, io_param, common); + vector_data_cell->SetQuantizer( + std::make_shared>(dim, allocator.get())); + vector_data_cell->SetIO(std::make_unique(allocator.get())); + + vector_data_cell->Train(base_vectors.data(), base_size); + vector_data_cell->BatchInsertVector(base_vectors.data(), base_size, ids.data()); + using VectorDataTmpl = std::remove_pointer_t; + + // searcher + auto searcher = std::make_shared>( + graph_data_cell, vector_data_cell, common); + + // search + InnerSearchParam search_param; + search_param.ep_ = fixed_entry_point_id; + search_param.ef_ = ef_search; + search_param.topk_ = k; + search_param.is_id_allowed_ = nullptr; + for (int i = 0; i < query_size; i++) { + std::unordered_set valid_set, set; + auto result = searcher->Search(base_vectors.data() + i * dim, search_param); + auto valid_result = alg_hnsw->searchBaseLayerST( + fixed_entry_point_id, base_vectors.data() + i * dim, ef_search, nullptr); + REQUIRE(result.size() == valid_result.size()); + + for (int j = 0; j < k - 1; j++) { + valid_set.insert(valid_result.top().second); + set.insert(result.top().second); + result.pop(); + valid_result.pop(); + } + for (auto id : set) { + REQUIRE(valid_set.find(id) != valid_set.end()); + } + for (auto id : valid_set) { + REQUIRE(set.find(id) != set.end()); + } + REQUIRE(result.top().second == valid_result.top().second); + REQUIRE(result.top().second == ids[i]); + } +} \ No newline at end of file diff --git a/src/impl/runtime_parameter.h b/src/impl/runtime_parameter.h new file mode 100644 index 00000000..001fac52 --- /dev/null +++ b/src/impl/runtime_parameter.h @@ -0,0 +1,102 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace vsag { +using ParamValue = std::variant; + +struct RuntimeParameter { +public: + RuntimeParameter(const std::string& name) : name_(name) { + } + virtual ~RuntimeParameter() = default; + + virtual ParamValue + sample(std::mt19937& gen) = 0; + + virtual ParamValue + next() = 0; + + virtual void + reset() = 0; + + virtual bool + is_end() = 0; + +public: + std::string name_; +}; + +struct IntRuntimeParameter : RuntimeParameter { +public: + IntRuntimeParameter(const std::string& name, int min, int max, int step = -1.0) + : RuntimeParameter(name), min_(min), max_(max) { + cur_ = min_; + is_end_ = (cur_ < max_); + if (step < 0) { + step_ = (max_ - min_) / 10.0; + } + if (step_ == 0) { + step_ = 1; + } + } + + ParamValue + sample(std::mt19937& gen) override { + std::uniform_real_distribution<> dis(min_, max_); + cur_ = int(dis(gen)); + return cur_; + } + + ParamValue + next() override { + cur_ += step_; + if (cur_ > max_) { + cur_ -= (max_ - min_); + } + return cur_; + } + + void + reset() override { + cur_ = min_; + } + + bool + is_end() override { + return is_end_; + } + +private: + int min_{0}; + int max_{0}; + int step_{0}; + int cur_{0}; + bool is_end_{false}; +}; +} // namespace vsag \ No newline at end of file From 8b7d19492f4c3d931f05f1cd528739663729a6ab Mon Sep 17 00:00:00 2001 From: "zhongxiaoyao.zxy" Date: Fri, 10 Jan 2025 17:56:49 +0800 Subject: [PATCH 2/2] update Signed-off-by: zhongxiaoyao.zxy --- src/allocator_wrapper.h | 13 +++++++++++++ src/impl/basic_optimizer.cpp | 22 ++++++++++++++++++++++ src/impl/basic_optimizer.h | 20 ++++++++++++-------- src/impl/basic_searcher.h | 2 +- src/impl/basic_searcher_test.cpp | 13 +++++++++++-- 5 files changed, 59 insertions(+), 11 deletions(-) create mode 100644 src/impl/basic_optimizer.cpp diff --git a/src/allocator_wrapper.h b/src/allocator_wrapper.h index 5479aae8..7fd09258 100644 --- a/src/allocator_wrapper.h +++ b/src/allocator_wrapper.h @@ -71,4 +71,17 @@ class AllocatorWrapper { Allocator* allocator_{}; }; + +template +bool +operator==(const AllocatorWrapper&, const AllocatorWrapper&) noexcept { + return true; +} + +template +bool +operator!=(const AllocatorWrapper& a, const AllocatorWrapper& b) noexcept { + return !(a == b); +} + } // namespace vsag diff --git a/src/impl/basic_optimizer.cpp b/src/impl/basic_optimizer.cpp new file mode 100644 index 00000000..8691282b --- /dev/null +++ b/src/impl/basic_optimizer.cpp @@ -0,0 +1,22 @@ + +// Copyright 2024-present the vsag project +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "basic_optimizer.h" + +namespace vsag { +template class Optimizer< + BasicSearcher, MemoryIO>>>; +} \ No newline at end of file diff --git a/src/impl/basic_optimizer.h b/src/impl/basic_optimizer.h index a9da32f1..5eef2df9 100644 --- a/src/impl/basic_optimizer.h +++ b/src/impl/basic_optimizer.h @@ -15,6 +15,7 @@ #pragma once #include "../utils.h" +#include "basic_searcher.h" #include "logger.h" #include "runtime_parameter.h" @@ -22,12 +23,13 @@ namespace vsag { template class Optimizer { - Optimizer(std::shared_ptr allocator, int trials = 100) - : parameters_(allocator.get()), - best_params_(allocator.get()), +public: + Optimizer(const IndexCommonParam& common_param, int trials = 100) + : parameters_(common_param.allocator_.get()), + best_params_(common_param.allocator_.get()), n_trials_(trials), best_loss_(std::numeric_limits::max()) { - allocator_ = allocator.get(); + allocator_ = common_param.allocator_.get(); std::random_device rd; gen_.seed(rd()); } @@ -38,8 +40,8 @@ class Optimizer { } void - Optimize(OptimizableOBJ& obj) { - double original_loss = obj.MockRun(); + Optimize(std::shared_ptr obj) { + double original_loss = obj->MockRun(); for (int i = 0; i < n_trials_; ++i) { // generate a group of runtime params @@ -47,10 +49,10 @@ class Optimizer { for (auto& param : parameters_) { current_params[param->name_] = param->sample(gen_); } - obj.SetRuntimeParameters(current_params); + obj->SetRuntimeParameters(current_params); // evaluate - double loss = obj.MockRun(); + double loss = obj->MockRun(); // update if (loss < best_loss_) { @@ -62,6 +64,8 @@ class Optimizer { (original_loss - best_loss_) / original_loss)); } } + + obj->SetRuntimeParameters(best_params_); } UnorderedMap diff --git a/src/impl/basic_searcher.h b/src/impl/basic_searcher.h index 9a483b08..e488af58 100644 --- a/src/impl/basic_searcher.h +++ b/src/impl/basic_searcher.h @@ -33,7 +33,7 @@ namespace vsag { -static const InnerIdType SAMPLE_SIZE = 10000; +static const InnerIdType SAMPLE_SIZE = 1000; static const uint32_t CENTROID_EF = 500; static const uint32_t PREFETCH_DEGREE_DIVIDE = 3; static const uint32_t PREFETCH_MAXIMAL_DEGREE = 1; diff --git a/src/impl/basic_searcher_test.cpp b/src/impl/basic_searcher_test.cpp index 748f440c..cc2a9bc9 100644 --- a/src/impl/basic_searcher_test.cpp +++ b/src/impl/basic_searcher_test.cpp @@ -17,6 +17,7 @@ #include "algorithm/hnswlib/hnswalg.h" #include "algorithm/hnswlib/space_l2.h" +#include "basic_optimizer.h" #include "catch2/catch_template_test_macros.hpp" #include "data_cell/adapter_graph_datacell.h" #include "data_cell/flatten_datacell.h" @@ -28,7 +29,7 @@ using namespace vsag; -TEST_CASE("search with alg_hnsw", "[ut][basic_searcher]") { +TEST_CASE("search with alg_hnsw and optimizer", "[ut][basic_searcher]") { // data attr uint32_t base_size = 1000; uint32_t query_size = 100; @@ -88,9 +89,17 @@ TEST_CASE("search with alg_hnsw", "[ut][basic_searcher]") { vector_data_cell->BatchInsertVector(base_vectors.data(), base_size, ids.data()); using VectorDataTmpl = std::remove_pointer_t; - // searcher + // init searcher and optimizer auto searcher = std::make_shared>( graph_data_cell, vector_data_cell, common); + auto optimizer = + std::make_shared>>(common, 1); + optimizer->RegisterParameter(std::make_shared(PREFETCH_CACHE_LINE, 1, 10)); + optimizer->RegisterParameter( + std::make_shared(PREFETCH_NEIGHBOR_CODE_NUM, 1, 10)); + optimizer->RegisterParameter( + std::make_shared(PREFETCH_NEIGHBOR_VISIT_NUM, 1, 10)); + optimizer->Optimize(searcher); // search InnerSearchParam search_param;