Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a batch interface for getDistanceByLabel #309

Closed
wants to merge 19 commits into from

Conversation

Carrot-77
Copy link
Contributor

@Carrot-77 Carrot-77 commented Jan 9, 2025

Add a batch interface for getDistanceByLabel
param count is the count of vids
param vids is the unique identifier of the vector to be calculated in the index.
param vector is the embedding of query
param distances is the distances between the query and the vector of the given ID
return result is valid distance of input vids.
virtual tl::expected<int64_t, Error>CalcBatchDistanceById(int64_t count, const int64_t vids, const float vector, float *&distances)

@@ -249,6 +249,23 @@ class Index {
throw std::runtime_error("Index doesn't support get distance by id");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no need to keep this interface anymore; it is just a special form of batch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This involves a lot of test files. Can we leave it unchanged for now?

CalcBatchDistanceById(int64_t count,
const int64_t *vids,
const float* vector,
float *&distances) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can return a Dataset as the result set, and let the Dataset automatically manage the allocated memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

@@ -81,6 +81,7 @@ class StaticHierarchicalNSW : public AlgorithmInterface<float> {
void* dist_func_param_{nullptr};

mutable std::mutex label_lookup_lock; // lock for label_lookup_
mutable std::shared_mutex shared_label_lookup_lock;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here add a new lock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

for (int i = 0; i < count; i++) {
auto search = label_lookup_.find(vids[i]);
if (search == label_lookup_.end()) {
distances[i] = -1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

explicitly specify at the interface that -1 indicates an invalid distance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already marked

@inabao inabao mentioned this pull request Jan 9, 2025
valid_cnt++;
}
}
result->NumElements(valid_cnt)->Owner(true, allocator_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's best to directly include the distances in the dataset to avoid potential memory leaks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified

}
}
result->NumElements(valid_cnt)->Owner(true, allocator_);
result->Distances(distances);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

distances[i] = -1;
} else {
InnerIdType internal_id = search->second;
float dist = fstdistfunc_(data_point, getDataByInternalId(internal_id), dist_func_param_);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have cal distance by id interface, why not use it

float
HierarchicalNSW::getDistanceByLabel(LabelType label, const void* data_point) 

Copy link
Contributor Author

@Carrot-77 Carrot-77 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to calculate distances in batches. Calling getDistanceByLabel requires a shared lock for each LabelType, which may have a certain impact on performance in large batch scenarios.

std::unique_lock<std::mutex> lock_table(label_lookup_lock);
int64_t valid_cnt = 0;
auto result = vsag::Dataset::Make();
auto *distances = (float *)allocator_->Allocate(sizeof(float) * count);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator

@LHT129 LHT129 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@wxyucs wxyucs left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Collaborator

@inabao inabao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Carrot-77
Copy link
Contributor Author

Forgot to add the signature, modified to: #337

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants