diff --git a/benchs/indexes/hgraph-95.yml b/benchs/indexes/hgraph-95.yml index 039be1821..2886add5d 100644 --- a/benchs/indexes/hgraph-95.yml +++ b/benchs/indexes/hgraph-95.yml @@ -25,8 +25,8 @@ HGRAPH/GIST/95: datapath: "/tmp/data/gist-960-euclidean.hdf5" type: "build,search" # build, search index_name: "hgraph" - create_params: '{"dim":960,"dtype":"float32","metric_type":"l2","index_param":{"base_quantization_type":"sq8_uniform","max_degree":96,"ef_construction":400, "precise_quantization_type":"fp32", "use_reorder":true}}' - search_params: '{"hgraph":{"ef_search":120}}' + create_params: '{"dim":960,"dtype":"float32","metric_type":"l2","index_param":{"base_quantization_type":"sq8_uniform","max_degree":96,"ef_construction":400,"tau":0.02, "precise_quantization_type":"fp32", "use_reorder":true}}' + search_params: '{"hgraph":{"ef_search":88}}' index_path: "/tmp/gist-960-euclidean/index/hgraph_index" topk: 10 search_mode: "knn" # ["knn", "range", "knn_filter", "range_filter"] diff --git a/include/vsag/constants.h b/include/vsag/constants.h index 5d12711dd..859cf03db 100644 --- a/include/vsag/constants.h +++ b/include/vsag/constants.h @@ -162,6 +162,7 @@ extern const char* const HGRAPH_BASE_QUANTIZATION_TYPE; extern const char* const HGRAPH_GRAPH_MAX_DEGREE; extern const char* const HGRAPH_BUILD_EF_CONSTRUCTION; extern const char* const HGRAPH_BUILD_ALPHA; +extern const char* const HGRAPH_BUILD_TAU; extern const char* const HGRAPH_INIT_CAPACITY; extern const char* const HGRAPH_GRAPH_TYPE; extern const char* const HGRAPH_GRAPH_STORAGE_TYPE; diff --git a/src/algorithm/hgraph.cpp b/src/algorithm/hgraph.cpp index 3589b0a0f..d3152d1a7 100644 --- a/src/algorithm/hgraph.cpp +++ b/src/algorithm/hgraph.cpp @@ -50,6 +50,8 @@ HGraph::HGraph(const HGraphParameterPtr& hgraph_param, const vsag::IndexCommonPa build_by_base_(hgraph_param->build_by_base), ef_construct_(hgraph_param->ef_construction), alpha_(hgraph_param->alpha), + tau_(hgraph_param->tau), + select_edge_param(hgraph_param->selectedgeparam), odescent_param_(hgraph_param->odescent_param), graph_type_(hgraph_param->graph_type), hierarchical_datacell_param_(hgraph_param->hierarchical_graph_param), @@ -1537,13 +1539,24 @@ HGraph::graph_add_one(const void* data, int level, InnerIdType inner_id) { label_table_->SetDuplicateId(static_cast(param.duplicate_id), inner_id); return false; } - mutually_connect_new_element(inner_id, - result, - this->bottom_graph_, - flatten_codes, - neighbors_mutex_, - allocator_, - alpha_); + if (select_edge_param == "alpha") { + mutually_connect_new_element(inner_id, + result, + this->bottom_graph_, + flatten_codes, + neighbors_mutex_, + allocator_, + alpha_); + } else if (select_edge_param == "tau") { + mutually_connect_new_element(inner_id, + result, + this->bottom_graph_, + flatten_codes, + neighbors_mutex_, + allocator_, + tau_); + } + } else { bottom_graph_->InsertNeighborsById(inner_id, Vector(allocator_)); } @@ -1557,13 +1570,23 @@ HGraph::graph_add_one(const void* data, int level, InnerIdType inner_id) { // to specify which overloaded function to call (VisitedListPtr) nullptr, discard_stats); - mutually_connect_new_element(inner_id, - result, - route_graphs_[j], - flatten_codes, - neighbors_mutex_, - allocator_, - alpha_); + if (select_edge_param == "alpha") { + mutually_connect_new_element(inner_id, + result, + this->bottom_graph_, + flatten_codes, + neighbors_mutex_, + allocator_, + alpha_); + } else if (select_edge_param == "tau") { + mutually_connect_new_element(inner_id, + result, + this->bottom_graph_, + flatten_codes, + neighbors_mutex_, + allocator_, + tau_); + } } else { route_graphs_[j]->InsertNeighborsById(inner_id, Vector(allocator_)); } diff --git a/src/algorithm/hgraph.h b/src/algorithm/hgraph.h index 1e35be265..8f14419a5 100644 --- a/src/algorithm/hgraph.h +++ b/src/algorithm/hgraph.h @@ -355,6 +355,8 @@ class HGraph : public InnerIndexInterface { uint64_t ef_construct_{400}; float alpha_{1.0}; + float tau_{0.0F}; + std::string select_edge_param{"alpha"}; std::atomic total_count_{0}; diff --git a/src/algorithm/hgraph_parameter.cpp b/src/algorithm/hgraph_parameter.cpp index 4a2b188a9..9a6ab9bd6 100644 --- a/src/algorithm/hgraph_parameter.cpp +++ b/src/algorithm/hgraph_parameter.cpp @@ -104,6 +104,10 @@ HGraphParameter::FromJson(const JsonType& json) { if (json.Contains(ALPHA_KEY)) { this->alpha = json[ALPHA_KEY].GetFloat(); + this->selectedgeparam = "alpha"; + } else if (json.Contains(TAU_KEY)) { + this->tau = json[TAU_KEY].GetFloat(); + this->selectedgeparam = "tau"; } if (json.Contains(BUILD_THREAD_COUNT_KEY)) { @@ -136,6 +140,7 @@ HGraphParameter::ToJson() const { json[GRAPH_KEY].SetJson(this->bottom_graph_param->ToJson()); json[EF_CONSTRUCTION_KEY].SetInt(this->ef_construction); json[ALPHA_KEY].SetFloat(this->alpha); + json[TAU_KEY].SetFloat(this->tau); json[SUPPORT_DUPLICATE].SetBool(this->support_duplicate); return json; } diff --git a/src/algorithm/hgraph_parameter.h b/src/algorithm/hgraph_parameter.h index 96eddc355..ff3c1ec5e 100644 --- a/src/algorithm/hgraph_parameter.h +++ b/src/algorithm/hgraph_parameter.h @@ -59,6 +59,8 @@ class HGraphParameter : public InnerIndexParameter { uint64_t ef_construction{400}; float alpha{1.0F}; + float tau{0.0F}; + std::string selectedgeparam{"alpha"}; bool support_duplicate{false}; bool support_tombstone{false}; diff --git a/src/algorithm/pyramid.cpp b/src/algorithm/pyramid.cpp index fce892466..5f75d249b 100644 --- a/src/algorithm/pyramid.cpp +++ b/src/algorithm/pyramid.cpp @@ -669,8 +669,13 @@ Pyramid::add_one_point(const std::shared_ptr& node, auto results = searcher_->Search( node->graph_, codes, vl, vector, search_param, (LabelTablePtr) nullptr, discard_stats); pool_->ReturnOne(vl); - mutually_connect_new_element( - inner_id, results, node->graph_, codes, points_mutex_, allocator_, alpha_); + if (select_edge_param == "alpha") { + mutually_connect_new_element( + inner_id, results, node->graph_, codes, points_mutex_, allocator_, alpha_); + } else { + mutually_connect_new_element( + inner_id, results, node->graph_, codes, points_mutex_, allocator_, tau_); + } if (update_entry_point) { node->entry_point_ = inner_id; } diff --git a/src/algorithm/pyramid.h b/src/algorithm/pyramid.h index fc77f7372..479324e64 100644 --- a/src/algorithm/pyramid.h +++ b/src/algorithm/pyramid.h @@ -97,6 +97,8 @@ class Pyramid : public InnerIndexInterface { Pyramid(const PyramidParamPtr& pyramid_param, const IndexCommonParam& common_param) : InnerIndexInterface(pyramid_param, common_param), alpha_(pyramid_param->alpha), + tau_(pyramid_param->tau), + select_edge_param(pyramid_param->selectedgeparam), no_build_levels_(common_param.allocator_.get()), odescent_param_(pyramid_param->odescent_param), ef_construction_(pyramid_param->ef_construction), @@ -223,6 +225,8 @@ class Pyramid : public InnerIndexInterface { int64_t max_capacity_{0}; int64_t cur_element_count_{0}; float alpha_{1.0F}; + float tau_{0.0F}; + std::string select_edge_param{"alpha"}; std::shared_mutex resize_mutex_; std::mutex cur_element_count_mutex_; diff --git a/src/algorithm/pyramid_zparameters.cpp b/src/algorithm/pyramid_zparameters.cpp index f58114dcb..1a9566d41 100644 --- a/src/algorithm/pyramid_zparameters.cpp +++ b/src/algorithm/pyramid_zparameters.cpp @@ -47,6 +47,14 @@ PyramidParameters::FromJson(const JsonType& json) { this->base_codes_param = CreateFlattenParam(json[BASE_CODES_KEY]); + if (json.Contains(ALPHA_KEY)) { + this->alpha = json[ALPHA_KEY].GetFloat(); + this->selectedgeparam = "alpha"; + } else if (json.Contains(TAU_KEY)) { + this->tau = json[TAU_KEY].GetFloat(); + this->selectedgeparam = "tau"; + } + if (json.Contains(NO_BUILD_LEVELS)) { const auto& no_build_levels_json = json[NO_BUILD_LEVELS]; CHECK_ARGUMENT(no_build_levels_json.IsArray(), diff --git a/src/algorithm/pyramid_zparameters.h b/src/algorithm/pyramid_zparameters.h index 5913da22d..7eba22066 100644 --- a/src/algorithm/pyramid_zparameters.h +++ b/src/algorithm/pyramid_zparameters.h @@ -49,10 +49,13 @@ struct PyramidParameters : public InnerIndexParameter { ODescentParameterPtr odescent_param{nullptr}; std::vector no_build_levels; + uint64_t ef_construction{400}; int64_t max_degree{64}; std::string graph_type{GRAPH_TYPE_VALUE_NSW}; float alpha{1.2F}; + float tau{0.0F}; + std::string selectedgeparam{"alpha"}; uint32_t index_min_size{0}; }; diff --git a/src/constants.cpp b/src/constants.cpp index 17c2caaf3..f78ff2bed 100644 --- a/src/constants.cpp +++ b/src/constants.cpp @@ -148,6 +148,7 @@ const char* const HGRAPH_BASE_QUANTIZATION_TYPE = "base_quantization_type"; const char* const HGRAPH_GRAPH_MAX_DEGREE = "max_degree"; const char* const HGRAPH_BUILD_EF_CONSTRUCTION = "ef_construction"; const char* const HGRAPH_BUILD_ALPHA = "alpha"; +const char* const HGRAPH_BUILD_TAU = "tau"; const char* const HGRAPH_INIT_CAPACITY = "hgraph_init_capacity"; const char* const HGRAPH_GRAPH_TYPE = "graph_type"; const char* const HGRAPH_GRAPH_STORAGE_TYPE = "graph_storage_type"; diff --git a/src/impl/pruning_strategy.cpp b/src/impl/pruning_strategy.cpp index c4f8f43c8..633d66ac0 100644 --- a/src/impl/pruning_strategy.cpp +++ b/src/impl/pruning_strategy.cpp @@ -21,12 +21,13 @@ #include "utils/lock_strategy.h" namespace vsag { +template void select_edges_by_heuristic(const DistHeapPtr& edges, uint64_t max_size, const FlattenInterfacePtr& flatten, Allocator* allocator, - float alpha) { + float param_value) { if (edges->Size() < max_size) { return; } @@ -49,9 +50,21 @@ select_edges_by_heuristic(const DistHeapPtr& edges, for (const auto& second_pair : return_list) { float curdist = flatten->ComputePairVectors(second_pair.second, current_pair.second); - if (alpha * curdist < float_query) { - good = false; - break; + + if constexpr (param == EdgeSelectionParam::ALPHA) { + if (param_value * curdist < float_query) { + good = false; + break; + } + } else { + if (curdist < (float_query - 3 * param_value)) { + good = false; + break; + } + if (float_query <= 3 * param_value) { + good = true; + break; + } } } if (good) { @@ -64,6 +77,7 @@ select_edges_by_heuristic(const DistHeapPtr& edges, } } +template InnerIdType mutually_connect_new_element(InnerIdType cur_c, const DistHeapPtr& top_candidates, @@ -71,9 +85,11 @@ mutually_connect_new_element(InnerIdType cur_c, const FlattenInterfacePtr& flatten, const MutexArrayPtr& neighbors_mutexes, Allocator* allocator, - float alpha) { + float param_value) { const size_t max_size = graph->MaximumDegree(); - select_edges_by_heuristic(top_candidates, max_size, flatten, allocator, alpha); + + select_edges_by_heuristic(top_candidates, max_size, flatten, allocator, param_value); + if (top_candidates->Size() > max_size) { throw VsagException( ErrorType::INTERNAL_ERROR, @@ -123,7 +139,7 @@ mutually_connect_new_element(InnerIdType cur_c, neighbors[j]); } - select_edges_by_heuristic(candidates, max_size, flatten, allocator, alpha); + select_edges_by_heuristic(candidates, max_size, flatten, allocator, param_value); Vector cand_neighbors(allocator); while (not candidates->Empty()) { @@ -136,4 +152,30 @@ mutually_connect_new_element(InnerIdType cur_c, return next_closest_entry_point; } +template void +select_edges_by_heuristic( + const DistHeapPtr&, uint64_t, const FlattenInterfacePtr&, Allocator*, float); + +template void +select_edges_by_heuristic( + const DistHeapPtr&, uint64_t, const FlattenInterfacePtr&, Allocator*, float); + +template InnerIdType +mutually_connect_new_element(InnerIdType, + const DistHeapPtr&, + const GraphInterfacePtr&, + const FlattenInterfacePtr&, + const MutexArrayPtr&, + Allocator*, + float); + +template InnerIdType +mutually_connect_new_element(InnerIdType, + const DistHeapPtr&, + const GraphInterfacePtr&, + const FlattenInterfacePtr&, + const MutexArrayPtr&, + Allocator*, + float); + } // namespace vsag diff --git a/src/impl/pruning_strategy.h b/src/impl/pruning_strategy.h index a1a6bf4c2..129ba021a 100644 --- a/src/impl/pruning_strategy.h +++ b/src/impl/pruning_strategy.h @@ -15,6 +15,7 @@ #pragma once +#include "inner_search_param.h" #include "typing.h" #include "utils/pointer_define.h" @@ -24,13 +25,17 @@ DEFINE_POINTER(FlattenInterface); DEFINE_POINTER(GraphInterface); DEFINE_POINTER(MutexArray); +enum class EdgeSelectionParam { ALPHA, TAU }; + +template void select_edges_by_heuristic(const DistHeapPtr& edges, uint64_t max_size, const FlattenInterfacePtr& flatten, Allocator* allocator, - float alpha = 1.0F); + float param_value = (Param == EdgeSelectionParam::ALPHA) ? 1.0F : 0.0F); +template InnerIdType mutually_connect_new_element(InnerIdType cur_c, const DistHeapPtr& top_candidates, @@ -38,6 +43,33 @@ mutually_connect_new_element(InnerIdType cur_c, const FlattenInterfacePtr& flatten, const MutexArrayPtr& neighbors_mutexes, Allocator* allocator, - float alpha = 1.0F); + float param_value = (Param == EdgeSelectionParam::ALPHA) ? 1.0F + : 0.0F); + +extern template void +select_edges_by_heuristic( + const DistHeapPtr&, uint64_t, const FlattenInterfacePtr&, Allocator*, float); + +extern template void +select_edges_by_heuristic( + const DistHeapPtr&, uint64_t, const FlattenInterfacePtr&, Allocator*, float); + +extern template InnerIdType +mutually_connect_new_element(InnerIdType, + const DistHeapPtr&, + const GraphInterfacePtr&, + const FlattenInterfacePtr&, + const MutexArrayPtr&, + Allocator*, + float); + +extern template InnerIdType +mutually_connect_new_element(InnerIdType, + const DistHeapPtr&, + const GraphInterfacePtr&, + const FlattenInterfacePtr&, + const MutexArrayPtr&, + Allocator*, + float); } // namespace vsag diff --git a/src/impl/pruning_strategy_test.cpp b/src/impl/pruning_strategy_test.cpp index 4d8d7511d..47ae3db53 100644 --- a/src/impl/pruning_strategy_test.cpp +++ b/src/impl/pruning_strategy_test.cpp @@ -93,7 +93,8 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy // - Compare with ID1: 1.0 * 25.0 (ID1-ID2 distance) = 25.0 < 34.0 -> PRUNE ID2 // - return_list remains [ID1] // Final return_list size: 1 - select_edges_by_heuristic(edges, 3, flatten, allocator.get(), 1.0F); + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 1.0F); REQUIRE(edges->Size() == 1); std::vector kept; @@ -120,7 +121,8 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy // Step 4: Process ID2: 1.5 * 25.0 = 37.5 < 34.0-> NO, but check ID3 // - 1.5 * 13.0 (ID2-ID3 distance) = 19.5 < 34.0-> PRUNE ID2 // Final return_list size: 2 - select_edges_by_heuristic(edges, 3, flatten, allocator.get(), 1.5F); + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 1.5F); REQUIRE(edges->Size() == 2); std::vector kept; @@ -140,7 +142,8 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy edges->Push(d04, 4); //similar process - select_edges_by_heuristic(edges, 3, flatten, allocator.get(), 2.0F); + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 2.0F); REQUIRE(edges->Size() == 2); std::vector kept; @@ -159,7 +162,8 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy edges->Push(d03, 3); edges->Push(d04, 4); - select_edges_by_heuristic(edges, 3, flatten, allocator.get(), 3.5F); + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 3.5F); REQUIRE(edges->Size() == 3); std::vector kept; @@ -185,8 +189,8 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy auto mutexes = std::make_shared(); MutexArrayPtr mutex_array = std::make_shared(); - auto entry_point = - mutually_connect_new_element(0, candidates, graph, flatten, mutexes, allocator.get()); + auto entry_point = mutually_connect_new_element( + 0, candidates, graph, flatten, mutexes, allocator.get(), 1.0F); REQUIRE(entry_point == 1); @@ -196,4 +200,197 @@ TEST_CASE("Pruning Strategy Select Edges With Heuristic", "[ut][pruning_strategy } } +TEST_CASE("Pruning Strategy Select Edges With Tau-MNG", "[ut][pruning_strategy]") { + auto allocator = Engine::CreateDefaultAllocator(); + auto flatten_param = std::make_shared(); + flatten_param->quantizer_parameter = std::make_shared(); + flatten_param->io_parameter = std::make_shared(); + IndexCommonParam common_param; + common_param.allocator_ = allocator; + common_param.metric_ = MetricType::METRIC_TYPE_L2SQR; + common_param.dim_ = 128; + + auto flatten = FlattenInterface::MakeInstance(flatten_param, common_param); + REQUIRE(flatten != nullptr); + + float vectors[5][128] = {0}; + vectors[0][0] = 0.0F; + vectors[1][0] = 1.0F; + vectors[2][0] = 5.0F; + vectors[3][0] = 10.0F; + vectors[4][0] = 20.0F; + + flatten->Train(vectors, 5); + flatten->BatchInsertVector(vectors, 5); + + const float d01 = 1.0F; // distance to ID1 + const float d02 = 25.0F; // distance to ID2 + const float d03 = 100.0F; // distance to ID3 + const float d04 = 400.0F; // distance to ID4 + + SECTION("Tau-MNG with small tau value") { + auto edges = std::make_shared>(allocator.get(), -1); + edges->Push(d01, 1); + edges->Push(d02, 2); + edges->Push(d03, 3); + edges->Push(d04, 4); + + // τ-MNG pruning with tau=2.0, alpha=1.0 (max_size=3) + // τ-MNG rule: if curdist < (float_query - 3 * tau), prune + // For tau=2.0, 3*tau=6.0 + // Step 1: Keep ID1 (distance=1.0), return_list = [ID1] + // Step 2: Process ID2 (distance=25.0) + // - Compare with ID1: dist(ID1,ID2)=16.0 + // - Check: 16.0 < (25.0 - 6.0)=19.0? YES -> prune ID2 + // - return_list = [ID1, ID2] + // Step 3: Process ID3 (distance=100.0) + // - Compare with ID1: dist(ID1,ID3)=99.0 < (100.0-6.0)=94.0? NO + // - Compare with ID2: dist(ID2,ID3)=75.0 < (100.0-6.0)=94.0? YES -> prune ID3 + // Step 4: Process ID4 (distance=400.0) + // - Compare with ID1: dist(ID1,ID4)=399.0 < (400.0-6.0)=394.0? NO + // - Compare with ID2: dist(ID2,ID4)=375.0 < (400.0-6.0)=394.0? YES -> prune ID4 + // Final return_list: [ID1, ID2] + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 2.0F); + + REQUIRE(edges->Size() == 1); + std::vector kept; + while (!edges->Empty()) { + kept.push_back(edges->Top().second); + edges->Pop(); + } + std::sort(kept.begin(), kept.end()); + REQUIRE(kept == std::vector{1}); + } + + SECTION("Tau-MNG with larger tau value") { + auto edges = std::make_shared>(allocator.get(), -1); + edges->Push(d01, 1); + edges->Push(d02, 2); + edges->Push(d03, 3); + edges->Push(d04, 4); + + // τ-MNG pruning with tau=5.0, alpha=1.0 (max_size=3) + // For tau=5.0, 3*tau=15.0 + // Step 1: Keep ID1, return_list = [ID1] + // Step 2: Process ID2 (distance=25.0) + // - dist(ID1,ID2)=24.0 < (25.0-15.0)=10.0? NO -> keep ID2 + // - return_list = [ID1, ID2] + // Step 3: Process ID3 (distance=100.0) + // - dist(ID1,ID3)=99.0 < (100.0-15.0)=85.0? NO + // - dist(ID2,ID3)=75.0 < (100.0-15.0)=85.0? YES -> prune ID3 + // Step 4: Process ID4 (distance=400.0) + // - dist(ID1,ID4)=399.0 < (400.0-15.0)=385.0? NO + // - dist(ID2,ID4)=375.0 < (400.0-15.0)=385.0? YES -> prune ID4 + // Final return_list: [ID1, ID2] + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 5.0F); + + REQUIRE(edges->Size() == 2); + std::vector kept; + while (!edges->Empty()) { + kept.push_back(edges->Top().second); + edges->Pop(); + } + std::sort(kept.begin(), kept.end()); + REQUIRE(kept == std::vector{1, 2}); + } + + SECTION("Tau-MNG with very small tau (approaching original heuristic)") { + auto edges = std::make_shared>(allocator.get(), -1); + edges->Push(d01, 1); + edges->Push(d02, 2); + edges->Push(d03, 3); + edges->Push(d04, 4); + + // τ-MNG with very small tau=0.1 should behave similarly to original heuristic + // For tau=0.1, 3*tau=0.3 + // This will be very close to original heuristic behavior + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 0.1F); + + REQUIRE(edges->Size() == 1); + std::vector kept; + while (!edges->Empty()) { + kept.push_back(edges->Top().second); + edges->Pop(); + } + std::sort(kept.begin(), kept.end()); + REQUIRE(kept == std::vector{1}); + } + + SECTION("Tau-MNG with short edge forced connection") { + auto edges = std::make_shared>(allocator.get(), -1); + + // Add some very short edges that should be forced to connect + const float short_dist_1 = 2.0F; // <= 3*tau for tau=1.0 + const float short_dist_2 = 2.5F; // <= 3*tau for tau=1.0 + + edges->Push(short_dist_1, 1); + edges->Push(short_dist_2, 2); + edges->Push(d03, 3); + edges->Push(d04, 4); + + // τ-MNG with tau=1.0, short edges should be forced to connect + // For tau=1.0, 3*tau=3.0 + // ID1 (dist=2.0) and ID2 (dist=2.5) are both <= 3.0, so they should be kept + // regardless of other conditions due to τ-MNG forced connection rule + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 1.0F); + + REQUIRE(edges->Size() == 2); + std::vector kept; + while (!edges->Empty()) { + kept.push_back(edges->Top().second); + edges->Pop(); + } + std::sort(kept.begin(), kept.end()); + REQUIRE(kept == std::vector{1, 2}); + } + + SECTION("Mutual connection with Tau-MNG algorithm") { + auto graph_param = std::make_shared(); + graph_param->io_parameter_ = std::make_shared(); + graph_param->max_degree_ = 4; + auto graph = GraphInterface::MakeInstance(graph_param, common_param); + + auto candidates = std::make_shared>(allocator.get(), -1); + candidates->Push(d01, 1); + candidates->Push(d02, 2); + candidates->Push(d03, 3); + candidates->Push(d04, 4); + + auto mutexes = std::make_shared(); + MutexArrayPtr mutex_array = std::make_shared(); + + auto entry_point = mutually_connect_new_element( + 0, candidates, graph, flatten, mutex_array, allocator.get(), 2.0F); + REQUIRE(entry_point == 1); + Vector neighbors_0(allocator.get()); + graph->GetNeighbors(0, neighbors_0); + REQUIRE(neighbors_0.size() == 1); // ID1 should be kept + } + + SECTION("Tau-MNG with zero tau falls back to original heuristic") { + auto edges = std::make_shared>(allocator.get(), -1); + edges->Push(d01, 1); + edges->Push(d02, 2); + edges->Push(d03, 3); + edges->Push(d04, 4); + + // tau=0 should behave exactly like original heuristic with alpha=1.0 + select_edges_by_heuristic( + edges, 3, flatten, allocator.get(), 0.0F); + + REQUIRE(edges->Size() == 1); + std::vector kept; + while (!edges->Empty()) { + kept.push_back(edges->Top().second); + edges->Pop(); + } + std::sort(kept.begin(), kept.end()); + REQUIRE(kept == std::vector{1}); + } +} + } // namespace vsag diff --git a/src/inner_string_params.h b/src/inner_string_params.h index ea74d2354..c2c7ddef9 100644 --- a/src/inner_string_params.h +++ b/src/inner_string_params.h @@ -46,6 +46,7 @@ const char* const HGRAPH_IGNORE_REORDER_KEY = "ignore_reorder"; const char* const HGRAPH_BUILD_BY_BASE_QUANTIZATION_KEY = "build_by_base"; const char* const GRAPH_KEY = "graph"; const char* const ALPHA_KEY = "alpha"; +const char* const TAU_KEY = "tau"; // IO param key const char* const IO_PARAMS_KEY = "io_params";