Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support in-place update operation for hnsw #196

Merged
merged 1 commit into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions include/vsag/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,31 @@ class Index {
throw std::runtime_error("Index not support delete vector");
}

/**
* Update the id of a base point from the index
*
* @param old_id indicates the old id of a base point in index
* @param new_id is the updated new id of the base point
* @return result indicates whether the update operation is successful.
*/
LHT129 marked this conversation as resolved.
Show resolved Hide resolved
virtual tl::expected<bool, Error>
UpdateId(int64_t old_id, int64_t new_id) {
jiaweizone marked this conversation as resolved.
Show resolved Hide resolved
throw std::runtime_error("Index not support update id");
}

/**
LHT129 marked this conversation as resolved.
Show resolved Hide resolved
* Update the vector of a base point from the index
*
* @param id indicates the old id of a base point in index
* @param new_base is the updated new vector of the base point
* @param need_fine_tune indicates whether the connection of the base point needs to be fine-tuned
* @return result indicates whether the update operation is successful.
*/
virtual tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) {
throw std::runtime_error("Index not support update vector");
}

/**
* Performing single KNN search on index
*
Expand Down
2 changes: 2 additions & 0 deletions include/vsag/index_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ enum IndexFeature {
SUPPORT_BATCH_SEARCH_WITH_MULTI_THREAD, /**< Supports batch searching with multi-threading */

SUPPORT_ADD_CONCURRENT, /**< Supports concurrent addition of elements */
SUPPORT_UPDATE_ID_CONCURRENT, /**< Supports concurrent update id of elements */
SUPPORT_UPDATE_VECTOR_CONCURRENT, /**< Supports concurrent update vector of elements */
SUPPORT_SEARCH_CONCURRENT, /**< Supports concurrent searching */
SUPPORT_DELETE_CONCURRENT, /**< Supports concurrent deletion */
SUPPORT_ADD_SEARCH_CONCURRENT, /**< Supports concurrent addition and searching */
Expand Down
35 changes: 35 additions & 0 deletions src/algorithm/hnswlib/hnswalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,41 @@ HierarchicalNSW::dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_
}
}

void
HierarchicalNSW::updateVector(LabelType label, const void* data_point) {
std::unique_lock<std::mutex> lock(global_);
auto iter = label_lookup_.find(label);
if (iter == label_lookup_.end()) {
throw std::runtime_error(fmt::format("no label {} in HNSW", label));
} else {
InnerIdType internal_id = iter->second;

// reset data
std::shared_ptr<float[]> normalize_data;
normalizeVector(data_point, normalize_data);
LHT129 marked this conversation as resolved.
Show resolved Hide resolved
jiaweizone marked this conversation as resolved.
Show resolved Hide resolved
memcpy(getDataByInternalId(internal_id), data_point, data_size_);
LHT129 marked this conversation as resolved.
Show resolved Hide resolved
}
}

void
HierarchicalNSW::updateLabel(LabelType old_label, LabelType new_label) {
std::unique_lock<std::mutex> lock(global_);
auto iter_old = label_lookup_.find(old_label);
auto iter_new = label_lookup_.find(new_label);
if (iter_old == label_lookup_.end()) {
throw std::runtime_error(fmt::format("no old label {} in HNSW", old_label));
} else if (iter_new != label_lookup_.end()) {
throw std::runtime_error(fmt::format("new label {} has been in HNSW", new_label));
} else {
InnerIdType internal_id = iter_old->second;

// reset label
label_lookup_.erase(iter_old);
label_lookup_[new_label] = internal_id;
setExternalLabel(internal_id, new_label);
}
}

