Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: zhongxiaoyao.zxy <[email protected]>
  • Loading branch information
ShawnShawnYou committed Dec 24, 2024
1 parent bdb4716 commit 5bc7f48
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 4 deletions.
25 changes: 25 additions & 0 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ class Index {
throw std::runtime_error("Index not support delete vector");
}

/**
* Update the id of a base point from the index
*
* @param old_id indicates the old id of a base point in index
* @param new_id is the updated new id of the base point
* @return result indicates whether the update operation is successful.
*/
virtual tl::expected<bool, Error>
UpdateId(int64_t old_id, int64_t new_id) {
throw std::runtime_error("Index not support update id");
}

/**
* Update the vector of a base point from the index
*
* @param id indicates the old id of a base point in index
* @param new_base is the updated new vector of the base point
* @param need_fine_tune indicates whether the connection of the base point needs to be fine-tuned
* @return result indicates whether the update operation is successful.
*/
virtual tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) {
throw std::runtime_error("Index not support update vector");
}

/**
* Performing single KNN search on index
*
Expand Down
2 changes: 2 additions & 0 deletions include/vsag/index_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ enum IndexFeature {
SUPPORT_BATCH_SEARCH_WITH_MULTI_THREAD, /**< Supports batch searching with multi-threading */

SUPPORT_ADD_CONCURRENT, /**< Supports concurrent addition of elements */
SUPPORT_UPDATE_ID_CONCURRENT, /**< Supports concurrent update id of elements */
SUPPORT_UPDATE_VECTOR_CONCURRENT, /**< Supports concurrent update vector of elements */
SUPPORT_SEARCH_CONCURRENT, /**< Supports concurrent searching */
SUPPORT_DELETE_CONCURRENT, /**< Supports concurrent deletion */
SUPPORT_ADD_SEARCH_CONCURRENT, /**< Supports concurrent addition and searching */
Expand Down
35 changes: 35 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,41 @@ HierarchicalNSW::dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_
}
}

void
HierarchicalNSW::updateVector(LabelType label, const void* data_point) {
std::unique_lock<std::mutex> lock(global_);
auto iter = label_lookup_.find(label);
if (iter == label_lookup_.end()) {
throw std::runtime_error(fmt::format("no label {} in HNSW", label));
} else {
InnerIdType internal_id = iter->second;

// reset data
std::shared_ptr<float[]> normalize_data;
normalizeVector(data_point, normalize_data);
memcpy(getDataByInternalId(internal_id), data_point, data_size_);
}
}

void
HierarchicalNSW::updateLabel(LabelType old_label, LabelType new_label) {
std::unique_lock<std::mutex> lock(global_);
auto iter_old = label_lookup_.find(old_label);
auto iter_new = label_lookup_.find(new_label);
if (iter_old == label_lookup_.end()) {
throw std::runtime_error(fmt::format("no old label {} in HNSW", old_label));
} else if (iter_new != label_lookup_.end()) {
throw std::runtime_error(fmt::format("new label {} has been in HNSW", new_label));
} else {
InnerIdType internal_id = iter_old->second;

// reset label
label_lookup_.erase(iter_old);
label_lookup_[new_label] = internal_id;
setExternalLabel(internal_id, new_label);
}
}

void
HierarchicalNSW::removePoint(LabelType label) {
InnerIdType cur_c = 0;
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_c);

void
updateLabel(LabelType old_label, LabelType new_label);

void
updateVector(LabelType label, const void* data_point);

void
removePoint(LabelType label);

Expand Down
46 changes: 46 additions & 0 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,52 @@ HNSW::GetStats() const {
return j.dump();
}

tl::expected<bool, Error>
HNSW::update_id(int64_t old_id, int64_t new_id) {
if (use_static_) {
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
"static hnsw does not support update");
}

try {
std::unique_lock lock(rw_mutex_);

// note that the validation of old_id is handled within updateLabel.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateLabel(old_id,
new_id);
} catch (const std::runtime_error& e) {
spdlog::warn(
"update error for replace old_id {} to new_id {}: {}", old_id, new_id, e.what());
return false;
}

return true;
}

tl::expected<bool, Error>
HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) {
if (use_static_) {
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
"static hnsw does not support update");
}

try {
// the validation of the new vector
void* new_base_vec = nullptr;
size_t data_size = 0;
get_vectors(new_base, &new_base_vec, &data_size);

// note that the validation of old_id is handled within updatePoint.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateVector(
id, new_base_vec);
} catch (const std::runtime_error& e) {
spdlog::warn("update error for replace vector of id {}: {}", id, e.what());
return false;
}

