Skip to content

Commit c81a417

Browse files
committed
update
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
1 parent 55dcb8a commit c81a417

File tree

10 files changed

+337
-4
lines changed

10 files changed

+337
-4
lines changed

include/vsag/index.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,31 @@ class Index {
8888
throw std::runtime_error("Index not support delete vector");
8989
}
9090

91+
/**
92+
* Update the id of a base point from the index
93+
*
94+
* @param old_id indicates the old id of a base point in index
95+
* @param new_id is the updated new id of the base point
96+
* @return result indicates whether the update operation is successful.
97+
*/
98+
virtual tl::expected<bool, Error>
99+
UpdateId(int64_t old_id, int64_t new_id) {
100+
throw std::runtime_error("Index not support update id");
101+
}
102+
103+
/**
104+
* Update the vector of a base point from the index
105+
*
106+
* @param id indicates the old id of a base point in index
107+
* @param new_base is the updated new vector of the base point
108+
* @param need_fine_tune indicates whether the connection of the base point needs to be fine-tuned
109+
* @return result indicates whether the update operation is successful.
110+
*/
111+
virtual tl::expected<bool, Error>
112+
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) {
113+
throw std::runtime_error("Index not support update vector");
114+
}
115+
91116
/**
92117
* Performing single KNN search on index
93118
*

include/vsag/index_feature.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ enum IndexFeature {
4848
SUPPORT_BATCH_SEARCH_WITH_MULTI_THREAD, /**< Supports batch searching with multi-threading */
4949

5050
SUPPORT_ADD_CONCURRENT, /**< Supports concurrent addition of elements */
51+
SUPPORT_UPDATE_ID_CONCURRENT, /**< Supports concurrent update id of elements */
52+
SUPPORT_UPDATE_VECTOR_CONCURRENT, /**< Supports concurrent update vector of elements */
5153
SUPPORT_SEARCH_CONCURRENT, /**< Supports concurrent searching */
5254
SUPPORT_DELETE_CONCURRENT, /**< Supports concurrent deletion */
5355
SUPPORT_ADD_SEARCH_CONCURRENT, /**< Supports concurrent addition and searching */

src/algorithm/hnswlib/hnswalg.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1136,6 +1136,41 @@ HierarchicalNSW::dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_
11361136
}
11371137
}
11381138

1139+
void
1140+
HierarchicalNSW::updateVector(LabelType label, const void* data_point) {
1141+
std::unique_lock<std::mutex> lock(global_);
1142+
auto iter = label_lookup_.find(label);
1143+
if (iter == label_lookup_.end()) {
1144+
throw std::runtime_error(fmt::format("no label {} in HNSW", label));
1145+
} else {
1146+
InnerIdType internal_id = iter->second;
1147+
1148+
// reset data
1149+
std::shared_ptr<float[]> normalize_data;
1150+
normalizeVector(data_point, normalize_data);
1151+
memcpy(getDataByInternalId(internal_id), data_point, data_size_);
1152+
}
1153+
}
1154+
1155+
void
1156+
HierarchicalNSW::updateLabel(LabelType old_label, LabelType new_label) {
1157+
std::unique_lock<std::mutex> lock(global_);
1158+
auto iter_old = label_lookup_.find(old_label);
1159+
auto iter_new = label_lookup_.find(new_label);
1160+
if (iter_old == label_lookup_.end()) {
1161+
throw std::runtime_error(fmt::format("no old label {} in HNSW", old_label));
1162+
} else if (iter_new != label_lookup_.end()) {
1163+
throw std::runtime_error(fmt::format("new label {} has been in HNSW", new_label));
1164+
} else {
1165+
InnerIdType internal_id = iter_old->second;
1166+
1167+
// reset label
1168+
label_lookup_.erase(iter_old);
1169+
label_lookup_[new_label] = internal_id;
1170+
setExternalLabel(internal_id, new_label);
1171+
}
1172+
}
1173+
11391174
void
11401175
HierarchicalNSW::removePoint(LabelType label) {
11411176
InnerIdType cur_c = 0;

src/algorithm/hnswlib/hnswalg.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
348348
void
349349
dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_c);
350350

351+
void
352+
updateLabel(LabelType old_label, LabelType new_label);
353+
354+
void
355+
updateVector(LabelType label, const void* data_point);
356+
351357
void
352358
removePoint(LabelType label);
353359

src/index/hnsw.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,51 @@ HNSW::GetStats() const {
616616
return j.dump();
617617
}
618618

