Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor param to replace JsonType internal #290

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@

add_subdirectory (simd)
add_subdirectory (io)
add_subdirectory (quantization)

file (GLOB CPP_SRCS "*.cpp")
file (GLOB CPP_FACTORY_SRCS "factory/*.cpp")
file (GLOB CPP_CONJUGATE_GRAPH_SRCS "impl/*.cpp")
file (GLOB CPP_INDEX_SRCS "index/*.cpp")
file (GLOB CPP_HNSWLIB_SRCS "algorithm/hnswlib/*.cpp")
file (GLOB CPP_QUANTIZATION_SRCS "quantization/*.cpp")
file (GLOB CPP_DATA_CELL_SRCS "data_cell/*.cpp")
file (GLOB CPP_ALGORITHM_SRCS "algorithm/*.cpp")
list (FILTER CPP_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_FACTORY_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_CONJUGATE_GRAPH_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_INDEX_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_QUANTIZATION_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_DATA_CELL_SRCS EXCLUDE REGEX "_test.cpp")
list (FILTER CPP_ALGORITHM_SRCS EXCLUDE REGEX "_test.cpp")

set (VSAG_SRCS ${CPP_SRCS} ${CPP_FACTORY_SRCS} ${CPP_INDEX_SRCS} ${CPP_CONJUGATE_GRAPH_SRCS}
${CPP_HNSWLIB_SRCS} ${CPP_QUANTIZATION_SRCS} ${CPP_DATA_CELL_SRCS} ${CPP_ALGORITHM_SRCS})
${CPP_HNSWLIB_SRCS} ${CPP_DATA_CELL_SRCS} ${CPP_ALGORITHM_SRCS})
add_library (vsag SHARED ${VSAG_SRCS})
add_library (vsag_static STATIC ${VSAG_SRCS})

set (VSAG_DEP_LIBS diskann pthread m dl simd fmt::fmt-header-only nlohmann_json::nlohmann_json roaring)
set (VSAG_DEP_LIBS diskann pthread m dl simd io quantizer fmt::fmt-header-only nlohmann_json::nlohmann_json roaring)
target_link_libraries (vsag ${VSAG_DEP_LIBS} coverage_config)
target_link_libraries (vsag_static ${VSAG_DEP_LIBS} coverage_config)

maybe_add_dependencies (vsag spdlog roaring openblas boost mkl)
maybe_add_dependencies (vsag_static spdlog roaring openblas boost mkl)
84 changes: 24 additions & 60 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "common.h"
#include "data_cell/sparse_graph_datacell.h"
#include "index/hgraph_zparameters.h"
#include "index/hgraph_index_zparameters.h"

