Skip to content

Commit 761d305

Browse files
committed
feat: use workspace, add cli & disk cache & http service
Signed-off-by: Keming <[email protected]>
1 parent b79994d commit 761d305

19 files changed

+2505
-39
lines changed

Cargo.lock

+1,728-16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+19-6
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
[package]
2-
name = "rabitq"
1+
[workspace]
2+
members = ["crates/*"]
3+
4+
[workspace.package]
35
version = "0.2.0"
46
edition = "2021"
57
description = "A Rust implementation of the RaBitQ vector search algorithm."
68
license = "AGPL-3.0"
9+
authors = ["Keming <[email protected]>"]
10+
11+
[workspace.dependencies]
12+
log = "0.4"
13+
faer = "0.19"
14+
15+
# root package
16+
[package]
17+
name = "rabitq"
18+
version.workspace = true
19+
edition.workspace = true
20+
authors.workspace = true
21+
license.workspace = true
722
documentation = "https://docs.rs/rabitq"
823
repository = "https://github.com/kemingy/rabitq"
924
keywords = ["vector-search", "quantization", "binary-dot-product"]
1025
categories = ["algorithms", "science"]
1126

1227
[dependencies]
13-
argh = "0.1"
14-
env_logger = "0.11"
15-
faer = "0.19"
16-
log = "0.4"
28+
faer = { workspace = true }
29+
log = { workspace = true }
1730
num-traits = "0.2"
1831
rand = "0.8"
1932
rand_distr = "0.4.3"

README.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,6 @@
1313
- [ ] RaBitQ with fastscan
1414
- [x] x86_64 SIMD support
1515
- [ ] integrate with K-means clustering
16-
- [ ] disk-based RaBitQ (WIP)
17-
- [ ] HTTP service (WIP)
16+
- [x] disk-based RaBitQ
17+
- [x] HTTP service
18+
- [ ] insert & update & delete

crates/cli/Cargo.toml

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "cli"
3+
version.workspace = true
4+
edition.workspace = true
5+
description.workspace = true
6+
license.workspace = true
7+
authors.workspace = true
8+
9+
[dependencies]
10+
argh = "0.1"
11+
env_logger = "0.11"
12+
log = {workspace = true}
13+
rabitq = {path = "../.."}

src/main.rs renamed to crates/cli/src/main.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use rabitq::utils::{calculate_recall, read_vecs};
99
use rabitq::RaBitQ;
1010