return true;
}

tl::expected<bool, Error>
HNSW::remove(int64_t id) {
if (use_static_) {
Expand Down
16 changes: 16 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ class HNSW : public Index {
SAFE_CALL(return this->remove(id));
}

tl::expected<bool, Error>
UpdateId(int64_t old_id, int64_t new_id) override {
SAFE_CALL(return this->update_id(old_id, new_id));
}

tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) override {
SAFE_CALL(return this->update_vector(id, new_base, need_fine_tune));
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
Expand Down Expand Up @@ -192,6 +202,12 @@ class HNSW : public Index {
tl::expected<bool, Error>
remove(int64_t id);

tl::expected<bool, Error>
update_id(int64_t old_id, int64_t new_id);

tl::expected<bool, Error>
update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune);

template <typename FilterType>
tl::expected<DatasetPtr, Error>
knn_search_internal(const DatasetPtr& query,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,40 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Add", "[ft][hnsw]")
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestUpdateId(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestUpdateVector(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Serialize File", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
131 changes: 131 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,137 @@ TestIndex::TestAddIndex(const IndexPtr& index,
}
}

void
TestIndex::TestUpdateId(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto ids = dataset->base_->GetIds();
auto num_vectors = dataset->base_->GetNumElements();
auto dim = dataset->base_->GetDim();
auto gt_topK = dataset->top_k;
auto base = dataset->base_->GetFloat32Vectors();

std::unordered_map<int64_t, int64_t> update_id_map;
std::unordered_map<int64_t, int64_t> reverse_id_map;
int64_t max_id = num_vectors;
for (int i = 0; i < num_vectors; i++) {
if (ids[i] > max_id) {
max_id = ids[i];
}
}
for (int i = 0; i < num_vectors; i++) {
update_id_map[ids[i]] = ids[i] + 2 * max_id;
}

std::vector<int> correct_num = {0, 0};
for (int round = 0; round < 2; round++) {
// round 0 for update, round 1 for validate update results
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);

auto result = index->KnnSearch(query, gt_topK, search_param);
REQUIRE(result.has_value());

if (round == 0) {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}

auto succ_update_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
REQUIRE(succ_update_res.has_value());
if (expected_success) {
REQUIRE(succ_update_res.value());
}

// old id don't exist
auto failed_old_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
REQUIRE(failed_old_res.has_value());
REQUIRE(not failed_old_res.value());

// new id is used
auto failed_new_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]);
REQUIRE(failed_new_res.has_value());
REQUIRE(not failed_new_res.value());
} else {
if (result.value()->GetIds()[0] == update_id_map[ids[i]]) {
correct_num[round] += 1;
}
}
}
}

REQUIRE(correct_num[0] == correct_num[1]);
}

void
TestIndex::TestUpdateVector(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto ids = dataset->base_->GetIds();
auto num_vectors = dataset->base_->GetNumElements();
auto dim = dataset->base_->GetDim();
auto gt_topK = dataset->top_k;
auto base = dataset->base_->GetFloat32Vectors();

int64_t max_id = num_vectors;
for (int i = 0; i < num_vectors; i++) {
if (ids[i] > max_id) {
max_id = ids[i];
}
}

std::vector<int> correct_num = {0, 0};
for (int round = 0; round < 2; round++) {
// round 0 for update, round 1 for validate update results
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);

auto result = index->KnnSearch(query, gt_topK, search_param);
REQUIRE(result.has_value());

if (round == 0) {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}

std::vector<float> update_vecs(dim);
for (int d = 0; d < dim; d++) {
update_vecs[d] = base[i * dim + d] + 0.001f;
}
auto new_base = vsag::Dataset::Make();
new_base->NumElements(1)
->Dim(dim)
->Float32Vectors(update_vecs.data())
->Owner(false);

auto before_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
auto succ_vec_res = index->UpdateVector(ids[i], new_base);
REQUIRE(succ_vec_res.has_value());
if (expected_success) {
REQUIRE(succ_vec_res.value());
}
auto after_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(before_update_dist < after_update_dist);

// old id don't exist
auto failed_old_res = index->UpdateVector(ids[i] + 2 * max_id, new_base);
REQUIRE(failed_old_res.has_value());
REQUIRE(not failed_old_res.value());
} else {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}
}
}
}

REQUIRE(correct_num[0] == correct_num[1]);
}

void
TestIndex::TestContinueAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ class TestIndex {
const TestDatasetPtr& dataset,
bool expected_success = true);

static void
TestUpdateId(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestUpdateVector(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestContinueAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
Loading

0 comments on commit 5bc7f48

Please sign in to comment.