|
| 1 | +use std::io::Write; |
| 2 | +use std::path::Path; |
| 3 | +use std::sync::{Arc, Mutex}; |
| 4 | + |
| 5 | +use aws_config::BehaviorVersion; |
| 6 | +use aws_sdk_s3::Client; |
| 7 | +use bytes::Buf; |
| 8 | +use rabitq::metrics::METRICS; |
| 9 | +use rabitq::utils::l2_squared_distance; |
| 10 | +use rusqlite::{Connection, OptionalExtension}; |
| 11 | + |
| 12 | +const BLOCK_BYTE_LIMIT: u32 = 1 << 19; // 512KiB |
| 13 | + |
| 14 | +fn parse_fvecs(bytes: &mut impl Buf) -> Vec<Vec<f32>> { |
| 15 | + let mut vecs = Vec::new(); |
| 16 | + while bytes.has_remaining() { |
| 17 | + let dim = bytes.get_u32_le() as usize; |
| 18 | + vecs.push((0..dim).map(|_| bytes.get_f32_le()).collect()); |
| 19 | + } |
| 20 | + vecs |
| 21 | +} |
| 22 | + |
| 23 | +/// Download rabitq meta data from S3. |
| 24 | +pub async fn download_meta_from_s3(bucket: &str, prefix: &str, path: &Path) -> anyhow::Result<()> { |
| 25 | + let s3_config = aws_config::defaults(BehaviorVersion::v2024_03_28()) |
| 26 | + .load() |
| 27 | + .await; |
| 28 | + let client = Client::new(&s3_config); |
| 29 | + for filename in [ |
| 30 | + "centroids.fvecs", |
| 31 | + "orthogonal.fvecs", |
| 32 | + "factors.fvecs", |
| 33 | + "offsets_ids.ivecs", |
| 34 | + "x_binary_vec.u64vecs", |
| 35 | + ] { |
| 36 | + if path.join(filename).is_file() { |
| 37 | + continue; |
| 38 | + } |
| 39 | + let mut object = client |
| 40 | + .get_object() |
| 41 | + .bucket(bucket) |
| 42 | + .key(format!("{}/{}", prefix, filename)) |
| 43 | + .send() |
| 44 | + .await?; |
| 45 | + let mut file = std::fs::File::create(path.join(filename))?; |
| 46 | + while let Some(chunk) = object.body.try_next().await? { |
| 47 | + file.write_all(chunk.as_ref())?; |
| 48 | + } |
| 49 | + } |
| 50 | + |
| 51 | + Ok(()) |
| 52 | +} |
| 53 | + |
| 54 | +/// Cached vector. |
| 55 | +#[derive(Debug)] |
| 56 | +pub struct CachedVector { |
| 57 | + dim: u32, |
| 58 | + num_per_block: u32, |
| 59 | + total_num: u32, |
| 60 | + total_block: u32, |
| 61 | + s3_bucket: Arc<String>, |
| 62 | + s3_key: Arc<String>, |
| 63 | + s3_client: Arc<Client>, |
| 64 | + sqlite_conn: Mutex<Connection>, |
| 65 | +} |
| 66 | + |
| 67 | +impl CachedVector { |
| 68 | + /// init the cached vector. |
| 69 | + pub async fn new( |
| 70 | + dim: u32, |
| 71 | + num: u32, |
| 72 | + local_path: String, |
| 73 | + s3_bucket: String, |
| 74 | + s3_prefix: String, |
| 75 | + // _mem_cache_num: u32, |
| 76 | + // _disk_cache_mb: u32, |
| 77 | + ) -> Self { |
| 78 | + let s3_config = aws_config::defaults(BehaviorVersion::v2024_03_28()) |
| 79 | + .load() |
| 80 | + .await; |
| 81 | + let s3_client = Arc::new(Client::new(&s3_config)); |
| 82 | + let num_per_block = BLOCK_BYTE_LIMIT / (4 * (dim + 1)); |
| 83 | + let total_num = num; |
| 84 | + let total_block = (total_num + num_per_block - 1) / num_per_block; |
| 85 | + let sqlite_conn = Connection::open(Path::new(&local_path)).expect("failed to open sqlite"); |
| 86 | + sqlite_conn |
| 87 | + .execute( |
| 88 | + "CREATE TABLE IF NOT EXISTS matrix ( |
| 89 | + id INTEGER PRIMARY KEY, |
| 90 | + vec BLOB |
| 91 | + )", |
| 92 | + (), |
| 93 | + ) |
| 94 | + .expect("failed to create table"); |
| 95 | + Self { |
| 96 | + dim, |
| 97 | + num_per_block, |
| 98 | + total_num, |
| 99 | + total_block, |
| 100 | + s3_bucket: Arc::new(s3_bucket), |
| 101 | + s3_key: Arc::new(format!("{}/base.fvecs", s3_prefix)), |
| 102 | + s3_client, |
| 103 | + sqlite_conn: Mutex::new(sqlite_conn), |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + fn block_range_bytes(&self, block: usize) -> (usize, usize) { |
| 108 | + let start = 4 * block * (self.dim as usize + 1) * self.num_per_block as usize; |
| 109 | + let end = if block == self.total_block as usize - 1 { |
| 110 | + 4 * (self.dim as usize + 1) * self.total_num as usize |
| 111 | + } else { |
| 112 | + 4 * (block + 1) * (self.dim as usize + 1) * self.num_per_block as usize |
| 113 | + }; |
| 114 | + (start, end - 1) |
| 115 | + } |
| 116 | + |
| 117 | + async fn fetch_from_s3(&self, index: usize, query: &[f32]) -> anyhow::Result<f32> { |
| 118 | + let block = index / self.num_per_block as usize; |
| 119 | + let (start, end) = self.block_range_bytes(block); |
| 120 | + let object = self |
| 121 | + .s3_client |
| 122 | + .get_object() |
| 123 | + .bucket(self.s3_bucket.as_ref()) |
| 124 | + .key(self.s3_key.as_ref()) |
| 125 | + .range(format!("bytes={}-{}", start, end)) |
| 126 | + .send() |
| 127 | + .await?; |
| 128 | + let mut bytes = object.body.collect().await?; |
| 129 | + let vecs = parse_fvecs(&mut bytes); |
| 130 | + METRICS.add_cache_miss_count(1); |
| 131 | + let offset_id = index % self.num_per_block as usize; |
| 132 | + let start_id = index - offset_id; |
| 133 | + |
| 134 | + { |
| 135 | + let conn = self.sqlite_conn.lock().unwrap(); |
| 136 | + let mut statement = conn.prepare( |
| 137 | + "INSERT INTO matrix (id, vec) VALUES (?1, ?2) ON CONFLICT(id) DO NOTHING", |
| 138 | + )?; |
| 139 | + for (i, vec) in vecs.iter().enumerate() { |
| 140 | + statement.execute((start_id + i, bytemuck::cast_slice(vec)))?; |
| 141 | + } |
| 142 | + } |
| 143 | + |
| 144 | + let distance = l2_squared_distance(&vecs[offset_id], query); |
| 145 | + |
| 146 | + Ok(distance) |
| 147 | + } |
| 148 | + |
| 149 | + /// Get the vector l2 square distance. |
| 150 | + pub async fn get_query_vec_distance(&self, query: &[f32], index: u32) -> anyhow::Result<f32> { |
| 151 | + { |
| 152 | + let conn = self.sqlite_conn.lock().unwrap(); |
| 153 | + let mut statement = conn.prepare("SELECT vec FROM matrix WHERE id = ?1")?; |
| 154 | + if let Some(raw) = statement |
| 155 | + .query_row([index], |res| res.get::<_, Vec<u8>>(0)) |
| 156 | + .optional()? |
| 157 | + { |
| 158 | + let res = l2_squared_distance(bytemuck::cast_slice(&raw), query); |
| 159 | + return Ok(res); |
| 160 | + } |
| 161 | + } |
| 162 | + self.fetch_from_s3(index as usize, query).await |
| 163 | + } |
| 164 | +} |
0 commit comments