1111
#[derive(FromArgs, Debug)]
12-
/// RaBitQ
12+
/// RaBitQ CLI args
1313
struct Args {
1414
/// base path
1515
#[argh(option, short = 'b')]

crates/disk/Cargo.toml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[package]
2+
name = "disk"
3+
version.workspace = true
4+
edition.workspace = true
5+
authors.workspace = true
6+
license.workspace = true
7+
8+
[dependencies]
9+
anyhow = "1.0.89"
10+
aws-config = "1.5.8"
11+
aws-sdk-s3 = "1.54.0"
12+
bytemuck = "1.18.0"
13+
bytes = "1.7.2"
14+
faer.workspace = true
15+
rabitq = { path = "../.." }
16+
rusqlite = "0.32.1"

crates/disk/src/cache.rs

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
use std::io::Write;
2+
use std::path::Path;
3+
use std::sync::{Arc, Mutex};
4+
5+
use aws_config::BehaviorVersion;
6+
use aws_sdk_s3::Client;
7+
use bytes::Buf;
8+
use rabitq::metrics::METRICS;
9+
use rabitq::utils::l2_squared_distance;
10+
use rusqlite::{Connection, OptionalExtension};
11+
12+
const BLOCK_BYTE_LIMIT: u32 = 1 << 19; // 512KiB
13+
14+
fn parse_fvecs(bytes: &mut impl Buf) -> Vec<Vec<f32>> {
15+
let mut vecs = Vec::new();
16+
while bytes.has_remaining() {
17+
let dim = bytes.get_u32_le() as usize;
18+
vecs.push((0..dim).map(|_| bytes.get_f32_le()).collect());
19+
}
20+
vecs
21+
}
22+
23+
/// Download rabitq meta data from S3.
24+
pub async fn download_meta_from_s3(bucket: &str, prefix: &str, path: &Path) -> anyhow::Result<()> {
25+
let s3_config = aws_config::defaults(BehaviorVersion::v2024_03_28())
26+
.load()
27+
.await;
28+
let client = Client::new(&s3_config);
29+
for filename in [
30+
"centroids.fvecs",
31+
"orthogonal.fvecs",
32+
"factors.fvecs",
33+
"offsets_ids.ivecs",
34+
"x_binary_vec.u64vecs",
35+
] {
36+
if path.join(filename).is_file() {
37+
continue;
38+
}
39+
let mut object = client
40+
.get_object()
41+
.bucket(bucket)
42+
.key(format!("{}/{}", prefix, filename))
43+
.send()
44+
.await?;
45+
let mut file = std::fs::File::create(path.join(filename))?;
46+
while let Some(chunk) = object.body.try_next().await? {
47+
file.write_all(chunk.as_ref())?;
48+
}
49+
}
50+
51+
Ok(())
52+
}
53+
54+
/// Cached vector.
55+
#[derive(Debug)]
56+
pub struct CachedVector {
57+
dim: u32,
58+
num_per_block: u32,
59+
total_num: u32,
60+
total_block: u32,
61+
s3_bucket: Arc<String>,
62+
s3_key: Arc<String>,
63+
s3_client: Arc<Client>,
64+
sqlite_conn: Mutex<Connection>,
65+
}
66+
67+
impl CachedVector {
68+
/// init the cached vector.
69+
pub async fn new(
70+
dim: u32,
71+
num: u32,
72+
local_path: String,
73+
s3_bucket: String,
74+
s3_prefix: String,
75+
// _mem_cache_num: u32,
76+
// _disk_cache_mb: u32,
77+
) -> Self {
78+
let s3_config = aws_config::defaults(BehaviorVersion::v2024_03_28())
79+
.load()
80+
.await;
81+
let s3_client = Arc::new(Client::new(&s3_config));
82+
let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1));
83+
let total_num = num;
84+
let total_block = (total_num + num_per_block - 1) / num_per_block;
85+
let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite");
86+
sqlite_conn
87+
.execute(
88+
"CREATE TABLE IF NOT EXISTS matrix (
89+
id INTEGER PRIMARY KEY,
90+
vec BLOB
91+
)",
92+
(),
93+
)
94+
.expect("failed to create table");
95+
Self {
96+
dim,
97+
num_per_block,
98+
total_num,
99+
total_block,
100+
s3_bucket: Arc::new(s3_bucket),
101+
s3_key: Arc::new(format!("{}/base.fvecs", s3_prefix)),
102+
s3_client,
103+
sqlite_conn: Mutex::new(sqlite_conn),
104+
}
105+
}
106+
107+
fn block_range_bytes(&self, block: usize) -> (usize, usize) {
108+
let start = 4 * block * (self.dim as usize + 1) * self.num_per_block as usize;
109+
let end = if block == self.total_block as usize - 1 {
110+
4 * (self.dim as usize + 1) * self.total_num as usize
111+
} else {
112+
4 * (block + 1) * (self.dim as usize + 1) * self.num_per_block as usize
113+
};
114+
(start, end - 1)
115+
}
116+
117+
async fn fetch_from_s3(&self, index: usize, query: &[f32]) -> anyhow::Result<f32> {
118+
let block = index / self.num_per_block as usize;
119+
let (start, end) = self.block_range_bytes(block);
120+
let object = self
121+
.s3_client
122+
.get_object()
123+
.bucket(self.s3_bucket.as_ref())
124+
.key(self.s3_key.as_ref())
125+
.range(format!("bytes={}-{}", start, end))
126+
.send()
127+
.await?;
128+
let mut bytes = object.body.collect().await?;
129+
let vecs = parse_fvecs(&mut bytes);
130+
METRICS.add_cache_miss_count(1);
131+
let offset_id = index % self.num_per_block as usize;
132+
let start_id = index - offset_id;
133+
134+
{
135+
let conn = self.sqlite_conn.lock().unwrap();
136+
let mut statement = conn.prepare(
137+
"INSERT INTO matrix (id, vec) VALUES (?1, ?2) ON CONFLICT(id) DO NOTHING",
138+
)?;
139+
for (i, vec) in vecs.iter().enumerate() {
140+
statement.execute((start_id + i, bytemuck::cast_slice(vec)))?;
141+
}
142+
}
143+
144+
let distance = l2_squared_distance(&vecs[offset_id], query);
145+
146+
Ok(distance)
147+
}
148+
149+
/// Get the vector l2 square distance.
150+
pub async fn get_query_vec_distance(&self, query: &[f32], index: u32) -> anyhow::Result<f32> {
151+
{
152+
let conn = self.sqlite_conn.lock().unwrap();
153+
let mut statement = conn.prepare("SELECT vec FROM matrix WHERE id = ?1")?;
154+
if let Some(raw) = statement
155+
.query_row([index], |res| res.get::<_, Vec<u8>>(0))
156+
.optional()?
157+
{
158+
let res = l2_squared_distance(bytemuck::cast_slice(&raw), query);
159+
return Ok(res);
160+
}
161+
}
162+
self.fetch_from_s3(index as usize, query).await
163+
}
164+
}

0 commit comments

Comments
 (0)