diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index 5cddf16fe..77765b43d 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -149,9 +149,9 @@ impl IndexOptions { } } IndexingOptions::Rabitq(_) => { - if !matches!(self.vector.d, DistanceKind::L2) { + if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) { return Err(ValidationError::new( - "rabitq is not support for distance that is not l2", + "rabitq is not support for distance that is not l2 or dot", )); } if !matches!(self.vector.v, VectorKind::Vecf32) { @@ -446,8 +446,10 @@ pub struct RabitqIndexingOptions { #[serde(default = "RabitqIndexingOptions::default_nlist")] #[validate(range(min = 1, max = 1_000_000))] pub nlist: u32, - #[serde(default = "IvfIndexingOptions::default_spherical_centroids")] + #[serde(default = "RabitqIndexingOptions::default_spherical_centroids")] pub spherical_centroids: bool, + #[serde(default = "RabitqIndexingOptions::default_residual_quantization")] + pub residual_quantization: bool, } impl RabitqIndexingOptions { @@ -457,6 +459,9 @@ impl RabitqIndexingOptions { fn default_spherical_centroids() -> bool { false } + fn default_residual_quantization() -> bool { + false + } } impl Default for RabitqIndexingOptions { @@ -464,6 +469,7 @@ impl Default for RabitqIndexingOptions { Self { nlist: Self::default_nlist(), spherical_centroids: Self::default_spherical_centroids(), + residual_quantization: Self::default_residual_quantization(), } } } diff --git a/crates/rabitq/src/lib.rs b/crates/rabitq/src/lib.rs index 60116ebc8..a90433b18 100644 --- a/crates/rabitq/src/lib.rs +++ b/crates/rabitq/src/lib.rs @@ -31,6 +31,7 @@ pub struct Rabitq { offsets: Json>, projected_centroids: Json>, projection: Json>>, + is_residual: Json, } impl Rabitq { @@ -74,21 +75,16 @@ impl Rabitq { opts.rabitq_nprobe as usize, ); let mut heap = Vec::new(); - for &(_, i) in lists.iter() { + for &(dis_v2, i) in lists.iter() { + let trans_vector = if *self.is_residual { + &O::residual(&projected_query, &self.projected_centroids[(i,)]) + } else { + &projected_query + }; let preprocessed = if opts.rabitq_fast_scan { - self.quantization - .fscan_preprocess(&O::residual( - &projected_query, - &self.projected_centroids[(i,)], - )) - .into() + self.quantization.fscan_preprocess(trans_vector, dis_v2) } else { - self.quantization - .preprocess(&O::residual( - &projected_query, - &self.projected_centroids[(i,)], - )) - .into() + self.quantization.preprocess(trans_vector, dis_v2) }; let start = self.offsets[i]; let end = self.offsets[i + 1]; @@ -116,6 +112,7 @@ fn from_nothing( let RabitqIndexingOptions { nlist, spherical_centroids, + residual_quantization, } = options.indexing.clone().unwrap_rabitq(); let projection = { use nalgebra::{DMatrix, QR}; @@ -137,6 +134,7 @@ fn from_nothing( } projection }; + let is_residual = residual_quantization && O::SUPPORT_RESIDUAL; rayon::check(); let samples = O::sample(collection, nlist); rayon::check(); @@ -174,16 +172,30 @@ fn from_nothing( let collection = RemappedCollection::from_collection(collection, remap); rayon::check(); let storage = O::Storage::create(path.as_ref().join("storage"), &collection); - let quantization = Quantization::create( - path.as_ref().join("quantization"), - options.vector, - collection.len(), - |vector| { - let vector = O::cast(collection.vector(vector)); - let target = k_means_lookup(vector, ¢roids); - O::proj(&projection, &O::residual(vector, ¢roids[(target,)])) - }, - ); + + let quantization = if is_residual { + Quantization::create( + path.as_ref().join("quantization"), + options.vector, + collection.len(), + |vector| { + let vector = O::cast(collection.vector(vector)); + let target = k_means_lookup(vector, ¢roids); + O::proj(&projection, &O::residual(vector, ¢roids[(target,)])) + }, + ) + } else { + Quantization::create( + path.as_ref().join("quantization"), + options.vector, + collection.len(), + |vector| { + let vector = O::cast(collection.vector(vector)); + O::proj(&projection, vector) + }, + ) + }; + let projected_centroids = Vec2::from_vec( (centroids.shape_0(), centroids.shape_1()), (0..centroids.shape_0()) @@ -200,6 +212,7 @@ fn from_nothing( projected_centroids, ); let projection = Json::create(path.as_ref().join("projection"), projection); + let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual); Rabitq { storage, payloads, @@ -207,6 +220,7 @@ fn from_nothing( projected_centroids, quantization, projection, + is_residual, } } @@ -217,6 +231,7 @@ fn open(path: impl AsRef) -> Rabitq { let offsets = Json::open(path.as_ref().join("offsets")); let projected_centroids = Json::open(path.as_ref().join("projected_centroids")); let projection = Json::open(path.as_ref().join("projection")); + let is_residual = Json::open(path.as_ref().join("is_residual")); Rabitq { storage, quantization, @@ -224,10 +239,11 @@ fn open(path: impl AsRef) -> Rabitq { offsets, projected_centroids, projection, + is_residual, } } -fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> { +fn select(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> { if lists.is_empty() || n == 0 { return Vec::new(); } diff --git a/crates/rabitq/src/operator.rs b/crates/rabitq/src/operator.rs index 8fad76b12..fb1db69e0 100644 --- a/crates/rabitq/src/operator.rs +++ b/crates/rabitq/src/operator.rs @@ -1,3 +1,5 @@ +use std::ops::Index; + use base::distance::Distance; use base::operator::Borrowed; use base::operator::*; @@ -10,42 +12,38 @@ use storage::OperatorStorage; pub trait OperatorRabitq: OperatorStorage { fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2; fn cast(vector: Borrowed<'_, Self>) -> &[f32]; + + const SUPPORT_RESIDUAL: bool; fn residual(lhs: &[f32], rhs: &[f32]) -> Vec; fn proj(projection: &[Vec], vector: &[f32]) -> Vec; - type Params; + type VectorParams: IntoIterator; + type QvectorParams; + type QvectorLookup; - type Preprocessed; - - fn preprocess(vector: &[f32]) -> (Self::Params, Self::Preprocessed); + fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams; + fn train_decode + ?Sized>(u: u32, meta: &T) + -> Self::VectorParams; + fn preprocess(trans_vector: &[f32], dis_v_2: f32) + -> (Self::QvectorParams, Self::QvectorLookup); fn process( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - code: &[u8], - p0: &Self::Params, - p1: &Self::Preprocessed, + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, ) -> Distance; fn process_lowerbound( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - code: &[u8], - p0: &Self::Params, - p1: &Self::Preprocessed, + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, epsilon: f32, ) -> Distance; - - fn fscan_preprocess(vector: &[f32]) -> (Self::Params, Vec); + fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec); fn fscan_process_lowerbound( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - p0: &Self::Params, - param: u16, + vector_params: &Self::VectorParams, + qvector_params: &Self::QvectorParams, + binary_prod: u16, epsilon: f32, ) -> Distance; } @@ -62,6 +60,7 @@ impl OperatorRabitq for VectL2 { fn cast(vector: Borrowed<'_, Self>) -> &[f32] { vector.slice() } + const SUPPORT_RESIDUAL: bool = true; fn residual(lhs: &[f32], rhs: &[f32]) -> Vec { f32::vector_sub(lhs, rhs) } @@ -73,63 +72,111 @@ impl OperatorRabitq for VectL2 { .collect() } - type Params = (f32, f32, f32, f32); + // [dis_u_2, factor_ppc, factor_ip, factor_err] + type VectorParams = [f32; 4]; + // (dis_v_2, b, k, qvector_sum) + type QvectorParams = (f32, f32, f32, f32); + type QvectorLookup = ((Vec, Vec, Vec, Vec), Vec); + + fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams { + let sum_of_abs_x = f32::reduce_sum_of_abs_x(&vector); + let dis_u_2 = f32::reduce_sum_of_x2(&vector); + let dis_u = dis_u_2.sqrt(); + let x0 = sum_of_abs_x / (dis_u_2 * (dims as f32)).sqrt(); + let x_x0 = dis_u / x0; + let fac_norm = (dims as f32).sqrt(); + let max_x1 = 1.0f32 / (dims as f32 - 1.0).sqrt(); + let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); + let factor_ip = -2.0f32 / fac_norm * x_x0; + let cnt_pos = vector + .iter() + .map(|x| x.is_sign_positive() as i32) + .sum::(); + let cnt_neg = vector + .iter() + .map(|x| x.is_sign_negative() as i32) + .sum::(); + let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; + [dis_u_2, factor_ppc, factor_ip, factor_err] + } - type Preprocessed = ((Vec, Vec, Vec, Vec), Vec); + fn train_decode + ?Sized>( + u: u32, + meta: &T, + ) -> Self::VectorParams { + let dis_u_2 = meta[4 * u as usize + 0]; + let factor_ppc = meta[4 * u as usize + 1]; + let factor_ip = meta[4 * u as usize + 2]; + let factor_err = meta[4 * u as usize + 3]; + [dis_u_2, factor_ppc, factor_ip, factor_err] + } fn preprocess( - vector: &[f32], - ) -> ( - (f32, f32, f32, f32), - ((Vec, Vec, Vec, Vec), Vec), - ) { + trans_vector: &[f32], + dis_v_2: f32, + ) -> (Self::QvectorParams, Self::QvectorLookup) { use quantization::quantize; - let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize::<15>(vector); - let qvector_sum = if vector.len() <= 4369 { + let (k, b, qvector) = quantize::quantize::<15>(trans_vector); + let qvector_sum = if trans_vector.len() <= 4369 { quantize::reduce_sum_of_x_as_u16(&qvector) as f32 } else { quantize::reduce_sum_of_x_as_u32(&qvector) as f32 }; + let blut = binarize(&qvector); let lut = gen(qvector); ((dis_v_2, b, k, qvector_sum), (blut, lut)) } fn process( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - code: &[u8], - p0: &(f32, f32, f32, f32), - p1: &((Vec, Vec, Vec, Vec), Vec), + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, ) -> Distance { - let abdp = asymmetric_binary_dot_product(code, &p1.0) as u16; - let (rough, _) = rabitq_l2(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, abdp); + let (blut, _) = qvector_lookup; + let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; + let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c, d] => (*a, *b, *c, *d), + }; + let (rough, _) = rabitq_l2( + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); Distance::from_f32(rough) } fn process_lowerbound( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - code: &[u8], - p0: &(f32, f32, f32, f32), - p1: &((Vec, Vec, Vec, Vec), Vec), + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, epsilon: f32, ) -> Distance { - let abdp = asymmetric_binary_dot_product(code, &p1.0) as u16; - let (rough, err) = rabitq_l2(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, abdp); + let (blut, _) = qvector_lookup; + let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; + let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c, d] => (*a, *b, *c, *d), + }; + let (rough, err) = rabitq_l2( + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); Distance::from_f32(rough - epsilon * err) } - fn fscan_preprocess(vector: &[f32]) -> ((f32, f32, f32, f32), Vec) { + fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec) { use quantization::quantize; - let dis_v_2 = f32::reduce_sum_of_x2(vector); - let (k, b, qvector) = quantize::quantize::<15>(vector); - let qvector_sum = if vector.len() <= 4369 { + let (k, b, qvector) = quantize::quantize::<15>(trans_vector); + let qvector_sum = if trans_vector.len() <= 4369 { quantize::reduce_sum_of_x_as_u16(&qvector) as f32 } else { quantize::reduce_sum_of_x_as_u32(&qvector) as f32 @@ -139,15 +186,134 @@ impl OperatorRabitq for VectL2 { } fn fscan_process_lowerbound( - dis_u_2: f32, - factor_ppc: f32, - factor_ip: f32, - factor_err: f32, - p0: &Self::Params, - param: u16, + vector_params: &Self::VectorParams, + qvector_params: &Self::QvectorParams, + binary_prod: u16, epsilon: f32, ) -> Distance { - let (rough, err) = rabitq_l2(dis_u_2, factor_ppc, factor_ip, factor_err, *p0, param); + let (dis_u_2, factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c, d] => (*a, *b, *c, *d), + }; + let (rough, err) = rabitq_l2( + dis_u_2, + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); + Distance::from_f32(rough - epsilon * err) + } +} + +impl OperatorRabitq for VectDot { + fn sample(vectors: &impl Vectors, nlist: u32) -> Vec2 { + VectL2::::sample(vectors, nlist) + } + fn cast(vector: Borrowed<'_, Self>) -> &[f32] { + VectL2::::cast(vector) + } + const SUPPORT_RESIDUAL: bool = false; + fn residual(_lhs: &[f32], _rhs: &[f32]) -> Vec { + unimplemented!() + } + fn proj(projection: &[Vec], vector: &[f32]) -> Vec { + VectL2::::proj(projection, vector) + } + + // [factor_ppc, factor_ip, factor_err] + type VectorParams = [f32; 3]; + // (dis_v_2, b, k, qvector_sum) + type QvectorParams = (f32, f32, f32, f32); + type QvectorLookup = ((Vec, Vec, Vec, Vec), Vec); + + fn train_encode(dims: u32, vector: Vec) -> Self::VectorParams { + let (factor_ppc, factor_ip, factor_err) = match VectL2::::train_encode(dims, vector) { + [_, b, c, d] => (b, c, d), + }; + + [factor_ppc, factor_ip, factor_err] + } + + fn train_decode + ?Sized>( + u: u32, + meta: &T, + ) -> Self::VectorParams { + let factor_ppc = meta[4 * u as usize + 0]; + let factor_ip = meta[4 * u as usize + 1]; + let factor_err = meta[4 * u as usize + 2]; + [factor_ppc, factor_ip, factor_err] + } + + fn preprocess( + trans_vector: &[f32], + dis_v_2: f32, + ) -> (Self::QvectorParams, Self::QvectorLookup) { + VectL2::::preprocess(trans_vector, dis_v_2) + } + + fn process( + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, + ) -> Distance { + let (blut, _) = qvector_lookup; + let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; + let (factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c] => (*a, *b, *c), + }; + let (rough, _) = rabitq_dot( + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); + Distance::from_f32(rough) + } + + fn process_lowerbound( + vector_params: &Self::VectorParams, + qvector_code: &[u8], + qvector_params: &Self::QvectorParams, + qvector_lookup: &Self::QvectorLookup, + epsilon: f32, + ) -> Distance { + let (blut, _) = qvector_lookup; + let binary_prod = asymmetric_binary_dot_product(qvector_code, blut) as u16; + let (factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c] => (*a, *b, *c), + }; + let (rough, err) = rabitq_dot( + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); + Distance::from_f32(rough - epsilon * err) + } + fn fscan_preprocess(trans_vector: &[f32], dis_v_2: f32) -> (Self::QvectorParams, Vec) { + VectL2::::fscan_preprocess(trans_vector, dis_v_2) + } + + fn fscan_process_lowerbound( + vector_params: &Self::VectorParams, + qvector_params: &Self::QvectorParams, + binary_prod: u16, + epsilon: f32, + ) -> Distance { + let (factor_ppc, factor_ip, factor_err) = match vector_params { + [a, b, c] => (*a, *b, *c), + }; + let (rough, err) = rabitq_dot( + factor_ppc, + factor_ip, + factor_err, + *qvector_params, + binary_prod, + ); Distance::from_f32(rough - epsilon * err) } } @@ -163,6 +329,7 @@ macro_rules! unimpl_operator_rabitq { unimplemented!() } + const SUPPORT_RESIDUAL: bool = false; fn residual(_: &[f32], _: &[f32]) -> Vec { unimplemented!() } @@ -171,48 +338,50 @@ macro_rules! unimpl_operator_rabitq { unimplemented!() } - type Params = std::convert::Infallible; - type Preprocessed = std::convert::Infallible; + type VectorParams = [f32; 0]; + type QvectorParams = std::convert::Infallible; + type QvectorLookup = std::convert::Infallible; - fn preprocess(_: &[f32]) -> (Self::Params, Self::Preprocessed) { + fn train_encode(_: u32, _: Vec) -> Self::VectorParams { + unimplemented!() + } + + fn train_decode + ?Sized>( + _: u32, + _: &T, + ) -> Self::VectorParams { + unimplemented!() + } + + fn preprocess(_: &[f32], _: f32) -> (Self::QvectorParams, Self::QvectorLookup) { unimplemented!() } fn process( - _: f32, - _: f32, - _: f32, - _: f32, + _: &Self::VectorParams, _: &[u8], - _: &Self::Params, - _: &Self::Preprocessed, + _: &Self::QvectorParams, + _: &Self::QvectorLookup, ) -> Distance { unimplemented!() } fn process_lowerbound( - _: f32, - _: f32, - _: f32, - _: f32, + _: &Self::VectorParams, _: &[u8], - _: &Self::Params, - _: &Self::Preprocessed, + _: &Self::QvectorParams, + _: &Self::QvectorLookup, _: f32, ) -> Distance { unimplemented!() } - fn fscan_preprocess(_: &[f32]) -> (Self::Params, Vec) { + fn fscan_preprocess(_: &[f32], _: f32) -> (Self::QvectorLookup, Vec) { unimplemented!() } - fn fscan_process_lowerbound( - _: f32, - _: f32, - _: f32, - _: f32, - _: &Self::Params, + _: &Self::VectorParams, + _: &Self::QvectorParams, _: u16, _: f32, ) -> Distance { @@ -222,8 +391,6 @@ macro_rules! unimpl_operator_rabitq { }; } -unimpl_operator_rabitq!(VectDot); - unimpl_operator_rabitq!(VectDot); unimpl_operator_rabitq!(VectL2); @@ -241,14 +408,30 @@ pub fn rabitq_l2( factor_ip: f32, factor_err: f32, (dis_v_2, b, k, qvector_sum): (f32, f32, f32, f32), - abdp: u16, + binary_prod: u16, ) -> (f32, f32) { - let rough = - dis_u_2 + dis_v_2 + b * factor_ppc + ((2.0 * abdp as f32) - qvector_sum) * factor_ip * k; + let rough = dis_u_2 + + dis_v_2 + + b * factor_ppc + + ((2.0 * binary_prod as f32) - qvector_sum) * factor_ip * k; let err = factor_err * dis_v_2.sqrt(); (rough, err) } +#[inline(always)] +pub fn rabitq_dot( + factor_ppc: f32, + factor_ip: f32, + factor_err: f32, + (dis_v_2, b, k, qvector_sum): (f32, f32, f32, f32), + binary_prod: u16, +) -> (f32, f32) { + let rough = + 0.5 * b * factor_ppc + 0.5 * ((2.0 * binary_prod as f32) - qvector_sum) * factor_ip * k; + let err = factor_err * dis_v_2.sqrt() * 0.5; + (rough, err) +} + fn binarize(vector: &[u8]) -> (Vec, Vec, Vec, Vec) { let n = vector.len(); let t0 = { @@ -335,3 +518,177 @@ fn asymmetric_binary_dot_product(x: &[u8], y: &(Vec, Vec, Vec, Vec, + centroid: Vec, + trans_vector: Vec, + } + + static PREPROCESS_O: LazyLock = LazyLock::new(|| { + let original: Vec = [(); LENGTH] + .into_iter() + .map(|_| thread_rng().gen_range((-1.0 * LENGTH as f32)..(LENGTH as f32))) + .collect(); + let centroid: Vec = vec![0.0; LENGTH].into_iter().collect(); + Case { + original: original.clone(), + centroid: centroid.clone(), + trans_vector: VectL2::::residual(&original, ¢roid), + } + }); + + #[test] + fn vector_f32l2_encode_decode() { + let path = env::temp_dir().join("meta_l2"); + let _ = std::fs::remove_file(path.clone()); + let case = &*PREPROCESS_O; + + let meta = + VectL2::::train_encode(case.trans_vector.len() as u32, case.trans_vector.clone()); + let mmap = MmapArray::create(path.clone(), Box::new(meta.into_iter())); + let params = VectL2::::train_decode(0, &mmap); + assert_eq!( + meta, params, + "Vecf32L2 encode and decode failed {:?} != {:?}", + meta, params + ); + std::fs::remove_file(path.clone()).unwrap(); + } + + #[test] + fn vector_f32dot_encode_decode() { + let path = env::temp_dir().join("meta_dot"); + let _ = std::fs::remove_file(path.clone()); + let case = &*PREPROCESS_O; + + let meta = + VectDot::::train_encode(case.trans_vector.len() as u32, case.trans_vector.clone()); + let mmap = MmapArray::create(path.clone(), Box::new(meta.into_iter())); + let params = VectDot::::train_decode(0, &mmap); + assert_eq!( + meta, params, + "Vecf32Dot encode and decode failed {:?} != {:?}", + meta, params + ); + std::fs::remove_file(path.clone()).unwrap(); + } + + #[test] + fn vector_f32l2_no_residual_estimate() { + let mut bad: usize = 0; + let case = &*PREPROCESS_O; + for _ in 0..ATTEMPTS { + let (query, trans_vector, dis_v_2, codes, estimate_failed) = + estimate_prepare_query(&case.centroid); + + let vector_params = VectL2::::train_encode( + case.trans_vector.len() as u32, + case.trans_vector.clone(), + ); + let (qvector_params, qvector_lookup) = + VectL2::::preprocess(&trans_vector, dis_v_2); + let est = + VectL2::::process(&vector_params, &codes, &qvector_params, &qvector_lookup); + let b = VectL2::::process_lowerbound( + &vector_params, + &codes, + &qvector_params, + &qvector_lookup, + EPSILON, + ); + + let real = f32::reduce_sum_of_d2(&query, &case.original); + if estimate_failed(est.to_f32(), b.to_f32(), real) { + bad += 1; + } + } + let error_rate = (bad as f32) / (ATTEMPTS as f32); + assert!( + error_rate < 0.02, + "too many errors: {} in {}", + bad, + ATTEMPTS, + ); + } + + #[test] + fn vector_f32dot_no_residual_estimate() { + let mut bad: usize = 0; + let case = &*PREPROCESS_O; + for _ in 0..ATTEMPTS { + let (query, trans_vector, dis_v_2, codes, estimate_failed) = + estimate_prepare_query(&case.centroid); + + let vector_params = VectDot::::train_encode( + case.trans_vector.len() as u32, + case.trans_vector.clone(), + ); + let (qvector_params, qvector_lookup) = + VectDot::::preprocess(&trans_vector, dis_v_2); + let est = + VectDot::::process(&vector_params, &codes, &qvector_params, &qvector_lookup); + let b = VectDot::::process_lowerbound( + &vector_params, + &codes, + &qvector_params, + &qvector_lookup, + EPSILON, + ); + + let real = -f32::reduce_sum_of_xy(&query, &case.original); + if estimate_failed(est.to_f32(), b.to_f32(), real) { + bad += 1; + } + } + let error_rate = (bad as f32) / (ATTEMPTS as f32); + assert!( + error_rate < 0.02, + "too many errors: {} in {}", + bad, + ATTEMPTS, + ); + } + + fn estimate_prepare_query( + centroid: &Vec, + ) -> ( + Vec, + Vec, + f32, + Vec, + impl Fn(f32, f32, f32) -> bool, + ) { + fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { + b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) + } + let query: Vec = [(); LENGTH] + .into_iter() + .map(|_| thread_rng().gen_range((-1.0 * LENGTH as f32)..(LENGTH as f32))) + .collect(); + let trans_vector = VectL2::::residual(&query, centroid); + let dis_v_2 = f32::reduce_sum_of_xy(&query, centroid); + let codes = + InfiniteByteChunks::new(trans_vector.iter().map(|e| e.is_sign_positive() as u8)) + .map(merge_8) + .take(trans_vector.len().div_ceil(8)) + .collect(); + fn estimate_failed(est: f32, b: f32, real: f32) -> bool { + let upper_bound = 2.0 * est - b; + b <= real && upper_bound >= real + } + (query, trans_vector, dis_v_2, codes, estimate_failed) + } +} diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs index a504ec92a..a9e829e2d 100644 --- a/crates/rabitq/src/quant/quantization.rs +++ b/crates/rabitq/src/quant/quantization.rs @@ -1,4 +1,4 @@ -use super::quantizer::RabitqQuantizer; +use super::quantizer::{Qvector, RabitqQuantizer}; use crate::operator::OperatorRabitq; use base::always_equal::AlwaysEqual; use base::distance::Distance; @@ -24,42 +24,8 @@ impl Quantizer { } } -pub enum QuantizationPreprocessed { - Rabitq( - ( - ::Params, - ::Preprocessed, - ), - ), -} - -impl From> for QuantizationAnyPreprocessed { - fn from(value: QuantizationPreprocessed) -> Self { - match value { - QuantizationPreprocessed::Rabitq((param, blut)) => Self::Rabitq((param, Ok(blut))), - } - } -} - -pub enum QuantizationFscanPreprocessed { - Rabitq((::Params, Vec)), -} - -impl From> for QuantizationAnyPreprocessed { - fn from(value: QuantizationFscanPreprocessed) -> Self { - match value { - QuantizationFscanPreprocessed::Rabitq((param, lut)) => Self::Rabitq((param, Err(lut))), - } - } -} - -pub enum QuantizationAnyPreprocessed { - Rabitq( - ( - ::Params, - Result<::Preprocessed, Vec>, - ), - ), +pub enum RabitqPreprocessed { + Rabitq(Qvector), } pub struct Quantization { @@ -74,7 +40,7 @@ impl Quantization { path: impl AsRef, vector_options: VectorOptions, n: u32, - vectors: impl Fn(u32) -> Vec, + vector_fetch: impl Fn(u32) -> Vec, ) -> Self { std::fs::create_dir(path.as_ref()).unwrap(); fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { @@ -91,7 +57,7 @@ impl Quantization { let codes = MmapArray::create(path.as_ref().join("codes"), { match &*train { Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let vector = vectors(i); + let vector = vector_fetch(i); let codes = x.encode(&vector); let bytes = x.bytes(); match x.bits() { @@ -123,8 +89,9 @@ impl Quantization { let t = x.dims().div_ceil(4); let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { let id = BLOCK_SIZE * block + i as u32; - let e = x.encode(&vectors(std::cmp::min(id, n - 1))); - InfiniteByteChunks::new(e.into_iter()) + let vector = vector_fetch(std::cmp::min(id, n - 1)); + let codes = x.encode(&vector); + InfiniteByteChunks::new(codes.into_iter()) .map(|[b0, b1, b2, b3]| b0 | b1 << 1 | b2 << 2 | b3 << 3) .take(t as usize) .collect() @@ -138,8 +105,8 @@ impl Quantization { path.as_ref().join("meta"), match &*train { Quantizer::Rabitq(x) => Box::new((0..n).flat_map(|i| { - let (a, b, c, d) = x.encode_meta(&vectors(i)); - [a, b, c, d].into_iter() + let vector = vector_fetch(i); + O::train_encode(x.dims(), vector).into_iter() })), }, ); @@ -164,45 +131,44 @@ impl Quantization { } } - pub fn preprocess(&self, lhs: &[f32]) -> QuantizationPreprocessed { - match &*self.train { - Quantizer::Rabitq(x) => QuantizationPreprocessed::Rabitq(x.preprocess(lhs)), - } + pub fn preprocess(&self, trans_vector: &[f32], dis_v_2: f32) -> RabitqPreprocessed { + let (params, blut) = match &*self.train { + Quantizer::Rabitq(x) => x.preprocess(trans_vector, dis_v_2), + }; + RabitqPreprocessed::Rabitq(Qvector::Scan((params, blut))) } - pub fn fscan_preprocess(&self, lhs: &[f32]) -> QuantizationFscanPreprocessed { - match &*self.train { - Quantizer::Rabitq(x) => QuantizationFscanPreprocessed::Rabitq(x.fscan_preprocess(lhs)), - } + pub fn fscan_preprocess(&self, trans_vector: &[f32], dis_v_2: f32) -> RabitqPreprocessed { + let (params, lut) = match &*self.train { + Quantizer::Rabitq(x) => x.fscan_preprocess(trans_vector, dis_v_2), + }; + RabitqPreprocessed::Rabitq(Qvector::FastScan((params, lut))) } - pub fn process(&self, preprocessed: &QuantizationPreprocessed, u: u32) -> Distance { + pub fn process(&self, preprocessed: &RabitqPreprocessed, u: u32) -> Distance { match (&*self.train, preprocessed) { - (Quantizer::Rabitq(x), QuantizationPreprocessed::Rabitq(lhs)) => { + (Quantizer::Rabitq(x), RabitqPreprocessed::Rabitq(Qvector::Scan((params, blut)))) => { let bytes = x.bytes() as usize; let start = u as usize * bytes; let end = start + bytes; - let a = self.meta[4 * u as usize + 0]; - let b = self.meta[4 * u as usize + 1]; - let c = self.meta[4 * u as usize + 2]; - let d = self.meta[4 * u as usize + 3]; - let codes = &self.codes[start..end]; - x.process(&lhs.0, &lhs.1, (a, b, c, d, codes)) + let vector_params = O::train_decode(u, &self.meta); + let code = &self.codes[start..end]; + x.process(&vector_params, params, blut, code) } + _ => unreachable!(), } } pub fn push_batch( &self, - preprocessed: &QuantizationAnyPreprocessed, + preprocessed: &RabitqPreprocessed, range: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, rq_epsilon: f32, ) { match (&*self.train, preprocessed) { - (Quantizer::Rabitq(x), QuantizationAnyPreprocessed::Rabitq((a, b))) => x.push_batch( - a, - b, + (Quantizer::Rabitq(x), RabitqPreprocessed::Rabitq(qvector)) => x.push_batch( + qvector, range, heap, &self.codes, diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs index c65739765..66819895b 100644 --- a/crates/rabitq/src/quant/quantizer.rs +++ b/crates/rabitq/src/quant/quantizer.rs @@ -3,13 +3,17 @@ use crate::operator::OperatorRabitq; use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::index::VectorOptions; -use base::scalar::ScalarLike; use base::search::RerankerPop; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::marker::PhantomData; use std::ops::Range; +pub enum Qvector { + FastScan((O::QvectorParams, Vec)), + Scan((O::QvectorParams, O::QvectorLookup)), +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(bound = "")] pub struct RabitqQuantizer { @@ -42,28 +46,6 @@ impl RabitqQuantizer { self.dims } - pub fn encode_meta(&self, vector: &[f32]) -> (f32, f32, f32, f32) { - let sum_of_abs_x = f32::reduce_sum_of_abs_x(vector); - let sum_of_x_2 = f32::reduce_sum_of_x2(vector); - let dis_u = sum_of_x_2.sqrt(); - let x0 = sum_of_abs_x / (sum_of_x_2 * (self.dims as f32)).sqrt(); - let x_x0 = dis_u / x0; - let fac_norm = (self.dims as f32).sqrt(); - let max_x1 = 1.0f32 / (self.dims as f32 - 1.0).sqrt(); - let factor_err = 2.0f32 * max_x1 * (x_x0 * x_x0 - dis_u * dis_u).sqrt(); - let factor_ip = -2.0f32 / fac_norm * x_x0; - let cnt_pos = vector - .iter() - .map(|x| x.is_sign_positive() as i32) - .sum::(); - let cnt_neg = vector - .iter() - .map(|x| x.is_sign_negative() as i32) - .sum::(); - let factor_ppc = factor_ip * (cnt_pos - cnt_neg) as f32; - (sum_of_x_2, factor_ppc, factor_ip, factor_err) - } - pub fn encode(&self, vector: &[f32]) -> Vec { let mut codes = Vec::new(); for i in 0..self.dims { @@ -72,37 +54,52 @@ impl RabitqQuantizer { codes } - pub fn preprocess(&self, lhs: &[f32]) -> (O::Params, O::Preprocessed) { - O::preprocess(lhs) + pub fn preprocess( + &self, + trans_vector: &[f32], + dis_v_2: f32, + ) -> (O::QvectorParams, O::QvectorLookup) { + O::preprocess(trans_vector, dis_v_2) } - pub fn fscan_preprocess(&self, lhs: &[f32]) -> (O::Params, Vec) { - O::fscan_preprocess(lhs) + pub fn fscan_preprocess( + &self, + trans_vector: &[f32], + dis_v_2: f32, + ) -> (O::QvectorParams, Vec) { + O::fscan_preprocess(trans_vector, dis_v_2) } pub fn process( &self, - p0: &O::Params, - p1: &O::Preprocessed, - (a, b, c, d, e): (f32, f32, f32, f32, &[u8]), + vector_params: &O::VectorParams, + qvector_params: &O::QvectorParams, + qvector_lookup: &O::QvectorLookup, + qvector_code: &[u8], ) -> Distance { - O::process(a, b, c, d, e, p0, p1) + O::process(vector_params, qvector_code, qvector_params, qvector_lookup) } pub fn process_lowerbound( &self, - p0: &O::Params, - p1: &O::Preprocessed, - (a, b, c, d, e): (f32, f32, f32, f32, &[u8]), + vector_params: &O::VectorParams, + qvector_params: &O::QvectorParams, + qvector_lookup: &O::QvectorLookup, + qvector_code: &[u8], epsilon: f32, ) -> Distance { - O::process_lowerbound(a, b, c, d, e, p0, p1, epsilon) + O::process_lowerbound( + vector_params, + qvector_code, + qvector_params, + qvector_lookup, + epsilon, + ) } pub fn push_batch( &self, - alpha: &O::Params, - beta: &Result>, + qvector: &Qvector, range: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], @@ -110,105 +107,136 @@ impl RabitqQuantizer { meta: &[f32], epsilon: f32, ) { - match beta { - Err(lut) => { - use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let s = range.start.next_multiple_of(BLOCK_SIZE); - let e = (range.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - if range.start != s { - let i = s - BLOCK_SIZE; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (range.start..s).map(|u| { - ( - Reverse({ - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) - }), - AlwaysEqual(u), - ) - }) - }); - } - for i in (s..e).step_by(BLOCK_SIZE as _) { - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (i..i + BLOCK_SIZE).map(|u| { - ( - Reverse({ - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) - }), - AlwaysEqual(u), + match qvector { + Qvector::FastScan((params, lut)) => { + self.push_back_fscan(params, lut, range, heap, packed_codes, meta, epsilon); + } + Qvector::Scan((params, blut)) => { + self.push_back_scan(params, blut, range, heap, codes, meta, epsilon); + } + } + } + + #[inline] + fn push_back_fscan( + &self, + qvector_params: &O::QvectorParams, + lut: &[u8], + rhs: Range, + heap: &mut Vec<(Reverse, AlwaysEqual)>, + packed_codes: &[u8], + meta: &[f32], + epsilon: f32, + ) { + use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; + let s = rhs.start.next_multiple_of(BLOCK_SIZE); + let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); + if rhs.start != s { + let i = s - BLOCK_SIZE; + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (rhs.start..s).map(|u| { + ( + Reverse({ + let params = &O::train_decode(u, meta); + let binary_prod = all_binary_product[(u - i) as usize]; + O::fscan_process_lowerbound( + params, + qvector_params, + binary_prod, + epsilon, ) - }) - }); - } - if e != range.end { - let i = e; - let t = self.dims.div_ceil(4); - let bytes = (t * 16) as usize; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(t, &packed_codes[start..end], lut); - heap.extend({ - (e..range.end).map(|u| { - ( - Reverse({ - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - let param = res[(u - i) as usize]; - O::fscan_process_lowerbound(a, b, c, d, alpha, param, epsilon) - }), - AlwaysEqual(u), + }), + AlwaysEqual(u), + ) + }) + }); + } + for i in (s..e).step_by(BLOCK_SIZE as _) { + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (i..i + BLOCK_SIZE).map(|u| { + ( + Reverse({ + let params = &O::train_decode(u, meta); + let binary_prod = all_binary_product[(u - i) as usize]; + O::fscan_process_lowerbound( + params, + qvector_params, + binary_prod, + epsilon, ) - }) - }); - } - } - Ok(blut) => { - heap.extend(range.map(|u| { + }), + AlwaysEqual(u), + ) + }) + }); + } + if e != rhs.end { + let i = e; + let t = self.dims.div_ceil(4); + let bytes = (t * 16) as usize; + let start = (i / BLOCK_SIZE) as usize * bytes; + let end = start + bytes; + let all_binary_product = fast_scan_b4(t, &packed_codes[start..end], lut); + heap.extend({ + (e..rhs.end).map(|u| { ( - Reverse(self.process_lowerbound( - alpha, - blut, - { - let bytes = self.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - let a = meta[4 * u as usize + 0]; - let b = meta[4 * u as usize + 1]; - let c = meta[4 * u as usize + 2]; - let d = meta[4 * u as usize + 3]; - (a, b, c, d, &codes[start..end]) - }, - epsilon, - )), + Reverse({ + let params = &O::train_decode(u, meta); + let binary_prod = all_binary_product[(u - i) as usize]; + O::fscan_process_lowerbound( + params, + qvector_params, + binary_prod, + epsilon, + ) + }), AlwaysEqual(u), ) - })); - } + }) + }); } } + #[inline] + fn push_back_scan( + &self, + qvector_params: &O::QvectorParams, + qvector_lookup: &O::QvectorLookup, + rhs: Range, + heap: &mut Vec<(Reverse, AlwaysEqual)>, + codes: &[u8], + meta: &[f32], + epsilon: f32, + ) { + heap.extend(rhs.map(|u| { + ( + Reverse(self.process_lowerbound( + &O::train_decode(u, meta), + qvector_params, + qvector_lookup, + { + let bytes = self.bytes() as usize; + let start = u as usize * bytes; + let end = start + bytes; + &codes[start..end] + }, + epsilon, + )), + AlwaysEqual(u), + ) + })); + } + pub fn rerank<'a, T: 'a>( &'a self, heap: Vec<(Reverse, AlwaysEqual)>,