void
HierarchicalNSW::removePoint(LabelType label) {
InnerIdType cur_c = 0;
Expand Down
6 changes: 6 additions & 0 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
dealNoInEdge(InnerIdType id, int level, int m_curmax, int skip_c);

void
updateLabel(LabelType old_label, LabelType new_label);

void
updateVector(LabelType label, const void* data_point);

void
removePoint(LabelType label);

Expand Down
45 changes: 45 additions & 0 deletions src/index/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,51 @@ HNSW::GetStats() const {
return j.dump();
}

tl::expected<bool, Error>
HNSW::update_id(int64_t old_id, int64_t new_id) {
if (use_static_) {
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
"static hnsw does not support update");
}

try {
// note that the validation of old_id is handled within updateLabel.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateLabel(old_id,
new_id);
} catch (const std::runtime_error& e) {
spdlog::warn(
"update error for replace old_id {} to new_id {}: {}", old_id, new_id, e.what());
return false;
}

return true;
}

tl::expected<bool, Error>
HNSW::update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune) {
jiaweizone marked this conversation as resolved.
Show resolved Hide resolved
// TODO(ZXY): implement need_fine_tune to allow update with distant vector
if (use_static_) {
LOG_ERROR_AND_RETURNS(ErrorType::UNSUPPORTED_INDEX_OPERATION,
"static hnsw does not support update");
}

try {
// the validation of the new vector
void* new_base_vec = nullptr;
size_t data_size = 0;
get_vectors(new_base, &new_base_vec, &data_size);

// note that the validation of old_id is handled within updatePoint.
std::reinterpret_pointer_cast<hnswlib::HierarchicalNSW>(alg_hnsw_)->updateVector(
id, new_base_vec);
} catch (const std::runtime_error& e) {
spdlog::warn("update error for replace vector of id {}: {}", id, e.what());
return false;
}

return true;
}