619+
tl::expected<bool, Error>
620+
HNSW::update_id(int64_t old_id, int64_t new_id) {
621+
if (use_static_) {
622+
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
623+
"static hnsw does not support update");
624+
}
625+
626+
try {
627+
// note that the validation of old_id is handled within updateLabel.
628+
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateLabel(old_id,
629+
new_id);
630+
} catch (const std::runtime_error& e) {
631+
spdlog::warn(
632+
"update error for replace old_id {} to new_id {}: {}", old_id, new_id, e.what());
633+
return false;
634+
}
635+
636+
return true;
637+
}
638+
639+
tl::expected<bool, Error>
640+
HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) {
641+
// TODO(ZXY): implement need_fine_tune to allow update with distant vector
642+
if (use_static_) {
643+
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
644+
"static hnsw does not support update");
645+
}
646+
647+
try {
648+
// the validation of the new vector
649+
void* new_base_vec = nullptr;
650+
size_t data_size = 0;
651+
get_vectors(new_base, &new_base_vec, &data_size);
652+
653+
// note that the validation of old_id is handled within updatePoint.
654+
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateVector(
655+
id, new_base_vec);
656+
} catch (const std::runtime_error& e) {
657+
spdlog::warn("update error for replace vector of id {}: {}", id, e.what());
658+
return false;
659+
}
660+
661+
return true;
662+
}
663+
619664
tl::expected<bool, Error>
620665
HNSW::remove(int64_t id) {
621666
if (use_static_) {

src/index/hnsw.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,16 @@ class HNSW : public Index {
7070
SAFE_CALL(return this->remove(id));
7171
}
7272

73+
tl::expected<bool, Error>
74+
UpdateId(int64_t old_id, int64_t new_id) override {
75+
SAFE_CALL(return this->update_id(old_id, new_id));
76+
}
77+
78+
tl::expected<bool, Error>
79+
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) override {
80+
SAFE_CALL(return this->update_vector(id, new_base, need_fine_tune));
81+
}
82+
7383
tl::expected<DatasetPtr, Error>
7484
KnnSearch(const DatasetPtr& query,
7585
int64_t k,
@@ -192,6 +202,12 @@ class HNSW : public Index {
192202
tl::expected<bool, Error>
193203
remove(int64_t id);
194204

205+
tl::expected<bool, Error>
206+
update_id(int64_t old_id, int64_t new_id);
207+
208+
tl::expected<bool, Error>
209+
update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune);
210+
195211
template <typename FilterType>
196212
tl::expected<DatasetPtr, Error>
197213
knn_search_internal(const DatasetPtr& query,

tests/test_hnsw_new.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,40 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Concurrent Add", "[f
285285
}
286286
}
287287

288+
TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hnsw]") {
289+
auto origin_size = vsag::Options::Instance().block_size_limit();
290+
auto size = GENERATE(1024 * 1024 * 2);
291+
auto metric_type = GENERATE("l2", "ip", "cosine");
292+
const std::string name = "hnsw";
293+
auto search_param = fmt::format(search_param_tmp, 100);
294+
for (auto& dim : dims) {
295+
vsag::Options::Instance().set_block_size_limit(size);
296+
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
297+
auto index = TestFactory(name, param, true);
298+
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
299+
TestBuildIndex(index, dataset, true);
300+
TestUpdateId(index, dataset, search_param, true);
301+
vsag::Options::Instance().set_block_size_limit(origin_size);
302+
}
303+
}
304+
305+
TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") {
306+
auto origin_size = vsag::Options::Instance().block_size_limit();
307+
auto size = GENERATE(1024 * 1024 * 2);
308+
auto metric_type = GENERATE("l2");
309+
const std::string name = "hnsw";
310+
auto search_param = fmt::format(search_param_tmp, 100);
311+
for (auto& dim : dims) {
312+
vsag::Options::Instance().set_block_size_limit(size);
313+
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
314+
auto index = TestFactory(name, param, true);
315+
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
316+
TestBuildIndex(index, dataset, true);
317+
TestUpdateVector(index, dataset, search_param, true);
318+
vsag::Options::Instance().set_block_size_limit(origin_size);
319+
}
320+
}
321+
288322
TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Serialize File", "[ft][hnsw]") {
289323
auto origin_size = vsag::Options::Instance().block_size_limit();
290324
auto size = GENERATE(1024 * 1024 * 2);

tests/test_index.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,137 @@ TestIndex::TestAddIndex(const IndexPtr& index,
6060
}
6161
}
6262

