Skip to content

Commit

Permalink
remove the excess locks (#212)
Browse files Browse the repository at this point in the history
Signed-off-by: jinjiabao.jjb <[email protected]>
  • Loading branch information
inabao authored and jinjiabao.jjb committed Dec 19, 2024
1 parent d241af1 commit 3baa6e1
Showing 1 changed file with 0 additions and 25 deletions.
25 changes: 0 additions & 25 deletions src/algorithm/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ const static float THRESHOLD_ERROR = 1e-6;

class HierarchicalNSW : public AlgorithmInterface<float> {
private:
static const tableint MAX_LABEL_OPERATION_LOCKS = 65536;
static const unsigned char DELETE_MARK = 0x01;

size_t max_elements_{0};
Expand All @@ -76,9 +75,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

VisitedListPool* visited_list_pool_{nullptr};

// Locks operations with element by label value
mutable vsag::Vector<std::mutex> label_op_locks_;

std::mutex global_{};
vsag::Vector<std::recursive_mutex> link_list_locks_;

Expand Down Expand Up @@ -137,7 +133,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
bool allow_replace_deleted = false)
: allocator_(allocator),
link_list_locks_(max_elements, allocator),
label_op_locks_(MAX_LABEL_OPERATION_LOCKS, allocator),
allow_replace_deleted_(allow_replace_deleted),
use_reversed_edges_(use_reversed_edges),
normalize_(normalize),
Expand Down Expand Up @@ -221,7 +216,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
throw std::runtime_error("Label not found");
}
tableint internal_id = search->second;
lock_table.unlock();
std::shared_ptr<float[]> normalize_query;
normalize_vector(data_point, normalize_query);
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
Expand All @@ -232,7 +226,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
isValidLabel(labeltype label) override {
std::unique_lock<std::mutex> lock_table(label_lookup_lock_);
bool is_valid = (label_lookup_.find(label) != label_lookup_.end());
lock_table.unlock();
return is_valid;
}

Expand All @@ -244,13 +237,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
}
};

inline std::mutex&
getLabelOpMutex(labeltype label) const {
// calculate hash
size_t lock_id = label & (MAX_LABEL_OPERATION_LOCKS - 1);
return label_op_locks_[lock_id];
}

inline labeltype
getExternalLabel(tableint internal_id) const {
labeltype value;
Expand Down Expand Up @@ -1070,7 +1056,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

size_links_level0_ = maxM0_ * sizeof(tableint) + sizeof(linklistsizeint);
vsag::Vector<std::recursive_mutex>(max_elements, allocator_).swap(link_list_locks_);
vsag::Vector<std::mutex>(MAX_LABEL_OPERATION_LOCKS, allocator_).swap(label_op_locks_);

revSize_ = 1.0 / mult_;
for (size_t i = 0; i < cur_element_count_; i++) {
Expand Down Expand Up @@ -1119,15 +1104,12 @@ class HierarchicalNSW : public AlgorithmInterface<float> {

const float*
getDataByLabel(labeltype label) const override {
std::lock_guard<std::mutex> lock_label(getLabelOpMutex(label));

std::unique_lock<std::mutex> lock_table(label_lookup_lock_);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end() || isMarkedDeleted(search->second)) {
throw std::runtime_error("Label not found");
}
tableint internalId = search->second;
lock_table.unlock();

char* data_ptrv = getDataByInternalId(internalId);
float* data_ptr = (float*)data_ptrv;
Expand All @@ -1141,16 +1123,13 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
markDelete(labeltype label) {
// lock all operations with element by label
std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));

std::unique_lock<std::mutex> lock_table(label_lookup_lock_);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
tableint internalId = search->second;
label_lookup_.erase(search);
lock_table.unlock();
markDeletedInternal(internalId);
}

Expand Down Expand Up @@ -1183,15 +1162,12 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
void
unmarkDelete(labeltype label) {
// lock all operations with element by label
std::unique_lock<std::mutex> lock_label(getLabelOpMutex(label));

std::unique_lock<std::mutex> lock_table(label_lookup_lock_);
auto search = label_lookup_.find(label);
if (search == label_lookup_.end()) {
throw std::runtime_error("Label not found");
}
tableint internalId = search->second;
lock_table.unlock();

unmarkDeletedInternal(internalId);
}
Expand Down Expand Up @@ -1239,7 +1215,6 @@ class HierarchicalNSW : public AlgorithmInterface<float> {
*/
bool
addPoint(const void* data_point, labeltype label) override {
std::lock_guard<std::mutex> lock_label(getLabelOpMutex(label));
if (addPoint(data_point, label, -1) == -1) {
return false;
}
Expand Down

0 comments on commit 3baa6e1

Please sign in to comment.