namespace vsag {
static BinarySet
Expand Down Expand Up @@ -50,69 +50,33 @@ next_multiple_of_power_of_two(uint64_t x, uint64_t n) {
return result;
}

HGraph::HGraph(const JsonType& index_param, const vsag::IndexCommonParam& common_param) noexcept
: index_param_(index_param),
common_param_(common_param),
HGraph::HGraph(const HGraphParameter& hgraph_param,
const vsag::IndexCommonParam& common_param) noexcept
: common_param_(common_param),
dim_(common_param.dim_),
metric_(common_param.metric_),
allocator_(common_param.allocator_.get()),
label_lookup_(common_param.allocator_.get()),
neighbors_mutex_(0, common_param.allocator_.get()),
route_graphs_(common_param.allocator_.get()),
labels_(common_param.allocator_.get()) {
this->dim_ = common_param.dim_;
this->metric_ = common_param.metric_;
this->allocator_ = common_param.allocator_.get();
}

tl::expected<void, Error>
HGraph::Init() {
try {
CHECK_ARGUMENT(this->index_param_.contains(HGRAPH_USE_REORDER_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_USE_REORDER_KEY));
this->use_reorder_ = this->index_param_[HGRAPH_USE_REORDER_KEY];

CHECK_ARGUMENT(this->index_param_.contains(HGRAPH_BASE_CODES_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_BASE_CODES_KEY));
const auto& base_codes_json_obj = this->index_param_[HGRAPH_BASE_CODES_KEY];
this->basic_flatten_codes_ =
FlattenInterface::MakeInstance(base_codes_json_obj, common_param_);

if (this->use_reorder_) {
CHECK_ARGUMENT(
this->index_param_.contains(HGRAPH_PRECISE_CODES_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_PRECISE_CODES_KEY));
const auto& precise_codes_json_obj = this->index_param_[HGRAPH_PRECISE_CODES_KEY];
this->high_precise_codes_ =
FlattenInterface::MakeInstance(precise_codes_json_obj, common_param_);
}

CHECK_ARGUMENT(this->index_param_.contains(HGRAPH_GRAPH_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_GRAPH_KEY));
const auto& graph_json_obj = this->index_param_[HGRAPH_GRAPH_KEY];
this->bottom_graph_ = GraphInterface::MakeInstance(graph_json_obj, common_param_);

mult_ = 1 / log(1.0 * static_cast<double>(this->bottom_graph_->MaximumDegree()));

resize(bottom_graph_->max_capacity_);

if (this->index_param_.contains(BUILD_PARAMS_KEY)) {
auto& build_params = this->index_param_[BUILD_PARAMS_KEY];
if (build_params.contains(BUILD_EF_CONSTRUCTION)) {
this->ef_construct_ = build_params[BUILD_EF_CONSTRUCTION];
}
if (build_params.contains(BUILD_THREAD_COUNT)) {
this->build_thread_count_ = build_params[BUILD_THREAD_COUNT];
}
}

if (this->build_thread_count_ > 1) {
this->build_pool_ = std::make_unique<progschj::ThreadPool>(this->build_thread_count_);
}

this->init_features();
} catch (const std::invalid_argument& e) {
LOG_ERROR_AND_RETURNS(
ErrorType::INVALID_ARGUMENT, "failed to init(invalid argument): ", e.what());
labels_(common_param.allocator_.get()),
use_reorder_(hgraph_param.use_reorder_),
ef_construct_(hgraph_param.ef_construction_),
build_thread_count_(hgraph_param.build_thread_count_) {
this->basic_flatten_codes_ =
FlattenInterface::MakeInstance(hgraph_param.base_codes_param_, common_param);
if (use_reorder_) {
this->high_precise_codes_ =
FlattenInterface::MakeInstance(hgraph_param.precise_codes_param_, common_param);
}
return {};
this->bottom_graph_ =
GraphInterface::MakeInstance(hgraph_param.bottom_graph_param_, common_param);
mult_ = 1 / log(1.0 * static_cast<double>(this->bottom_graph_->MaximumDegree()));
resize(bottom_graph_->max_capacity_);
if (this->build_thread_count_ > 1) {
this->build_pool_ = std::make_unique<progschj::ThreadPool>(this->build_thread_count_);
}
this->init_features();
}

tl::expected<std::vector<int64_t>, Error>
Expand Down
7 changes: 2 additions & 5 deletions src/algorithm/hgraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "common.h"
#include "data_cell/flatten_interface.h"
#include "data_cell/graph_interface.h"
#include "hgraph_parameter.h"
#include "index/index_common_param.h"
#include "index_feature_list.h"
#include "typing.h"
Expand All @@ -47,10 +48,7 @@ class HGraph {
Vector<std::pair<float, InnerIdType>>,
CompareByFirst>;

HGraph(const JsonType& index_param, const IndexCommonParam& common_param) noexcept;

tl::expected<void, Error>
Init();
HGraph(const HGraphParameter& param, const IndexCommonParam& common_param) noexcept;

tl::expected<std::vector<int64_t>, Error>
Build(const DatasetPtr& data);
Expand Down Expand Up @@ -208,7 +206,6 @@ class HGraph {
int64_t dim_{0};
MetricType metric_{MetricType::METRIC_TYPE_L2SQR};

const JsonType index_param_{};
const IndexCommonParam common_param_{};

std::default_random_engine level_generator_{2021};
Expand Down
85 changes: 85 additions & 0 deletions src/algorithm/hgraph_parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "hgraph_parameter.h"

#include <fmt/format-inl.h>

#include "data_cell/graph_interface_parameter.h"
#include "inner_string_params.h"

namespace vsag {

HGraphParameter::HGraphParameter(const JsonType& json) : HGraphParameter() {
this->FromJson(json);
}

HGraphParameter::HGraphParameter() : name_(INDEX_TYPE_HGRAPH) {
}

void
HGraphParameter::FromJson(const JsonType& json) {
CHECK_ARGUMENT(json.contains(HGRAPH_USE_REORDER_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_USE_REORDER_KEY));
this->use_reorder_ = json[HGRAPH_USE_REORDER_KEY];

CHECK_ARGUMENT(json.contains(HGRAPH_BASE_CODES_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_BASE_CODES_KEY));
const auto& base_codes_json = json[HGRAPH_BASE_CODES_KEY];
this->base_codes_param_ = std::make_shared<FlattenDataCellParameter>();
this->base_codes_param_->FromJson(base_codes_json);

if (use_reorder_) {
CHECK_ARGUMENT(json.contains(HGRAPH_PRECISE_CODES_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_PRECISE_CODES_KEY));
const auto& precise_codes_json = json[HGRAPH_PRECISE_CODES_KEY];
this->precise_codes_param_ = std::make_shared<FlattenDataCellParameter>();
this->precise_codes_param_->FromJson(precise_codes_json);
}

CHECK_ARGUMENT(json.contains(HGRAPH_GRAPH_KEY),
fmt::format("hgraph parameters must contains {}", HGRAPH_GRAPH_KEY));
const auto& graph_json = json[HGRAPH_GRAPH_KEY];
this->bottom_graph_param_ = GraphInterfaceParameter::GetGraphParameterByJson(graph_json);

if (json.contains(BUILD_PARAMS_KEY)) {
auto& build_params = json[BUILD_PARAMS_KEY];
if (build_params.contains(BUILD_EF_CONSTRUCTION)) {
this->ef_construction_ = build_params[BUILD_EF_CONSTRUCTION];
}
if (build_params.contains(BUILD_THREAD_COUNT)) {
this->build_thread_count_ = build_params[BUILD_THREAD_COUNT];
}
}
}

JsonType
HGraphParameter::ToJson() {
JsonType json;
json["type"] = INDEX_TYPE_HGRAPH;

json[HGRAPH_USE_REORDER_KEY] = this->use_reorder_;
json[HGRAPH_BASE_CODES_KEY] = this->base_codes_param_->ToJson();
if (use_reorder_) {
json[HGRAPH_PRECISE_CODES_KEY] = this->precise_codes_param_->ToJson();
}
json[HGRAPH_GRAPH_KEY] = this->bottom_graph_param_->ToJson();

json[BUILD_PARAMS_KEY][BUILD_EF_CONSTRUCTION] = this->ef_construction_;
json[BUILD_PARAMS_KEY][BUILD_THREAD_COUNT] = this->build_thread_count_;
return json;
}

} // namespace vsag
48 changes: 48 additions & 0 deletions src/algorithm/hgraph_parameter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "data_cell/flatten_datacell_parameter.h"
#include "data_cell/graph_interface_parameter.h"
#include "parameter.h"

namespace vsag {

class HGraphParameter : public Parameter {
public:
explicit HGraphParameter(const JsonType& json);

HGraphParameter();

void
FromJson(const JsonType& json) override;

JsonType
ToJson() override;

public:
FlattenDataCellParamPtr base_codes_param_{nullptr};
FlattenDataCellParamPtr precise_codes_param_{nullptr};
GraphInterfaceParamPtr bottom_graph_param_{nullptr};

bool use_reorder_{false};
uint64_t ef_construction_{400};
uint64_t build_thread_count_{100};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set the default value to the value in options?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to do one thing in this PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine


std::string name_;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

struct's public variables do not end with an _

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how to distinguish between local variables and member variables

};

} // namespace vsag
8 changes: 4 additions & 4 deletions src/data_cell/flatten_datacell.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ class FlattenDataCell : public FlattenInterface {
public:
FlattenDataCell() = default;

explicit FlattenDataCell(const JsonType& quantization_param,
const JsonType& io_param,
explicit FlattenDataCell(const QuantizerParamPtr& quantization_param,
const IOParamPtr& io_param,
const IndexCommonParam& common_param);

void
Expand Down Expand Up @@ -127,8 +127,8 @@ class FlattenDataCell : public FlattenInterface {
};

template <typename QuantTmpl, typename IOTmpl>
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const JsonType& quantization_param,
const JsonType& io_param,
FlattenDataCell<QuantTmpl, IOTmpl>::FlattenDataCell(const QuantizerParamPtr& quantization_param,
const IOParamPtr& io_param,
const IndexCommonParam& common_param)
: allocator_(common_param.allocator_.get()) {
this->quantizer_ = std::make_shared<QuantTmpl>(quantization_param, common_param);
Expand Down
46 changes: 46 additions & 0 deletions src/data_cell/flatten_datacell_parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "flatten_datacell_parameter.h"

#include <fmt/format-inl.h>

#include "inner_string_params.h"

namespace vsag {
FlattenDataCellParameter::FlattenDataCellParameter() {
}

void
FlattenDataCellParameter::FromJson(const JsonType& json) {
CHECK_ARGUMENT(json.contains(IO_PARAMS_KEY),
fmt::format("flatten interface parameters must contains {}", IO_PARAMS_KEY));
this->io_parameter_ = IOParameter::GetIOParameterByJson(json[IO_PARAMS_KEY]);

CHECK_ARGUMENT(
json.contains(QUANTIZATION_PARAMS_KEY),
fmt::format("flatten interface parameters must contains {}", QUANTIZATION_PARAMS_KEY));
this->quantizer_parameter_ =
QuantizerParameter::GetQuantizerParameterByJson(json[QUANTIZATION_PARAMS_KEY]);
}

JsonType
FlattenDataCellParameter::ToJson() {
JsonType json;
json[IO_PARAMS_KEY] = this->io_parameter_->ToJson();
json[QUANTIZATION_PARAMS_KEY] = this->quantizer_parameter_->ToJson();
return json;
}
} // namespace vsag
Loading
Loading