63+
void
64+
TestIndex::TestUpdateId(const IndexPtr& index,
65+
const TestDatasetPtr& dataset,
66+
const std::string& search_param,
67+
bool expected_success) {
68+
auto ids = dataset->base_->GetIds();
69+
auto num_vectors = dataset->base_->GetNumElements();
70+
auto dim = dataset->base_->GetDim();
71+
auto gt_topK = dataset->top_k;
72+
auto base = dataset->base_->GetFloat32Vectors();
73+
74+
std::unordered_map<int64_t, int64_t> update_id_map;
75+
std::unordered_map<int64_t, int64_t> reverse_id_map;
76+
int64_t max_id = num_vectors;
77+
for (int i = 0; i < num_vectors; i++) {
78+
if (ids[i] > max_id) {
79+
max_id = ids[i];
80+
}
81+
}
82+
for (int i = 0; i < num_vectors; i++) {
83+
update_id_map[ids[i]] = ids[i] + 2 * max_id;
84+
}
85+
86+
std::vector<int> correct_num = {0, 0};
87+
for (int round = 0; round < 2; round++) {
88+
// round 0 for update, round 1 for validate update results
89+
for (int i = 0; i < num_vectors; i++) {
90+
auto query = vsag::Dataset::Make();
91+
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);
92+
93+
auto result = index->KnnSearch(query, gt_topK, search_param);
94+
REQUIRE(result.has_value());
95+
96+
if (round == 0) {
97+
if (result.value()->GetIds()[0] == ids[i]) {
98+
correct_num[round] += 1;
99+
}
100+
101+
auto succ_update_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
102+
REQUIRE(succ_update_res.has_value());
103+
if (expected_success) {
104+
REQUIRE(succ_update_res.value());
105+
}
106+
107+
// old id don't exist
108+
auto failed_old_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
109+
REQUIRE(failed_old_res.has_value());
110+
REQUIRE(not failed_old_res.value());
111+
112+
// new id is used
113+
auto failed_new_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]);
114+
REQUIRE(failed_new_res.has_value());
115+
REQUIRE(not failed_new_res.value());
116+
} else {
117+
if (result.value()->GetIds()[0] == update_id_map[ids[i]]) {
118+
correct_num[round] += 1;
119+
}
120+
}
121+
}
122+
}
123+
124+
REQUIRE(correct_num[0] == correct_num[1]);
125+
}
126+
127+
void
128+
TestIndex::TestUpdateVector(const IndexPtr& index,
129+
const TestDatasetPtr& dataset,
130+
const std::string& search_param,
131+
bool expected_success) {
132+
auto ids = dataset->base_->GetIds();
133+
auto num_vectors = dataset->base_->GetNumElements();
134+
auto dim = dataset->base_->GetDim();
135+
auto gt_topK = dataset->top_k;
136+
auto base = dataset->base_->GetFloat32Vectors();
137+
138+
int64_t max_id = num_vectors;
139+
for (int i = 0; i < num_vectors; i++) {
140+
if (ids[i] > max_id) {
141+
max_id = ids[i];
142+
}
143+
}
144+
145+
std::vector<int> correct_num = {0, 0};
146+
for (int round = 0; round < 2; round++) {
147+
// round 0 for update, round 1 for validate update results
148+
for (int i = 0; i < num_vectors; i++) {
149+
auto query = vsag::Dataset::Make();
150+
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);
151+
152+
auto result = index->KnnSearch(query, gt_topK, search_param);
153+
REQUIRE(result.has_value());
154+
155+
if (round == 0) {
156+
if (result.value()->GetIds()[0] == ids[i]) {
157+
correct_num[round] += 1;
158+
}
159+
160+
std::vector<float> update_vecs(dim);
161+
for (int d = 0; d < dim; d++) {
162+
update_vecs[d] = base[i * dim + d] + 0.001f;
163+
}
164+
auto new_base = vsag::Dataset::Make();
165+
new_base->NumElements(1)
166+
->Dim(dim)
167+
->Float32Vectors(update_vecs.data())
168+
->Owner(false);
169+
170+
auto before_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
171+
auto succ_vec_res = index->UpdateVector(ids[i], new_base);
172+
REQUIRE(succ_vec_res.has_value());
173+
if (expected_success) {
174+
REQUIRE(succ_vec_res.value());
175+
}
176+
auto after_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
177+
REQUIRE(before_update_dist < after_update_dist);
178+
179+
// old id don't exist
180+
auto failed_old_res = index->UpdateVector(ids[i] + 2 * max_id, new_base);
181+
REQUIRE(failed_old_res.has_value());
182+
REQUIRE(not failed_old_res.value());
183+
} else {
184+
if (result.value()->GetIds()[0] == ids[i]) {
185+
correct_num[round] += 1;
186+
}
187+
}
188+
}
189+
}
190+
191+
REQUIRE(correct_num[0] == correct_num[1]);
192+
}
193+
63194
void
64195
TestIndex::TestContinueAdd(const IndexPtr& index,
65196
const TestDatasetPtr& dataset,

tests/test_index.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,18 @@ class TestIndex {
6565
const TestDatasetPtr& dataset,
6666
bool expected_success = true);
6767

68+
static void
69+
TestUpdateId(const IndexPtr& index,
70+
const TestDatasetPtr& dataset,
71+
const std::string& search_param,
72+
bool expected_success = true);
73+
74+
static void
75+
TestUpdateVector(const IndexPtr& index,
76+
const TestDatasetPtr& dataset,
77+
const std::string& search_param,
78+
bool expected_success = true);
79+
6880
static void
6981
TestContinueAdd(const IndexPtr& index,
7082
const TestDatasetPtr& dataset,

0 commit comments

Comments
 (0)