tl::expected<bool, Error>
HNSW::remove(int64_t id) {
if (use_static_) {
Expand Down
16 changes: 16 additions & 0 deletions src/index/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ class HNSW : public Index {
SAFE_CALL(return this->remove(id));
}

tl::expected<bool, Error>
UpdateId(int64_t old_id, int64_t new_id) override {
SAFE_CALL(return this->update_id(old_id, new_id));
}

tl::expected<bool, Error>
UpdateVector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune = false) override {
SAFE_CALL(return this->update_vector(id, new_base, need_fine_tune));
}

tl::expected<DatasetPtr, Error>
KnnSearch(const DatasetPtr& query,
int64_t k,
Expand Down Expand Up @@ -192,6 +202,12 @@ class HNSW : public Index {
tl::expected<bool, Error>
remove(int64_t id);

tl::expected<bool, Error>
update_id(int64_t old_id, int64_t new_id);

tl::expected<bool, Error>
update_vector(int64_t id, const DatasetPtr& new_base, bool need_fine_tune);

template <typename FilterType>
tl::expected<DatasetPtr, Error>
knn_search_internal(const DatasetPtr& query,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_hnsw_new.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,40 @@ TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Concurrent Add", "[f
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Id", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2", "ip", "cosine");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestUpdateId(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Update Vector", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
auto metric_type = GENERATE("l2");
const std::string name = "hnsw";
auto search_param = fmt::format(search_param_tmp, 100);
for (auto& dim : dims) {
vsag::Options::Instance().set_block_size_limit(size);
auto param = GenerateHNSWBuildParametersString(metric_type, dim);
auto index = TestFactory(name, param, true);
auto dataset = pool.GetDatasetAndCreate(dim, base_count, metric_type);
TestBuildIndex(index, dataset, true);
TestUpdateVector(index, dataset, search_param, true);
vsag::Options::Instance().set_block_size_limit(origin_size);
}
}

TEST_CASE_PERSISTENT_FIXTURE(fixtures::HNSWTestIndex, "HNSW Serialize File", "[ft][hnsw]") {
auto origin_size = vsag::Options::Instance().block_size_limit();
auto size = GENERATE(1024 * 1024 * 2);
Expand Down
131 changes: 131 additions & 0 deletions tests/test_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,137 @@ TestIndex::TestAddIndex(const IndexPtr& index,
}
}

void
TestIndex::TestUpdateId(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto ids = dataset->base_->GetIds();
auto num_vectors = dataset->base_->GetNumElements();
auto dim = dataset->base_->GetDim();
auto gt_topK = dataset->top_k;
auto base = dataset->base_->GetFloat32Vectors();

std::unordered_map<int64_t, int64_t> update_id_map;
std::unordered_map<int64_t, int64_t> reverse_id_map;
int64_t max_id = num_vectors;
for (int i = 0; i < num_vectors; i++) {
if (ids[i] > max_id) {
max_id = ids[i];
}
}
for (int i = 0; i < num_vectors; i++) {
update_id_map[ids[i]] = ids[i] + 2 * max_id;
}

std::vector<int> correct_num = {0, 0};
for (int round = 0; round < 2; round++) {
// round 0 for update, round 1 for validate update results
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);

auto result = index->KnnSearch(query, gt_topK, search_param);
REQUIRE(result.has_value());

if (round == 0) {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}

auto succ_update_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
REQUIRE(succ_update_res.has_value());
if (expected_success) {
REQUIRE(succ_update_res.value());
}

// old id don't exist
auto failed_old_res = index->UpdateId(ids[i], update_id_map[ids[i]]);
REQUIRE(failed_old_res.has_value());
REQUIRE(not failed_old_res.value());

// new id is used
auto failed_new_res = index->UpdateId(update_id_map[ids[i]], update_id_map[ids[i]]);
REQUIRE(failed_new_res.has_value());
REQUIRE(not failed_new_res.value());
} else {
if (result.value()->GetIds()[0] == update_id_map[ids[i]]) {
correct_num[round] += 1;
}
}
}
}

REQUIRE(correct_num[0] == correct_num[1]);
}

void
TestIndex::TestUpdateVector(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success) {
auto ids = dataset->base_->GetIds();
auto num_vectors = dataset->base_->GetNumElements();
auto dim = dataset->base_->GetDim();
auto gt_topK = dataset->top_k;
auto base = dataset->base_->GetFloat32Vectors();

int64_t max_id = num_vectors;
for (int i = 0; i < num_vectors; i++) {
if (ids[i] > max_id) {
max_id = ids[i];
}
}

std::vector<int> correct_num = {0, 0};
for (int round = 0; round < 2; round++) {
// round 0 for update, round 1 for validate update results
for (int i = 0; i < num_vectors; i++) {
auto query = vsag::Dataset::Make();
query->NumElements(1)->Dim(dim)->Float32Vectors(base + i * dim)->Owner(false);

auto result = index->KnnSearch(query, gt_topK, search_param);
REQUIRE(result.has_value());

if (round == 0) {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}

std::vector<float> update_vecs(dim);
for (int d = 0; d < dim; d++) {
update_vecs[d] = base[i * dim + d] + 0.001f;
}
auto new_base = vsag::Dataset::Make();
new_base->NumElements(1)
->Dim(dim)
->Float32Vectors(update_vecs.data())
->Owner(false);

auto before_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
auto succ_vec_res = index->UpdateVector(ids[i], new_base);
REQUIRE(succ_vec_res.has_value());
if (expected_success) {
REQUIRE(succ_vec_res.value());
}
auto after_update_dist = *index->CalcDistanceById(base + i * dim, ids[i]);
REQUIRE(before_update_dist < after_update_dist);

// old id don't exist
auto failed_old_res = index->UpdateVector(ids[i] + 2 * max_id, new_base);
REQUIRE(failed_old_res.has_value());
REQUIRE(not failed_old_res.value());
} else {
if (result.value()->GetIds()[0] == ids[i]) {
correct_num[round] += 1;
}
}
}
}

REQUIRE(correct_num[0] == correct_num[1]);
}

void
TestIndex::TestContinueAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ class TestIndex {
const TestDatasetPtr& dataset,
bool expected_success = true);

static void
TestUpdateId(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestUpdateVector(const IndexPtr& index,
const TestDatasetPtr& dataset,
const std::string& search_param,
bool expected_success = true);

static void
TestContinueAdd(const IndexPtr& index,
const TestDatasetPtr& dataset,
Expand Down
Loading
Loading