From 1ed47d8e00055a0b8826d1f0c345f7ce341282af Mon Sep 17 00:00:00 2001 From: usamoi Date: Wed, 18 Sep 2024 10:48:27 +0800 Subject: [PATCH] refactor: rework quantization abstraction (#591) Signed-off-by: usamoi --- Cargo.lock | 1 + crates/base/src/index.rs | 90 +--- crates/base/src/lib.rs | 1 - crates/base/src/pod.rs | 2 +- crates/base/src/search.rs | 17 +- crates/base/src/vector/bvect.rs | 5 + crates/base/src/vector/mod.rs | 2 + crates/base/src/vector/svect.rs | 5 + crates/base/src/vector/vect.rs | 5 + crates/cli/src/args.rs | 12 +- crates/flat/src/lib.rs | 38 +- crates/hnsw/src/lib.rs | 45 +- crates/index/src/segment/sealed.rs | 9 + crates/indexing/Cargo.toml | 1 + crates/indexing/src/lib.rs | 23 +- crates/indexing/src/sealed.rs | 105 +++- crates/interprocess-atomic-wait/src/lib.rs | 2 +- crates/inverted/src/operator.rs | 6 +- crates/ivf/src/lib.rs | 37 +- crates/ivf/src/operator.rs | 3 +- crates/quantization/src/lib.rs | 378 ++++---------- crates/quantization/src/operator.rs | 127 ----- crates/quantization/src/product.rs | 514 ++++++++++++++++++++ crates/quantization/src/product/mod.rs | 229 --------- crates/quantization/src/product/operator.rs | 101 ---- crates/quantization/src/quantize.rs | 5 - crates/quantization/src/quantizer.rs | 71 +++ crates/quantization/src/reranker/flat.rs | 28 -- crates/quantization/src/reranker/graph.rs | 35 +- crates/quantization/src/reranker/graph_2.rs | 43 ++ crates/quantization/src/reranker/mod.rs | 1 + crates/quantization/src/scalar.rs | 473 ++++++++++++++++++ crates/quantization/src/scalar/mod.rs | 208 -------- crates/quantization/src/scalar/operator.rs | 98 ---- crates/quantization/src/trivial.rs | 128 +++++ crates/quantization/src/trivial/mod.rs | 78 --- crates/quantization/src/trivial/operator.rs | 34 -- crates/quantization/src/utils.rs | 12 + crates/rabitq/src/quant/quantization.rs | 4 +- crates/rabitq/src/quant/quantizer.rs | 20 +- rust-toolchain.toml | 2 +- src/datatype/memory_bvector.rs | 6 +- src/datatype/memory_svecf32.rs | 10 +- src/datatype/memory_vecf16.rs | 6 +- src/datatype/memory_vecf32.rs | 6 +- src/datatype/subscript_bvector.rs | 2 +- src/datatype/subscript_svecf32.rs | 2 +- src/datatype/subscript_vecf16.rs | 2 +- src/datatype/subscript_vecf32.rs | 2 +- src/gucs/executing.rs | 92 +--- src/index/am_options.rs | 2 +- 51 files changed, 1638 insertions(+), 1490 deletions(-) delete mode 100644 crates/quantization/src/operator.rs create mode 100644 crates/quantization/src/product.rs delete mode 100644 crates/quantization/src/product/mod.rs delete mode 100644 crates/quantization/src/product/operator.rs create mode 100644 crates/quantization/src/quantizer.rs create mode 100644 crates/quantization/src/reranker/graph_2.rs create mode 100644 crates/quantization/src/scalar.rs delete mode 100644 crates/quantization/src/scalar/mod.rs delete mode 100644 crates/quantization/src/scalar/operator.rs create mode 100644 crates/quantization/src/trivial.rs delete mode 100644 crates/quantization/src/trivial/mod.rs delete mode 100644 crates/quantization/src/trivial/operator.rs diff --git a/Cargo.lock b/Cargo.lock index 2e821431b..5bc28158f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1521,6 +1521,7 @@ dependencies = [ "hnsw", "inverted", "ivf", + "quantization", "rabitq", "thiserror", ] diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs index 35587598a..5cddf16fe 100644 --- a/crates/base/src/index.rs +++ b/crates/base/src/index.rs @@ -106,11 +106,11 @@ pub struct IndexOptions { impl IndexOptions { fn validate_self_quantization( &self, - quantization: &QuantizationOptions, + quantization: &Option, ) -> Result<(), ValidationError> { match quantization { - QuantizationOptions::Trivial(_) => Ok(()), - QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_) => { + None => Ok(()), + Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => { if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) { return Err(ValidationError::new( "scalar quantization or product quantization is not support for vectors that are not dense vectors", @@ -356,13 +356,13 @@ impl Default for InvertedIndexingOptions { pub struct FlatIndexingOptions { #[serde(default)] #[validate(nested)] - pub quantization: QuantizationOptions, + pub quantization: Option, } impl Default for FlatIndexingOptions { fn default() -> Self { Self { - quantization: QuantizationOptions::default(), + quantization: Default::default(), } } } @@ -379,7 +379,7 @@ pub struct IvfIndexingOptions { pub residual_quantization: bool, #[serde(default)] #[validate(nested)] - pub quantization: QuantizationOptions, + pub quantization: Option, } impl IvfIndexingOptions { @@ -418,7 +418,7 @@ pub struct HnswIndexingOptions { pub ef_construction: u32, #[serde(default)] #[validate(nested)] - pub quantization: QuantizationOptions, + pub quantization: Option, } impl HnswIndexingOptions { @@ -472,7 +472,6 @@ impl Default for RabitqIndexingOptions { #[serde(deny_unknown_fields)] #[serde(rename_all = "snake_case")] pub enum QuantizationOptions { - Trivial(TrivialQuantizationOptions), Scalar(ScalarQuantizationOptions), Product(ProductQuantizationOptions), } @@ -480,29 +479,12 @@ pub enum QuantizationOptions { impl Validate for QuantizationOptions { fn validate(&self) -> Result<(), validator::ValidationErrors> { match self { - Self::Trivial(x) => x.validate(), Self::Scalar(x) => x.validate(), Self::Product(x) => x.validate(), } } } -impl Default for QuantizationOptions { - fn default() -> Self { - Self::Trivial(Default::default()) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct TrivialQuantizationOptions {} - -impl Default for TrivialQuantizationOptions { - fn default() -> Self { - Self {} - } -} - #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] #[validate(schema(function = "Self::validate_self"))] @@ -569,26 +551,16 @@ impl Default for ProductQuantizationOptions { #[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)] #[serde(deny_unknown_fields)] pub struct SearchOptions { - #[serde(default = "SearchOptions::default_flat_sq_rerank_size")] - #[validate(range(min = 0, max = 65535))] - pub flat_sq_rerank_size: u32, - #[serde(default = "SearchOptions::default_flat_sq_fast_scan")] - pub flat_sq_fast_scan: bool, - #[serde(default = "SearchOptions::default_flat_pq_rerank_size")] - #[validate(range(min = 0, max = 65535))] - pub flat_pq_rerank_size: u32, - #[serde(default = "SearchOptions::default_flat_pq_fast_scan")] - pub flat_pq_fast_scan: bool, - #[serde(default = "SearchOptions::default_ivf_sq_rerank_size")] + #[serde(default = "SearchOptions::default_sq_rerank_size")] #[validate(range(min = 0, max = 65535))] - pub ivf_sq_rerank_size: u32, - #[serde(default = "SearchOptions::default_ivf_sq_fast_scan")] - pub ivf_sq_fast_scan: bool, - #[serde(default = "SearchOptions::default_ivf_pq_rerank_size")] + pub sq_rerank_size: u32, + #[serde(default = "SearchOptions::default_sq_fast_scan")] + pub sq_fast_scan: bool, + #[serde(default = "SearchOptions::default_pq_rerank_size")] #[validate(range(min = 0, max = 65535))] - pub ivf_pq_rerank_size: u32, - #[serde(default = "SearchOptions::default_ivf_pq_fast_scan")] - pub ivf_pq_fast_scan: bool, + pub pq_rerank_size: u32, + #[serde(default = "SearchOptions::default_pq_fast_scan")] + pub pq_fast_scan: bool, #[serde(default = "SearchOptions::default_ivf_nprobe")] #[validate(range(min = 1, max = 65535))] pub ivf_nprobe: u32, @@ -609,28 +581,16 @@ pub struct SearchOptions { } impl SearchOptions { - pub const fn default_flat_sq_rerank_size() -> u32 { - 0 - } - pub const fn default_flat_sq_fast_scan() -> bool { - false - } - pub const fn default_flat_pq_rerank_size() -> u32 { - 0 - } - pub const fn default_flat_pq_fast_scan() -> bool { - false - } - pub const fn default_ivf_sq_rerank_size() -> u32 { + pub const fn default_sq_rerank_size() -> u32 { 0 } - pub const fn default_ivf_sq_fast_scan() -> bool { + pub const fn default_sq_fast_scan() -> bool { false } - pub const fn default_ivf_pq_rerank_size() -> u32 { + pub const fn default_pq_rerank_size() -> u32 { 0 } - pub const fn default_ivf_pq_fast_scan() -> bool { + pub const fn default_pq_fast_scan() -> bool { false } pub const fn default_ivf_nprobe() -> u32 { @@ -656,14 +616,10 @@ impl SearchOptions { impl Default for SearchOptions { fn default() -> Self { Self { - flat_sq_rerank_size: Self::default_flat_sq_rerank_size(), - flat_sq_fast_scan: Self::default_flat_sq_fast_scan(), - flat_pq_rerank_size: Self::default_flat_pq_rerank_size(), - flat_pq_fast_scan: Self::default_flat_pq_fast_scan(), - ivf_sq_rerank_size: Self::default_ivf_sq_rerank_size(), - ivf_sq_fast_scan: Self::default_ivf_sq_fast_scan(), - ivf_pq_rerank_size: Self::default_ivf_pq_rerank_size(), - ivf_pq_fast_scan: Self::default_ivf_pq_fast_scan(), + sq_rerank_size: Self::default_sq_rerank_size(), + sq_fast_scan: Self::default_sq_fast_scan(), + pq_rerank_size: Self::default_pq_rerank_size(), + pq_fast_scan: Self::default_pq_fast_scan(), ivf_nprobe: Self::default_ivf_nprobe(), hnsw_ef_search: Self::default_hnsw_ef_search(), rabitq_nprobe: Self::default_rabitq_nprobe(), diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs index d776e8b73..a303bbb6c 100644 --- a/crates/base/src/lib.rs +++ b/crates/base/src/lib.rs @@ -1,4 +1,3 @@ -#![feature(const_float_bits_conv)] #![feature(avx512_target_feature)] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))] #![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))] diff --git a/crates/base/src/pod.rs b/crates/base/src/pod.rs index e03b08a21..2e9c6d2bd 100644 --- a/crates/base/src/pod.rs +++ b/crates/base/src/pod.rs @@ -36,7 +36,7 @@ unsafe impl Pod for Distance {} unsafe impl Pod for Impossible {} pub fn bytes_of(t: &T) -> &[u8] { - unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::()) } + unsafe { core::slice::from_raw_parts(t as *const T as *const u8, size_of::()) } } pub fn zeroed_vec(length: usize) -> Vec { diff --git a/crates/base/src/search.rs b/crates/base/src/search.rs index 316fd48d4..86b6f3d83 100644 --- a/crates/base/src/search.rs +++ b/crates/base/src/search.rs @@ -3,6 +3,8 @@ use crate::distance::Distance; use crate::vector::VectorOwned; use serde::{Deserialize, Serialize}; use std::any::Any; +use std::cmp::Reverse; +use std::collections::BinaryHeap; use std::fmt::Display; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] @@ -93,14 +95,13 @@ pub trait RerankerPop { fn pop(&mut self) -> Option<(Distance, u32, T)>; } -pub trait RerankerPush { - fn push(&mut self, u: u32); -} - -pub trait FlatReranker: RerankerPop {} - -impl<'a, T> RerankerPop for Box + 'a> { +impl RerankerPop for BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)> { fn pop(&mut self) -> Option<(Distance, u32, T)> { - self.as_mut().pop() + let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.pop()?; + Some((dis_u, u, pay_u)) } } + +pub trait RerankerPush { + fn push(&mut self, u: u32); +} diff --git a/crates/base/src/vector/bvect.rs b/crates/base/src/vector/bvect.rs index b7a7007b1..9822989d3 100644 --- a/crates/base/src/vector/bvect.rs +++ b/crates/base/src/vector/bvect.rs @@ -55,6 +55,11 @@ impl VectorOwned for BVectOwned { data: &self.data, } } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(dims, vec![0; dims.div_ceil(BVECTOR_WIDTH) as usize]) + } } #[derive(Debug, Clone, Copy)] diff --git a/crates/base/src/vector/mod.rs b/crates/base/src/vector/mod.rs index 74a5eff1c..d16d5e754 100644 --- a/crates/base/src/vector/mod.rs +++ b/crates/base/src/vector/mod.rs @@ -24,6 +24,8 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static { type Borrowed<'a>: VectorBorrowed; fn as_borrowed(&self) -> Self::Borrowed<'_>; + + fn zero(dims: u32) -> Self; } pub trait VectorBorrowed: Copy + PartialEq + PartialOrd { diff --git a/crates/base/src/vector/svect.rs b/crates/base/src/vector/svect.rs index c53aeccf0..7f2a23325 100644 --- a/crates/base/src/vector/svect.rs +++ b/crates/base/src/vector/svect.rs @@ -77,6 +77,11 @@ impl VectorOwned for SVectOwned { values: &self.values, } } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(dims, vec![], vec![]) + } } #[derive(Debug, Clone, Copy)] diff --git a/crates/base/src/vector/vect.rs b/crates/base/src/vector/vect.rs index ddc8acfc7..4f7db360d 100644 --- a/crates/base/src/vector/vect.rs +++ b/crates/base/src/vector/vect.rs @@ -48,6 +48,11 @@ impl VectorOwned for VectOwned { fn as_borrowed(&self) -> VectBorrowed<'_, S> { VectBorrowed(self.0.as_slice()) } + + #[inline(always)] + fn zero(dims: u32) -> Self { + Self::new(vec![S::zero(); dims as usize]) + } } #[derive(Debug, Clone, Copy)] diff --git a/crates/cli/src/args.rs b/crates/cli/src/args.rs index cfe63380a..1004995f6 100644 --- a/crates/cli/src/args.rs +++ b/crates/cli/src/args.rs @@ -131,17 +131,13 @@ pub struct QueryArguments { impl QueryArguments { pub fn get_search_options(&self) -> SearchOptions { SearchOptions { - flat_sq_rerank_size: 0, - flat_pq_rerank_size: 0, - ivf_sq_rerank_size: 0, - ivf_pq_rerank_size: 0, + sq_rerank_size: 0, + pq_rerank_size: 0, hnsw_ef_search: self.ef, ivf_nprobe: self.probe, diskann_ef_search: 100, - flat_sq_fast_scan: false, - flat_pq_fast_scan: false, - ivf_sq_fast_scan: false, - ivf_pq_fast_scan: false, + sq_fast_scan: false, + pq_fast_scan: false, rabitq_epsilon: 1.9, rabitq_fast_scan: true, rabitq_nprobe: self.probe, diff --git a/crates/flat/src/lib.rs b/crates/flat/src/lib.rs index b90312169..1b1424cb1 100644 --- a/crates/flat/src/lib.rs +++ b/crates/flat/src/lib.rs @@ -7,24 +7,24 @@ use base::search::*; use base::vector::VectorBorrowed; use common::mmap_array::MmapArray; use common::remap::RemappedCollection; -use quantization::operator::OperatorQuantization; +use quantization::quantizer::Quantizer; use quantization::Quantization; use std::fs::create_dir; use std::path::Path; use storage::OperatorStorage; use storage::Storage; -pub trait OperatorFlat: OperatorQuantization + OperatorStorage {} +pub trait OperatorFlat: OperatorStorage {} -impl OperatorFlat for T {} +impl OperatorFlat for T {} -pub struct Flat { +pub struct Flat> { storage: O::Storage, - quantization: Quantization, + quantization: Quantization, payloads: MmapArray, } -impl Flat { +impl> Flat { pub fn create( path: impl AsRef, options: IndexOptions, @@ -43,20 +43,14 @@ impl Flat { vector: Borrowed<'a, O>, opts: &'a SearchOptions, ) -> Box + 'a> { - let mut heap = Vec::new(); - let preprocessed = self.quantization.preprocess(vector); - self.quantization.push_batch( - &preprocessed, - 0..self.storage.len(), - &mut heap, - opts.flat_sq_fast_scan, - opts.flat_pq_fast_scan, - ); - let mut reranker = self.quantization.flat_rerank( + let mut heap = Q::flat_rerank_start(); + let lut = self.quantization.flat_rerank_preprocess(vector, opts); + self.quantization + .flat_rerank_continue(&lut, 0..self.storage.len(), &mut heap); + let mut reranker = self.quantization.flat_rerank_break( heap, move |u| (O::distance(vector, self.storage.vector(u)), ()), - opts.flat_sq_rerank_size, - opts.flat_pq_rerank_size, + opts, ); Box::new(std::iter::from_fn(move || { reranker.pop().map(|(dis_u, u, ())| Element { @@ -83,15 +77,15 @@ impl Flat { } } -fn from_nothing( +fn from_nothing>( path: impl AsRef, options: IndexOptions, collection: &(impl Vectors> + Collection + Sync), -) -> Flat { +) -> Flat { create_dir(path.as_ref()).unwrap(); let flat_indexing_options = options.indexing.clone().unwrap_flat(); let storage = O::Storage::create(path.as_ref().join("storage"), collection); - let quantization = Quantization::::create( + let quantization = Quantization::::create( path.as_ref().join("quantization"), options.vector, flat_indexing_options.quantization, @@ -109,7 +103,7 @@ fn from_nothing( } } -fn open(path: impl AsRef) -> Flat { +fn open>(path: impl AsRef) -> Flat { let storage = O::Storage::open(path.as_ref().join("storage")); let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); diff --git a/crates/hnsw/src/lib.rs b/crates/hnsw/src/lib.rs index 12c9a62ce..7ac72337f 100644 --- a/crates/hnsw/src/lib.rs +++ b/crates/hnsw/src/lib.rs @@ -12,7 +12,7 @@ use common::mmap_array::MmapArray; use common::remap::RemappedCollection; use graph::visited::VisitedPool; use parking_lot::RwLock; -use quantization::operator::OperatorQuantization; +use quantization::quantizer::Quantizer; use quantization::Quantization; use rayon::iter::{IntoParallelIterator, ParallelIterator}; use std::fs::create_dir; @@ -24,13 +24,13 @@ use stoppable_rayon as rayon; use storage::OperatorStorage; use storage::Storage; -pub trait OperatorHnsw: OperatorQuantization + OperatorStorage {} +pub trait OperatorHnsw: OperatorStorage {} -impl OperatorHnsw for T {} +impl OperatorHnsw for T {} -pub struct Hnsw { +pub struct Hnsw> { storage: O::Storage, - quantization: Quantization, + quantization: Quantization, payloads: MmapArray, base_graph_outs: MmapArray, base_graph_weights: MmapArray, @@ -41,7 +41,7 @@ pub struct Hnsw { visited: VisitedPool, } -impl Hnsw { +impl> Hnsw { pub fn create( path: impl AsRef, options: IndexOptions, @@ -113,11 +113,11 @@ impl Hnsw { } } -fn from_nothing( +fn from_nothing>( path: impl AsRef, options: IndexOptions, collection: &(impl Vectors> + Collection + Sync), -) -> Hnsw { +) -> Hnsw { create_dir(path.as_ref()).unwrap(); let HnswIndexingOptions { m, @@ -137,7 +137,7 @@ fn from_nothing( finish(&mut g, m); let storage = O::Storage::create(path.as_ref().join("storage"), collection); rayon::check(); - let quantization = Quantization::::create( + let quantization = Quantization::::create( path.as_ref().join("quantization"), options.vector, quantization_options, @@ -195,12 +195,12 @@ fn from_nothing( } } -fn from_main( +fn from_main>( path: impl AsRef, options: IndexOptions, remapped: &RemappedCollection, impl Vectors> + Collection + Sync>, - main: &Hnsw, -) -> Hnsw { + main: &Hnsw, +) -> Hnsw { create_dir(path.as_ref()).unwrap(); let HnswIndexingOptions { m, @@ -235,7 +235,7 @@ fn from_main( finish(&mut g, m); let storage = O::Storage::create(path.as_ref().join("storage"), remapped); rayon::check(); - let quantization = Quantization::::create( + let quantization = Quantization::::create( path.as_ref().join("quantization"), options.vector, quantization_options, @@ -294,7 +294,7 @@ fn from_main( } } -fn open(path: impl AsRef) -> Hnsw { +fn open>(path: impl AsRef) -> Hnsw { let storage = O::Storage::open(path.as_ref().join("storage")); let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); @@ -590,8 +590,8 @@ fn capacity_for_a_hierarchy(m: u32, level: u8) -> u32 { } } -fn base_edges( - hnsw: &Hnsw, +fn base_edges>( + hnsw: &Hnsw, u: u32, ) -> impl Iterator + '_ { let m = *hnsw.m; @@ -606,7 +606,10 @@ fn base_edges( edges_weights.zip(edges_outs) } -fn base_outs(hnsw: &Hnsw, u: u32) -> impl Iterator + '_ { +fn base_outs>( + hnsw: &Hnsw, + u: u32, +) -> impl Iterator + '_ { let m = *hnsw.m; let offset = 2 * m as usize * u as usize; hnsw.base_graph_outs[offset..offset + 2 * m as usize] @@ -615,8 +618,8 @@ fn base_outs(hnsw: &Hnsw, u: u32) -> impl Iterator( - hnsw: &Hnsw, +fn hyper_edges>( + hnsw: &Hnsw, u: u32, level: u8, ) -> impl Iterator + '_ { @@ -643,8 +646,8 @@ fn hyper_edges( edges_weights.zip(edges_outs) } -fn hyper_outs( - hnsw: &Hnsw, +fn hyper_outs>( + hnsw: &Hnsw, u: u32, level: u8, ) -> impl Iterator + '_ { diff --git a/crates/index/src/segment/sealed.rs b/crates/index/src/segment/sealed.rs index ffdcb5dc6..21bb1eb17 100644 --- a/crates/index/src/segment/sealed.rs +++ b/crates/index/src/segment/sealed.rs @@ -119,8 +119,17 @@ impl SealedSegment { pub fn indexing(&self) -> &dyn Any { match &self.indexing { SealedIndexing::Flat(x) => x, + SealedIndexing::FlatPq(x) => x, + SealedIndexing::FlatSq(x) => x, + SealedIndexing::Ivf(x) => x, + SealedIndexing::IvfPq(x) => x, + SealedIndexing::IvfSq(x) => x, + SealedIndexing::Hnsw(x) => x, + SealedIndexing::HnswPq(x) => x, + SealedIndexing::HnswSq(x) => x, + SealedIndexing::InvertedIndex(x) => x, SealedIndexing::Rabitq(x) => x, } diff --git a/crates/indexing/Cargo.toml b/crates/indexing/Cargo.toml index acebfd531..4bbec8a1d 100644 --- a/crates/indexing/Cargo.toml +++ b/crates/indexing/Cargo.toml @@ -13,6 +13,7 @@ flat = { path = "../flat" } hnsw = { path = "../hnsw" } inverted = { path = "../inverted" } ivf = { path = "../ivf" } +quantization = { path = "../quantization" } rabitq = { path = "../rabitq" } [lints] diff --git a/crates/indexing/src/lib.rs b/crates/indexing/src/lib.rs index efbf8ca3e..0ca325662 100644 --- a/crates/indexing/src/lib.rs +++ b/crates/indexing/src/lib.rs @@ -5,11 +5,28 @@ pub use sealed::SealedIndexing; use base::operator::Operator; use inverted::operator::OperatorInvertedIndex; use ivf::operator::OperatorIvf; +use quantization::product::OperatorProductQuantization; +use quantization::scalar::OperatorScalarQuantization; use rabitq::operator::OperatorRabitq; -pub trait OperatorIndexing: - Operator + OperatorIvf + OperatorInvertedIndex + OperatorRabitq +pub trait OperatorIndexing +where + Self: Operator, + Self: OperatorIvf, + Self: OperatorInvertedIndex, + Self: OperatorRabitq, + Self: OperatorScalarQuantization, + Self: OperatorProductQuantization, { } -impl OperatorIndexing for T {} +impl OperatorIndexing for T +where + Self: Operator, + Self: OperatorIvf, + Self: OperatorInvertedIndex, + Self: OperatorRabitq, + Self: OperatorScalarQuantization, + Self: OperatorProductQuantization, +{ +} diff --git a/crates/indexing/src/sealed.rs b/crates/indexing/src/sealed.rs index dab62bcea..23fda96c2 100644 --- a/crates/indexing/src/sealed.rs +++ b/crates/indexing/src/sealed.rs @@ -6,13 +6,22 @@ use flat::Flat; use hnsw::Hnsw; use inverted::InvertedIndex; use ivf::Ivf; +use quantization::product::ProductQuantizer; +use quantization::scalar::ScalarQuantizer; +use quantization::trivial::TrivialQuantizer; use rabitq::Rabitq; use std::path::Path; pub enum SealedIndexing { - Flat(Flat), - Ivf(Ivf), - Hnsw(Hnsw), + Flat(Flat>), + FlatSq(Flat>), + FlatPq(Flat>), + Ivf(Ivf>), + IvfSq(Ivf>), + IvfPq(Ivf>), + Hnsw(Hnsw>), + HnswSq(Hnsw>), + HnswPq(Hnsw>), InvertedIndex(InvertedIndex), Rabitq(Rabitq), } @@ -24,9 +33,39 @@ impl SealedIndexing { source: &(impl Vectors> + Collection + Source + Sync), ) -> Self { match options.indexing { - IndexingOptions::Flat(_) => Self::Flat(Flat::create(path, options, source)), - IndexingOptions::Ivf(_) => Self::Ivf(Ivf::create(path, options, source)), - IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::create(path, options, source)), + IndexingOptions::Flat(FlatIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Flat(Flat::create(path, options, source)), + Some(QuantizationOptions::Scalar(_)) => { + Self::FlatSq(Flat::create(path, options, source)) + } + Some(QuantizationOptions::Product(_)) => { + Self::FlatPq(Flat::create(path, options, source)) + } + }, + IndexingOptions::Ivf(IvfIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Ivf(Ivf::create(path, options, source)), + Some(QuantizationOptions::Scalar(_)) => { + Self::IvfSq(Ivf::create(path, options, source)) + } + Some(QuantizationOptions::Product(_)) => { + Self::IvfPq(Ivf::create(path, options, source)) + } + }, + IndexingOptions::Hnsw(HnswIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Hnsw(Hnsw::create(path, options, source)), + Some(QuantizationOptions::Scalar(_)) => { + Self::HnswSq(Hnsw::create(path, options, source)) + } + Some(QuantizationOptions::Product(_)) => { + Self::HnswPq(Hnsw::create(path, options, source)) + } + }, IndexingOptions::InvertedIndex(_) => { Self::InvertedIndex(InvertedIndex::create(path, options, source)) } @@ -36,9 +75,27 @@ impl SealedIndexing { pub fn open(path: impl AsRef, options: IndexOptions) -> Self { match options.indexing { - IndexingOptions::Flat(_) => Self::Flat(Flat::open(path)), - IndexingOptions::Ivf(_) => Self::Ivf(Ivf::open(path)), - IndexingOptions::Hnsw(_) => Self::Hnsw(Hnsw::open(path)), + IndexingOptions::Flat(FlatIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Flat(Flat::open(path)), + Some(QuantizationOptions::Scalar(_)) => Self::FlatSq(Flat::open(path)), + Some(QuantizationOptions::Product(_)) => Self::FlatPq(Flat::open(path)), + }, + IndexingOptions::Ivf(IvfIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Ivf(Ivf::open(path)), + Some(QuantizationOptions::Scalar(_)) => Self::IvfSq(Ivf::open(path)), + Some(QuantizationOptions::Product(_)) => Self::IvfPq(Ivf::open(path)), + }, + IndexingOptions::Hnsw(HnswIndexingOptions { + ref quantization, .. + }) => match quantization { + None => Self::Hnsw(Hnsw::open(path)), + Some(QuantizationOptions::Scalar(_)) => Self::HnswSq(Hnsw::open(path)), + Some(QuantizationOptions::Product(_)) => Self::HnswPq(Hnsw::open(path)), + }, IndexingOptions::InvertedIndex(_) => Self::InvertedIndex(InvertedIndex::open(path)), IndexingOptions::Rabitq(_) => Self::Rabitq(Rabitq::open(path)), } @@ -51,8 +108,14 @@ impl SealedIndexing { ) -> Box + 'a> { match self { SealedIndexing::Flat(x) => x.vbase(vector, opts), + SealedIndexing::FlatPq(x) => x.vbase(vector, opts), + SealedIndexing::FlatSq(x) => x.vbase(vector, opts), SealedIndexing::Ivf(x) => x.vbase(vector, opts), + SealedIndexing::IvfPq(x) => x.vbase(vector, opts), + SealedIndexing::IvfSq(x) => x.vbase(vector, opts), SealedIndexing::Hnsw(x) => x.vbase(vector, opts), + SealedIndexing::HnswPq(x) => x.vbase(vector, opts), + SealedIndexing::HnswSq(x) => x.vbase(vector, opts), SealedIndexing::InvertedIndex(x) => x.vbase(vector, opts), SealedIndexing::Rabitq(x) => x.vbase(vector, opts), } @@ -63,8 +126,14 @@ impl Vectors> for SealedIndexing { fn dims(&self) -> u32 { match self { SealedIndexing::Flat(x) => x.dims(), + SealedIndexing::FlatSq(x) => x.dims(), + SealedIndexing::FlatPq(x) => x.dims(), SealedIndexing::Ivf(x) => x.dims(), + SealedIndexing::IvfSq(x) => x.dims(), + SealedIndexing::IvfPq(x) => x.dims(), SealedIndexing::Hnsw(x) => x.dims(), + SealedIndexing::HnswPq(x) => x.dims(), + SealedIndexing::HnswSq(x) => x.dims(), SealedIndexing::InvertedIndex(x) => x.dims(), SealedIndexing::Rabitq(x) => x.dims(), } @@ -73,8 +142,14 @@ impl Vectors> for SealedIndexing { fn len(&self) -> u32 { match self { SealedIndexing::Flat(x) => x.len(), + SealedIndexing::FlatPq(x) => x.len(), + SealedIndexing::FlatSq(x) => x.len(), SealedIndexing::Ivf(x) => x.len(), + SealedIndexing::IvfPq(x) => x.len(), + SealedIndexing::IvfSq(x) => x.len(), SealedIndexing::Hnsw(x) => x.len(), + SealedIndexing::HnswPq(x) => x.len(), + SealedIndexing::HnswSq(x) => x.len(), SealedIndexing::InvertedIndex(x) => x.len(), SealedIndexing::Rabitq(x) => x.len(), } @@ -83,8 +158,14 @@ impl Vectors> for SealedIndexing { fn vector(&self, i: u32) -> Borrowed<'_, O> { match self { SealedIndexing::Flat(x) => x.vector(i), + SealedIndexing::FlatPq(x) => x.vector(i), + SealedIndexing::FlatSq(x) => x.vector(i), SealedIndexing::Ivf(x) => x.vector(i), + SealedIndexing::IvfSq(x) => x.vector(i), + SealedIndexing::IvfPq(x) => x.vector(i), SealedIndexing::Hnsw(x) => x.vector(i), + SealedIndexing::HnswSq(x) => x.vector(i), + SealedIndexing::HnswPq(x) => x.vector(i), SealedIndexing::InvertedIndex(x) => x.vector(i), SealedIndexing::Rabitq(x) => x.vector(i), } @@ -95,8 +176,14 @@ impl Collection for SealedIndexing { fn payload(&self, i: u32) -> Payload { match self { SealedIndexing::Flat(x) => x.payload(i), + SealedIndexing::FlatPq(x) => x.payload(i), + SealedIndexing::FlatSq(x) => x.payload(i), SealedIndexing::Ivf(x) => x.payload(i), + SealedIndexing::IvfPq(x) => x.payload(i), + SealedIndexing::IvfSq(x) => x.payload(i), SealedIndexing::Hnsw(x) => x.payload(i), + SealedIndexing::HnswPq(x) => x.payload(i), + SealedIndexing::HnswSq(x) => x.payload(i), SealedIndexing::InvertedIndex(x) => x.payload(i), SealedIndexing::Rabitq(x) => x.payload(i), } diff --git a/crates/interprocess-atomic-wait/src/lib.rs b/crates/interprocess-atomic-wait/src/lib.rs index 950d13b7f..e567080a5 100644 --- a/crates/interprocess-atomic-wait/src/lib.rs +++ b/crates/interprocess-atomic-wait/src/lib.rs @@ -70,7 +70,7 @@ pub fn wait(futex: &AtomicU32, value: u32, timeout: Duration) { libc::UMTX_OP_WAIT_UINT, value as libc::c_ulong, std::mem::size_of_val(&timeout) as *mut std::ffi::c_void, - std::ptr::addr_of_mut!(timeout).cast(), + &mut timeout as *mut libc::timespec as *mut _, ); }; } diff --git a/crates/inverted/src/operator.rs b/crates/inverted/src/operator.rs index e43aa8207..a7111c857 100644 --- a/crates/inverted/src/operator.rs +++ b/crates/inverted/src/operator.rs @@ -1,10 +1,8 @@ use base::{operator::*, scalar::ScalarLike}; -use quantization::operator::OperatorQuantization; -use storage::OperatorStorage; - use std::iter::{zip, Empty}; +use storage::OperatorStorage; -pub trait OperatorInvertedIndex: OperatorQuantization + OperatorStorage { +pub trait OperatorInvertedIndex: OperatorStorage { fn to_index_vec(vec: Borrowed<'_, Self>) -> impl Iterator + '_; } diff --git a/crates/ivf/src/lib.rs b/crates/ivf/src/lib.rs index 75ad9e0e6..024bd7e2c 100644 --- a/crates/ivf/src/lib.rs +++ b/crates/ivf/src/lib.rs @@ -17,6 +17,7 @@ use k_means::k_means; use k_means::k_means_lookup; use k_means::k_means_lookup_many; use operator::OperatorIvf as Op; +use quantization::quantizer::Quantizer; use quantization::Quantization; use rayon::iter::IntoParallelIterator; use rayon::iter::ParallelIterator; @@ -25,16 +26,16 @@ use std::path::Path; use stoppable_rayon as rayon; use storage::Storage; -pub struct Ivf { +pub struct Ivf> { storage: O::Storage, - quantization: Quantization, + quantization: Quantization, payloads: MmapArray, offsets: Json>, centroids: Json::Scalar>>, is_residual: Json, } -impl Ivf { +impl> Ivf { pub fn create( path: impl AsRef, options: IndexOptions, @@ -73,28 +74,24 @@ impl Ivf { k_means_lookup_many(O::interpret(vector), &self.centroids), opts.ivf_nprobe as usize, ); - let mut heap = Vec::new(); - let mut preprocessed = self.quantization.preprocess(vector); + let mut heap = Q::flat_rerank_start(); + let mut lut = self.quantization.flat_rerank_preprocess(vector, opts); for i in lists.iter().map(|(_, i)| *i) { if *self.is_residual { let vector = O::residual(vector, &self.centroids[(i,)]); - preprocessed = self.quantization.preprocess(vector.as_borrowed()); + lut = self + .quantization + .flat_rerank_preprocess(vector.as_borrowed(), opts); } let start = self.offsets[i]; let end = self.offsets[i + 1]; - self.quantization.push_batch( - &preprocessed, - start..end, - &mut heap, - opts.ivf_sq_fast_scan, - opts.ivf_pq_fast_scan, - ); + self.quantization + .flat_rerank_continue(&lut, start..end, &mut heap); } - let mut reranker = self.quantization.flat_rerank( + let mut reranker = self.quantization.flat_rerank_break( heap, move |u| (O::distance(vector, self.storage.vector(u)), ()), - opts.ivf_sq_rerank_size, - opts.ivf_pq_rerank_size, + opts, ); Box::new(std::iter::from_fn(move || { reranker.pop().map(|(dis_u, u, ())| Element { @@ -105,11 +102,11 @@ impl Ivf { } } -fn from_nothing( +fn from_nothing>( path: impl AsRef, options: IndexOptions, collection: &(impl Vectors> + Collection + Sync), -) -> Ivf { +) -> Ivf { create_dir(path.as_ref()).unwrap(); let IvfIndexingOptions { nlist, @@ -158,7 +155,7 @@ fn from_nothing( let is_residual = residual_quantization && O::SUPPORT_RESIDUAL; rayon::check(); let storage = O::Storage::create(path.as_ref().join("storage"), &collection); - let quantization = Quantization::::create( + let quantization = Quantization::::create( path.as_ref().join("quantization"), options.vector, quantization_options, @@ -189,7 +186,7 @@ fn from_nothing( } } -fn open(path: impl AsRef) -> Ivf { +fn open>(path: impl AsRef) -> Ivf { let storage = O::Storage::open(path.as_ref().join("storage")); let quantization = Quantization::open(path.as_ref().join("quantization")); let payloads = MmapArray::open(path.as_ref().join("payloads")); diff --git a/crates/ivf/src/operator.rs b/crates/ivf/src/operator.rs index 6c40576ea..f99f4c0ef 100644 --- a/crates/ivf/src/operator.rs +++ b/crates/ivf/src/operator.rs @@ -4,10 +4,9 @@ use base::scalar::ScalarLike; use base::search::Vectors; use base::vector::*; use common::vec2::Vec2; -use quantization::operator::OperatorQuantization; use storage::OperatorStorage; -pub trait OperatorIvf: OperatorQuantization + OperatorStorage { +pub trait OperatorIvf: OperatorStorage { const SUPPORT: bool; type Scalar: ScalarLike; fn sample( diff --git a/crates/quantization/src/lib.rs b/crates/quantization/src/lib.rs index a44f3cc26..fe36a952e 100644 --- a/crates/quantization/src/lib.rs +++ b/crates/quantization/src/lib.rs @@ -6,18 +6,14 @@ #![allow(clippy::too_many_arguments)] pub mod fast_scan; -pub mod operator; pub mod product; pub mod quantize; +pub mod quantizer; pub mod reranker; pub mod scalar; pub mod trivial; pub mod utils; -use self::product::ProductQuantizer; -use self::scalar::ScalarQuantizer; -use crate::operator::OperatorQuantization; -use base::always_equal::AlwaysEqual; use base::distance::Distance; use base::index::*; use base::operator::*; @@ -25,202 +21,57 @@ use base::search::*; use base::vector::VectorOwned; use common::json::Json; use common::mmap_array::MmapArray; -use reranker::graph::GraphReranker; -use serde::Deserialize; -use serde::Serialize; -use std::cmp::Reverse; +use quantizer::Quantizer; +use std::marker::PhantomData; use std::ops::Range; use std::path::Path; -use trivial::TrivialQuantizer; -use utils::InfiniteByteChunks; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub enum Quantizer { - Trivial(TrivialQuantizer), - Scalar(ScalarQuantizer), - Product(ProductQuantizer), -} - -impl Quantizer { - pub fn train( - vector_options: VectorOptions, - quantization_options: QuantizationOptions, - vectors: &(impl Vectors> + Sync), - transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Send + Sync, - ) -> Self { - use QuantizationOptions::*; - match quantization_options { - Trivial(trivial_quantization_options) => Self::Trivial(TrivialQuantizer::train( - vector_options, - trivial_quantization_options, - vectors, - transform, - )), - Scalar(scalar_quantization_options) => Self::Scalar(ScalarQuantizer::train( - vector_options, - scalar_quantization_options, - vectors, - transform, - )), - Product(product_quantization_options) => Self::Product(ProductQuantizer::train( - vector_options, - product_quantization_options, - vectors, - transform, - )), - } - } -} - -pub enum QuantizationPreprocessed { - Trivial(O::TrivialQuantizationPreprocessed), - Scalar(O::QuantizationPreprocessed), - Product(O::QuantizationPreprocessed), -} - -pub struct Quantization { - train: Json>, +pub struct Quantization { + train: Json, codes: MmapArray, packed_codes: MmapArray, - #[allow(unused)] - meta: MmapArray, + _maker: PhantomData O>, } -impl Quantization { +impl> Quantization { pub fn create( path: impl AsRef, vector_options: VectorOptions, - quantization_options: QuantizationOptions, + quantization_options: Option, vectors: &(impl Vectors> + Sync), transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Send + Sync, ) -> Self { std::fs::create_dir(path.as_ref()).unwrap(); - fn merge_8([b0, b1, b2, b3, b4, b5, b6, b7]: [u8; 8]) -> u8 { - b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) - } - fn merge_4([b0, b1, b2, b3]: [u8; 4]) -> u8 { - b0 | (b1 << 2) | (b2 << 4) | (b3 << 6) - } - fn merge_2([b0, b1]: [u8; 2]) -> u8 { - b0 | (b1 << 4) - } - let train = Quantizer::train(vector_options, quantization_options, vectors, transform); + let train = Q::train(vector_options, quantization_options, vectors, transform); let train = Json::create(path.as_ref().join("train"), train); let codes = MmapArray::create(path.as_ref().join("codes"), { - match &*train { - Quantizer::Trivial(_) => { - Box::new(std::iter::empty()) as Box> - } - Quantizer::Scalar(x) => Box::new((0..vectors.len()).flat_map(|i| { - let vector = transform(vectors.vector(i)); - let codes = x.encode(vector.as_borrowed()); - let bytes = x.bytes(); - match x.bits() { - 1 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_8) - .take(bytes as usize) - .collect(), - 2 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_4) - .take(bytes as usize) - .collect(), - 4 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_2) - .take(bytes as usize) - .collect(), - 8 => codes, - _ => unreachable!(), - } - })), - Quantizer::Product(x) => Box::new((0..vectors.len()).flat_map(|i| { - let vector = transform(vectors.vector(i)); - let codes = x.encode(vector.as_borrowed()); - let bytes = x.bytes(); - match x.bits() { - 1 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_8) - .take(bytes as usize) - .collect(), - 2 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_4) - .take(bytes as usize) - .collect(), - 4 => InfiniteByteChunks::new(codes.into_iter()) - .map(merge_2) - .take(bytes as usize) - .collect(), - 8 => codes, - _ => unreachable!(), - } - })), - } + (0..vectors.len()).flat_map(|i| { + let vector = transform(vectors.vector(i)); + train.encode(vector.as_borrowed()) + }) }); - let packed_codes = MmapArray::create( - path.as_ref().join("packed_codes"), - match &*train { - Quantizer::Trivial(_) => { - Box::new(std::iter::empty()) as Box> - } - Quantizer::Scalar(x) => match x.bits() { - 4 => { - use fast_scan::b4::{pack, BLOCK_SIZE}; - let blocks = vectors.len().div_ceil(BLOCK_SIZE); - Box::new((0..blocks).flat_map(|block| { - let width = x.width(); - let n = vectors.len(); - let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { - let id = BLOCK_SIZE * block + i as u32; - x.encode( - transform(vectors.vector(std::cmp::min(id, n - 1))) - .as_borrowed(), - ) - }); - pack(width, raw) - })) as Box> - } - _ => Box::new(std::iter::empty()) as Box>, - }, - Quantizer::Product(x) => match x.bits() { - 4 => { - use fast_scan::b4::{pack, BLOCK_SIZE}; - let blocks = vectors.len().div_ceil(BLOCK_SIZE); - Box::new((0..blocks).flat_map(|block| { - let width = x.width(); - let n = vectors.len(); - let raw = std::array::from_fn::<_, { BLOCK_SIZE as _ }, _>(|i| { - let id = BLOCK_SIZE * block + i as u32; - x.encode( - transform(vectors.vector(std::cmp::min(id, n - 1))) - .as_borrowed(), - ) - }); - pack(width, raw) - })) as Box> + let packed_codes = MmapArray::create(path.as_ref().join("packed_codes"), { + let d = vectors.dims(); + let n = vectors.len(); + let m = n.div_ceil(32); + let train = &train; + (0..m).flat_map(move |alpha| { + let vectors = std::array::from_fn(|beta| { + let i = 32 * alpha + beta as u32; + if i < n { + transform(vectors.vector(i)) + } else { + O::Vector::zero(d) } - _ => Box::new(std::iter::empty()) as Box>, - }, - }, - ); - let meta = MmapArray::create( - path.as_ref().join("meta"), - match &*train { - Quantizer::Trivial(_) => { - Box::new(std::iter::empty()) as Box> - } - Quantizer::Scalar(_) => { - Box::new(std::iter::empty()) as Box> - } - Quantizer::Product(_) => { - Box::new(std::iter::empty()) as Box> - } - }, - ); + }); + train.fscan_encode(vectors) + }) + }); Self { train, codes, packed_codes, - meta, + _maker: PhantomData, } } @@ -228,135 +79,90 @@ impl Quantization { let train = Json::open(path.as_ref().join("train")); let codes = MmapArray::open(path.as_ref().join("codes")); let packed_codes = MmapArray::open(path.as_ref().join("packed_codes")); - let meta = MmapArray::open(path.as_ref().join("meta")); Self { train, codes, packed_codes, - meta, + _maker: PhantomData, } } - pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { - match &*self.train { - Quantizer::Trivial(x) => x.project(vector), - Quantizer::Scalar(x) => x.project(vector), - Quantizer::Product(x) => x.project(vector), - } + pub fn preprocess(&self, vector: Borrowed<'_, O>) -> Q::Lut { + Q::preprocess(&self.train, vector) } - pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> QuantizationPreprocessed { - match &*self.train { - Quantizer::Trivial(x) => QuantizationPreprocessed::Trivial(x.preprocess(lhs)), - Quantizer::Scalar(x) => QuantizationPreprocessed::Scalar(x.preprocess(lhs)), - Quantizer::Product(x) => QuantizationPreprocessed::Product(x.preprocess(lhs)), - } + pub fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + opts: &SearchOptions, + ) -> Result { + Q::flat_rerank_preprocess(&self.train, vector, opts) } - pub fn process( - &self, - vectors: &impl Vectors>, - preprocessed: &QuantizationPreprocessed, - u: u32, - ) -> Distance { - match (&*self.train, preprocessed) { - (Quantizer::Trivial(x), QuantizationPreprocessed::Trivial(lhs)) => { - let rhs = vectors.vector(u); - x.process(lhs, rhs) - } - (Quantizer::Scalar(x), QuantizationPreprocessed::Scalar(lhs)) => { - let bytes = x.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - let rhs = &self.codes[start..end]; - x.process(lhs, rhs) - } - (Quantizer::Product(x), QuantizationPreprocessed::Product(lhs)) => { - let bytes = x.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - let rhs = &self.codes[start..end]; - x.process(lhs, rhs) - } - _ => unreachable!(), - } + pub fn process(&self, vectors: &impl Vectors>, lut: &Q::Lut, u: u32) -> Distance { + let locate = |i| { + let code_size = self.train.code_size() as usize; + let start = i as usize * code_size; + let end = start + code_size; + &self.codes[start..end] + }; + let vector = vectors.vector(u); + Q::process(&self.train, lut, locate(u), vector) } - pub fn push_batch( + pub fn flat_rerank_continue( &self, - preprocessed: &QuantizationPreprocessed, - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - sq_fast_scan: bool, - pq_fast_scan: bool, + frlut: &Result, + range: Range, + heap: &mut Q::FlatRerankVec, ) { - match (&*self.train, preprocessed) { - (Quantizer::Trivial(x), QuantizationPreprocessed::Trivial(lhs)) => { - x.push_batch(lhs, rhs, heap) - } - (Quantizer::Scalar(x), QuantizationPreprocessed::Scalar(lhs)) => x.push_batch( - lhs, - rhs, - heap, - &self.codes, - &self.packed_codes, - sq_fast_scan, - ), - (Quantizer::Product(x), QuantizationPreprocessed::Product(lhs)) => x.push_batch( - lhs, - rhs, - heap, - &self.codes, - &self.packed_codes, - pq_fast_scan, - ), - _ => unreachable!(), - } + Q::flat_rerank_continue( + &self.train, + |i| { + let code_size = self.train.code_size() as usize; + let start = i as usize * code_size; + let end = start + code_size; + &self.codes[start..end] + }, + |i| { + let fcode_size = self.train.fcode_size() as usize; + let start = i as usize * fcode_size; + let end = start + fcode_size; + &self.packed_codes[start..end] + }, + frlut, + range, + heap, + ) } - pub fn flat_rerank<'a, T: 'a>( + pub fn flat_rerank_break<'a, 'b, T: 'a, R>( &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (Distance, T) + 'a, - sq_rerank_size: u32, - pq_rerank_size: u32, - ) -> Box + 'a> { - use Quantizer::*; - match &*self.train { - Trivial(x) => Box::new(x.flat_rerank(heap, r)), - Scalar(x) => Box::new(x.flat_rerank(heap, r, sq_rerank_size)), - Product(x) => Box::new(x.flat_rerank(heap, r, pq_rerank_size)), - } + heap: Q::FlatRerankVec, + rerank: R, + opts: &'b SearchOptions, + ) -> impl RerankerPop + 'a + use<'a, 'b, T, O, Q, R> + where + R: Fn(u32) -> (Distance, T) + 'a, + { + Q::flat_rerank_break(&self.train, heap, rerank, opts) } pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( &'a self, vector: Borrowed<'a, O>, - r: R, - ) -> GraphReranker<'a, T, R> { - use Quantizer::*; - match &*self.train { - Trivial(x) => x.graph_rerank(vector, r), - Scalar(x) => x.graph_rerank( - vector, - |u| { - let bytes = x.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - &self.codes[start..end] - }, - r, - ), - Product(x) => x.graph_rerank( - vector, - |u| { - let bytes = x.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - &self.codes[start..end] - }, - r, - ), - } + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a { + Q::graph_rerank( + &self.train, + |i| { + let code_size = self.train.code_size() as usize; + let start = i as usize * code_size; + let end = start + code_size; + &self.codes[start..end] + }, + vector, + rerank, + ) } } diff --git a/crates/quantization/src/operator.rs b/crates/quantization/src/operator.rs deleted file mode 100644 index 3d7d8136a..000000000 --- a/crates/quantization/src/operator.rs +++ /dev/null @@ -1,127 +0,0 @@ -use crate::product::operator::OperatorProductQuantization; -use crate::quantize::{dequantize, quantize}; -use crate::scalar::operator::OperatorScalarQuantization; -use crate::trivial::operator::OperatorTrivialQuantization; -use base::distance::Distance; -use base::operator::*; -use base::scalar::ScalarLike; - -pub trait OperatorQuantizationProcess: Operator { - type QuantizationPreprocessed; - - fn process( - dims: u32, - ratio: u32, - bits: u32, - preprocessed: &Self::QuantizationPreprocessed, - rhs: impl Fn(usize) -> usize, - ) -> Distance; - fn fscan_preprocess(preprocessed: &Self::QuantizationPreprocessed) -> (f32, f32, Vec); - fn fscan_process(width: u32, k: f32, b: f32, x: u16) -> Distance; -} - -impl OperatorQuantizationProcess for VectDot { - type QuantizationPreprocessed = Vec; - - fn process( - dims: u32, - ratio: u32, - bits: u32, - preprocessed: &Self::QuantizationPreprocessed, - rhs: impl Fn(usize) -> usize, - ) -> Distance { - let width = dims.div_ceil(ratio); - let xy = { - let mut xy = 0.0f32; - for i in 0..width as usize { - xy += preprocessed[i * (1 << bits) + rhs(i)]; - } - xy - }; - Distance::from(0.0f32 - xy) - } - - fn fscan_preprocess(preprocessed: &Self::QuantizationPreprocessed) -> (f32, f32, Vec) { - quantize::<255>(preprocessed) - } - - fn fscan_process(width: u32, k: f32, b: f32, x: u16) -> Distance { - Distance::from(-dequantize(width, k, b, x)) - } -} - -impl OperatorQuantizationProcess for VectL2 { - type QuantizationPreprocessed = Vec; - - fn process( - dims: u32, - ratio: u32, - bits: u32, - preprocessed: &Self::QuantizationPreprocessed, - rhs: impl Fn(usize) -> usize, - ) -> Distance { - let width = dims.div_ceil(ratio); - let mut d2 = 0.0f32; - for i in 0..width as usize { - d2 += preprocessed[i * (1 << bits) + rhs(i)]; - } - Distance::from(d2) - } - - fn fscan_preprocess(preprocessed: &Self::QuantizationPreprocessed) -> (f32, f32, Vec) { - quantize::<255>(preprocessed) - } - - fn fscan_process(width: u32, k: f32, b: f32, x: u16) -> Distance { - Distance::from(dequantize(width, k, b, x)) - } -} - -macro_rules! unimpl_operator_quantization_process { - ($t:ty) => { - impl OperatorQuantizationProcess for $t { - type QuantizationPreprocessed = std::convert::Infallible; - - fn process( - _: u32, - _: u32, - _: u32, - _: &Self::QuantizationPreprocessed, - _: impl Fn(usize) -> usize, - ) -> Distance { - unimplemented!() - } - - fn fscan_preprocess(_: &Self::QuantizationPreprocessed) -> (f32, f32, Vec) { - unimplemented!() - } - - fn fscan_process(_: u32, _: f32, _: f32, _: u16) -> Distance { - unimplemented!() - } - } - }; -} - -unimpl_operator_quantization_process!(BVectorDot); -unimpl_operator_quantization_process!(BVectorHamming); -unimpl_operator_quantization_process!(BVectorJaccard); - -unimpl_operator_quantization_process!(SVectDot); -unimpl_operator_quantization_process!(SVectL2); - -pub trait OperatorQuantization: - OperatorQuantizationProcess - + OperatorTrivialQuantization - + OperatorScalarQuantization - + OperatorProductQuantization -{ -} - -impl OperatorQuantization for BVectorDot {} -impl OperatorQuantization for BVectorJaccard {} -impl OperatorQuantization for BVectorHamming {} -impl OperatorQuantization for SVectDot {} -impl OperatorQuantization for SVectL2 {} -impl OperatorQuantization for VectDot {} -impl OperatorQuantization for VectL2 {} diff --git a/crates/quantization/src/product.rs b/crates/quantization/src/product.rs new file mode 100644 index 000000000..76c1ddadb --- /dev/null +++ b/crates/quantization/src/product.rs @@ -0,0 +1,514 @@ +use crate::fast_scan::b4::fast_scan_b4; +use crate::fast_scan::b4::pack; +use crate::quantize::quantize; +use crate::quantizer::Quantizer; +use crate::reranker::flat::WindowFlatReranker; +use crate::reranker::graph_2::Graph2Reranker; +use crate::utils::merge_2; +use crate::utils::merge_4; +use crate::utils::merge_8; +use crate::utils::InfiniteByteChunks; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::index::*; +use base::operator::*; +use base::scalar::impossible::Impossible; +use base::scalar::ScalarLike; +use base::search::*; +use base::vector::VectorOwned; +use common::sample::sample; +use common::vec2::Vec2; +use k_means::k_means; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; +use serde::Deserialize; +use serde::Serialize; +use std::cmp::Reverse; +use std::ops::Range; +use stoppable_rayon as rayon; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ProductQuantizer { + dims: u32, + ratio: u32, + bits: u32, + originals: Vec>, + centroids: Vec2, +} + +impl Quantizer for ProductQuantizer { + fn train( + vector_options: VectorOptions, + options: Option, + vectors: &(impl Vectors> + Sync), + transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Sync, + ) -> Self { + let dims = vector_options.dims; + let options = if let Some(QuantizationOptions::Product(x)) = options { + x + } else { + panic!("inconsistent parameters: options and generics") + }; + let ratio = options.ratio; + let bits = options.bits; + let width = dims.div_ceil(ratio); + let originals = (0..width) + .into_par_iter() + .map(|p| { + let subdims = std::cmp::min(ratio, dims - ratio * p); + let start = p * ratio; + let end = start + subdims; + let subsamples = sample(vectors.len(), 65536, end - start, |i| { + O::subslice( + transform(vectors.vector(i)).as_borrowed(), + start, + end - start, + ) + .to_vec() + }); + k_means(1 << bits, subsamples, false, false, true) + }) + .collect::>(); + let mut centroids = Vec2::zeros((1 << bits, dims as usize)); + for p in 0..width { + let subdims = std::cmp::min(ratio, dims - ratio * p); + for j in 0_usize..(1 << bits) { + centroids[(j,)][(p * ratio) as usize..][..subdims as usize] + .copy_from_slice(&originals[p as usize][(j,)]); + } + } + Self { + dims, + ratio, + bits, + originals, + centroids, + } + } + + fn encode(&self, vector: Borrowed<'_, O>) -> Vec { + let dims = self.dims; + let ratio = self.ratio; + let bits = self.bits; + let width = dims.div_ceil(ratio); + let bytes = (dims.div_ceil(ratio) * bits).div_ceil(8); + let mut code = Vec::with_capacity(width.div_ceil(bits) as usize); + for p in 0..width { + let subdims = std::cmp::min(ratio, dims - ratio * p); + let left = O::subslice(vector, p * ratio, subdims); + let target = k_means::k_means_lookup(left, &self.originals[p as usize]); + code.push(target as u8); + } + match bits { + 1 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_8) + .take(bytes as usize) + .collect(), + 2 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_4) + .take(bytes as usize) + .collect(), + 4 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_2) + .take(bytes as usize) + .collect(), + 8 => code, + _ => unreachable!(), + } + } + + fn fscan_encode(&self, vectors: [Owned; 32]) -> Vec { + let dims = self.dims; + let ratio = self.ratio; + let bits = self.bits; + if bits == 4 { + let width = dims.div_ceil(ratio); + let codes = vectors.map(|vector| { + let mut code = Vec::with_capacity(width.div_ceil(bits) as usize); + for p in 0..width { + let subdims = std::cmp::min(ratio, dims - ratio * p); + let left = O::subslice(vector.as_borrowed(), p * ratio, subdims); + let target = k_means::k_means_lookup(left, &self.originals[p as usize]); + code.push(target as u8); + } + code + }); + pack(width, codes).collect() + } else { + Vec::new() + } + } + + fn code_size(&self) -> u32 { + (self.dims * self.bits).div_ceil(8) + } + + fn fcode_size(&self) -> u32 { + if self.bits == 4 { + self.dims.div_ceil(self.ratio) * 16 + } else { + 0 + } + } + + type Lut = Vec; + + fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { + O::preprocess( + self.dims, + self.ratio, + self.bits, + self.centroids.as_slice(), + vector, + ) + } + + fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance { + O::process(self.dims, self.ratio, self.bits, lut, code) + } + + type FLut = ( + /* width */ u32, + /* k */ f32, + /* b */ f32, + Vec, + ); + + fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut { + O::fscan_preprocess( + self.dims, + self.ratio, + self.bits, + self.centroids.as_slice(), + vector, + ) + } + + fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { + O::fscan_process(flut, code) + } + + type FlatRerankVec = Vec<(Reverse, AlwaysEqual)>; + + fn flat_rerank_start() -> Self::FlatRerankVec { + Vec::new() + } + + fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + opts: &SearchOptions, + ) -> Result { + if opts.pq_fast_scan && self.bits == 4 { + Ok(self.fscan_preprocess(vector)) + } else { + Err(self.preprocess(vector)) + } + } + + fn flat_rerank_continue( + &self, + locate_0: impl Fn(u32) -> C, + locate_1: impl Fn(u32) -> C, + frlut: &Result, + range: Range, + heap: &mut Vec<(Reverse, AlwaysEqual)>, + ) where + C: AsRef<[u8]>, + { + match frlut { + Ok(flut) => { + fn divide(r: Range) -> (Option, Range, Option) { + if r.start > r.end || r.start % 32 == 0 && r.end % 32 == 0 { + (None, r.start / 32..r.end / 32, None) + } else if r.start / 32 == r.end / 32 { + (Some(r.start / 32), 0..0, None) + } else { + let left = (r.start % 32 != 0).then_some(r.start / 32); + let right = (r.end % 32 != 0).then_some(r.end / 32); + (left, r.start / 32 + 1..r.end / 32, right) + } + } + let (left, main, right) = divide(range.clone()); + if let Some(i) = left { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + for i in main { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + if let Some(i) = right { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + } + Err(lut) => { + for j in range { + let r = O::process(self.dims, self.ratio, self.bits, lut, locate_0(j).as_ref()); + heap.push((Reverse(r), AlwaysEqual(j))); + } + } + } + } + + fn flat_rerank_break<'a, T: 'a, R>( + &'a self, + heap: Vec<(Reverse, AlwaysEqual)>, + rerank: R, + opts: &SearchOptions, + ) -> impl RerankerPop + 'a + where + R: Fn(u32) -> (Distance, T) + 'a, + { + WindowFlatReranker::new(heap, rerank, opts.pq_rerank_size) + } + + fn graph_rerank<'a, T, R, C>( + &'a self, + locate: impl Fn(u32) -> C + 'a, + vector: Borrowed<'a, O>, + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a + where + T: 'a, + R: Fn(u32) -> (Distance, T) + 'a, + C: AsRef<[u8]>, + { + let lut = self.preprocess(vector); + Graph2Reranker::new( + move |u| self.process(&lut, locate(u).as_ref(), vector), + rerank, + ) + } +} + +pub trait OperatorProductQuantization: Operator { + type Scalar: ScalarLike; + fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar]; + + fn preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> Vec; + fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance; + fn fscan_preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec); + fn fscan_process(flut: &(u32, f32, f32, Vec), code: &[u8]) -> [Distance; 32]; +} + +impl OperatorProductQuantization for VectDot { + type Scalar = S; + fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar] { + &vector.slice()[start as usize..][..len as usize] + } + + fn preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> Vec { + let mut xy = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); + for p in 0..dims.div_ceil(ratio) { + let subdims = std::cmp::min(ratio, dims - ratio * p); + xy.extend((0_usize..1 << bits).map(|k| { + let mut xy = 0.0f32; + for i in ratio * p..ratio * p + subdims { + let x = vector.slice()[i as usize].to_f32(); + let y = centroids[(k as u32 * dims + i) as usize].to_f32(); + xy += x * y; + } + xy + })); + } + xy + } + fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { + fn internal( + dims: u32, + ratio: u32, + bits: u32, + t: &[f32], + f: impl Fn(usize) -> usize, + ) -> Distance { + let width = dims.div_ceil(ratio); + let xy = { + let mut xy = 0.0f32; + for i in 0..width as usize { + xy += t[i * (1 << bits) + f(i)]; + } + xy + }; + Distance::from(-xy) + } + match bits { + 1 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize + }), + 2 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize + }), + 4 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize + }), + 8 => internal(dims, ratio, bits, lut, |i| code[i] as usize), + _ => unreachable!(), + } + } + + fn fscan_preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, ratio, bits, centroids, vector)); + (dims.div_ceil(ratio), k, b, t) + } + fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] { + let &(width, k, b, ref t) = flut; + let r = fast_scan_b4(width, codes, t); + std::array::from_fn(|i| Distance::from(-((width as f32) * b + (r[i] as f32) * k))) + } +} + +impl OperatorProductQuantization for VectL2 { + type Scalar = S; + fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar] { + &vector.slice()[start as usize..][..len as usize] + } + + fn preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> Vec { + let mut d2 = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); + for p in 0..dims.div_ceil(ratio) { + let subdims = std::cmp::min(ratio, dims - ratio * p); + d2.extend((0_usize..1 << bits).map(|k| { + let mut d2 = 0.0f32; + for i in ratio * p..ratio * p + subdims { + let x = vector.slice()[i as usize].to_f32(); + let y = centroids[(k as u32 * dims + i) as usize].to_f32(); + let d = x - y; + d2 += d * d; + } + d2 + })); + } + d2 + } + fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { + fn internal( + dims: u32, + ratio: u32, + bits: u32, + t: &[f32], + f: impl Fn(usize) -> usize, + ) -> Distance { + let width = dims.div_ceil(ratio); + let mut d2 = 0.0f32; + for i in 0..width as usize { + d2 += t[i * (1 << bits) + f(i)]; + } + Distance::from(d2) + } + match bits { + 1 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize + }), + 2 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize + }), + 4 => internal(dims, ratio, bits, lut, |i| { + ((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize + }), + 8 => internal(dims, ratio, bits, lut, |i| code[i] as usize), + _ => unreachable!(), + } + } + + fn fscan_preprocess( + dims: u32, + ratio: u32, + bits: u32, + centroids: &[Self::Scalar], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, ratio, bits, centroids, vector)); + (dims.div_ceil(ratio), k, b, t) + } + fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] { + let &(width, k, b, ref t) = flut; + let r = fast_scan_b4(width, codes, t); + std::array::from_fn(|i| Distance::from((width as f32) * b + (r[i] as f32) * k)) + } +} + +macro_rules! unimpl_operator_product_quantization { + ($t:ty) => { + impl OperatorProductQuantization for $t { + type Scalar = Impossible; + fn subslice(_: Borrowed<'_, Self>, _: u32, _: u32) -> &[Self::Scalar] { + unimplemented!() + } + + fn preprocess( + _: u32, + _: u32, + _: u32, + _: &[Self::Scalar], + _: Borrowed<'_, Self>, + ) -> Vec { + unimplemented!() + } + fn process(_: u32, _: u32, _: u32, _: &[f32], _: &[u8]) -> Distance { + unimplemented!() + } + + fn fscan_preprocess( + _: u32, + _: u32, + _: u32, + _: &[Self::Scalar], + _: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + unimplemented!() + } + fn fscan_process(_: &(u32, f32, f32, Vec), _: &[u8]) -> [Distance; 32] { + unimplemented!() + } + } + }; +} + +unimpl_operator_product_quantization!(BVectorDot); +unimpl_operator_product_quantization!(BVectorHamming); +unimpl_operator_product_quantization!(BVectorJaccard); + +unimpl_operator_product_quantization!(SVectDot); +unimpl_operator_product_quantization!(SVectL2); diff --git a/crates/quantization/src/product/mod.rs b/crates/quantization/src/product/mod.rs deleted file mode 100644 index fcb0cffb9..000000000 --- a/crates/quantization/src/product/mod.rs +++ /dev/null @@ -1,229 +0,0 @@ -pub mod operator; - -use self::operator::OperatorProductQuantization; -use crate::reranker::flat::WindowFlatReranker; -use crate::reranker::graph::GraphReranker; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::index::*; -use base::operator::*; -use base::search::*; -use base::vector::VectorBorrowed; -use base::vector::VectorOwned; -use common::sample::sample; -use common::vec2::Vec2; -use k_means::k_means; -use rayon::iter::{IntoParallelIterator, ParallelIterator}; -use serde::Deserialize; -use serde::Serialize; -use std::cmp::Reverse; -use std::ops::Range; -use stoppable_rayon as rayon; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct ProductQuantizer { - dims: u32, - ratio: u32, - bits: u32, - originals: Vec>, - centroids: Vec2, -} - -impl ProductQuantizer { - pub fn train( - vector_options: VectorOptions, - product_quantization_options: ProductQuantizationOptions, - vectors: &(impl Vectors> + Sync), - transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Send + Sync, - ) -> Self { - let dims = vector_options.dims; - let ratio = product_quantization_options.ratio; - let bits = product_quantization_options.bits; - let width = dims.div_ceil(ratio); - let originals = (0..width) - .into_par_iter() - .map(|p| { - let subdims = std::cmp::min(ratio, dims - ratio * p); - let start = p * ratio; - let end = start + subdims; - let subsamples = sample(vectors.len(), 65536, end - start, |i| { - O::subslice( - transform(vectors.vector(i)).as_borrowed(), - start, - end - start, - ) - .to_vec() - }); - k_means(1 << bits, subsamples, false, false, true) - }) - .collect::>(); - let mut centroids = Vec2::zeros((1 << bits, dims as usize)); - for p in 0..width { - let subdims = std::cmp::min(ratio, dims - ratio * p); - for j in 0_usize..(1 << bits) { - centroids[(j,)][(p * ratio) as usize..][..subdims as usize] - .copy_from_slice(&originals[p as usize][(j,)]); - } - } - Self { - dims, - ratio, - bits, - originals, - centroids, - } - } - - pub fn bits(&self) -> u32 { - self.bits - } - - pub fn bytes(&self) -> u32 { - (self.dims.div_ceil(self.ratio) * self.bits).div_ceil(8) - } - - pub fn width(&self) -> u32 { - self.dims.div_ceil(self.ratio) - } - - pub fn encode(&self, vector: Borrowed<'_, O>) -> Vec { - let dims = self.dims; - let ratio = self.ratio; - let width = dims.div_ceil(ratio); - let mut codes = Vec::with_capacity(width.div_ceil(self.bits) as usize); - for p in 0..width { - let subdims = std::cmp::min(ratio, dims - ratio * p); - let left = O::subslice(vector, p * ratio, subdims); - let target = k_means::k_means_lookup(left, &self.originals[p as usize]); - codes.push(target as u8); - } - codes - } - - pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { - vector.own() - } - - pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::QuantizationPreprocessed { - O::product_quantization_preprocess( - self.dims, - self.ratio, - self.bits, - self.centroids.as_slice(), - lhs, - ) - } - - pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> Distance { - let dims = self.dims; - let ratio = self.ratio; - match self.bits { - 1 => O::process(dims, ratio, 1, preprocessed, |i| { - ((rhs[i >> 3] >> ((i & 7) << 0)) & 1) as usize - }), - 2 => O::process(dims, ratio, 2, preprocessed, |i| { - ((rhs[i >> 2] >> ((i & 3) << 1)) & 3) as usize - }), - 4 => O::process(dims, ratio, 4, preprocessed, |i| { - ((rhs[i >> 1] >> ((i & 1) << 2)) & 15) as usize - }), - 8 => O::process(dims, ratio, 8, preprocessed, |i| rhs[i] as usize), - _ => unreachable!(), - } - } - - pub fn push_batch( - &self, - preprocessed: &O::QuantizationPreprocessed, - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - codes: &[u8], - packed_codes: &[u8], - fast_scan: bool, - ) { - let dims = self.dims; - let ratio = self.ratio; - let width = dims.div_ceil(ratio); - if fast_scan && self.bits == 4 { - use crate::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let (k, b, lut) = O::fscan_preprocess(preprocessed); - let s = rhs.start.next_multiple_of(BLOCK_SIZE); - let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - if rhs.start != s { - let i = s - BLOCK_SIZE; - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (rhs.start..s).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - for i in (s..e).step_by(BLOCK_SIZE as _) { - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (i..i + BLOCK_SIZE).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - if e != rhs.end { - let i = e; - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (e..rhs.end).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - return; - } - heap.extend(rhs.map(|u| { - ( - Reverse(self.process(preprocessed, { - let bytes = self.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - &codes[start..end] - })), - AlwaysEqual(u), - ) - })); - } - - pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( - &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: R, - rerank_size: u32, - ) -> impl RerankerPop + 'a { - WindowFlatReranker::new(heap, r, rerank_size) - } - - pub fn graph_rerank< - 'a, - T: 'a, - C: Fn(u32) -> &'a [u8] + 'a, - R: Fn(u32) -> (Distance, T) + 'a, - >( - &'a self, - vector: Borrowed<'a, O>, - c: C, - r: R, - ) -> GraphReranker<'a, T, R> { - let p = O::product_quantization_preprocess( - self.dims, - self.ratio, - self.bits, - self.centroids.as_slice(), - vector, - ); - GraphReranker::new(Some(Box::new(move |u| self.process(&p, c(u)))), r) - } -} diff --git a/crates/quantization/src/product/operator.rs b/crates/quantization/src/product/operator.rs deleted file mode 100644 index 1c7deb893..000000000 --- a/crates/quantization/src/product/operator.rs +++ /dev/null @@ -1,101 +0,0 @@ -use crate::operator::OperatorQuantizationProcess; -use base::scalar::impossible::Impossible; -use base::{operator::*, scalar::ScalarLike}; - -pub trait OperatorProductQuantization: OperatorQuantizationProcess { - type Scalar: ScalarLike; - fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar]; - fn product_quantization_preprocess( - dims: u32, - ratio: u32, - bits: u32, - centroids: &[Self::Scalar], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed; -} - -impl OperatorProductQuantization for VectDot { - type Scalar = S; - fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar] { - &vector.slice()[start as usize..][..len as usize] - } - fn product_quantization_preprocess( - dims: u32, - ratio: u32, - bits: u32, - centroids: &[Self::Scalar], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - let mut xy = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); - for p in 0..dims.div_ceil(ratio) { - let subdims = std::cmp::min(ratio, dims - ratio * p); - xy.extend((0_usize..1 << bits).map(|k| { - let mut xy = 0.0f32; - for i in ratio * p..ratio * p + subdims { - let x = lhs.slice()[i as usize].to_f32(); - let y = centroids[(k as u32 * dims + i) as usize].to_f32(); - xy += x * y; - } - xy - })); - } - xy - } -} - -impl OperatorProductQuantization for VectL2 { - type Scalar = S; - fn subslice(vector: Borrowed<'_, Self>, start: u32, len: u32) -> &[Self::Scalar] { - &vector.slice()[start as usize..][..len as usize] - } - fn product_quantization_preprocess( - dims: u32, - ratio: u32, - bits: u32, - centroids: &[Self::Scalar], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - let mut d2 = Vec::with_capacity((dims.div_ceil(ratio) as usize) * (1 << bits)); - for p in 0..dims.div_ceil(ratio) { - let subdims = std::cmp::min(ratio, dims - ratio * p); - d2.extend((0_usize..1 << bits).map(|k| { - let mut d2 = 0.0f32; - for i in ratio * p..ratio * p + subdims { - let x = lhs.slice()[i as usize].to_f32(); - let y = centroids[(k as u32 * dims + i) as usize].to_f32(); - let d = x - y; - d2 += d * d; - } - d2 - })); - } - d2 - } -} - -macro_rules! unimpl_operator_product_quantization { - ($t:ty) => { - impl OperatorProductQuantization for $t { - type Scalar = Impossible; - fn subslice(_: Borrowed<'_, Self>, _: u32, _: u32) -> &[Self::Scalar] { - unimplemented!() - } - fn product_quantization_preprocess( - _: u32, - _: u32, - _: u32, - _: &[Self::Scalar], - _: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - unimplemented!() - } - } - }; -} - -unimpl_operator_product_quantization!(BVectorDot); -unimpl_operator_product_quantization!(BVectorHamming); -unimpl_operator_product_quantization!(BVectorJaccard); - -unimpl_operator_product_quantization!(SVectDot); -unimpl_operator_product_quantization!(SVectL2); diff --git a/crates/quantization/src/quantize.rs b/crates/quantization/src/quantize.rs index 1546b9958..d7e1ec41c 100644 --- a/crates/quantization/src/quantize.rs +++ b/crates/quantization/src/quantize.rs @@ -374,11 +374,6 @@ pub fn quantize(lut: &[f32]) -> (f32, f32, Vec) { (k, b, mul_add_round::mul_add_round(lut, 1.0 / k, -b / k)) } -#[inline(always)] -pub fn dequantize(sum_1: u32, k: f32, b: f32, sum_x: u16) -> f32 { - (sum_1 as f32) * b + (sum_x as f32) * k -} - #[inline(always)] pub fn reduce_sum_of_x_as_u16(vector: &[u8]) -> u16 { reduce_sum_of_x_as_u16::reduce_sum_of_x_as_u16(vector) diff --git a/crates/quantization/src/quantizer.rs b/crates/quantization/src/quantizer.rs new file mode 100644 index 000000000..0b27533e8 --- /dev/null +++ b/crates/quantization/src/quantizer.rs @@ -0,0 +1,71 @@ +use base::distance::Distance; +use base::index::{QuantizationOptions, SearchOptions, VectorOptions}; +use base::operator::Operator; +use base::operator::{Borrowed, Owned}; +use base::search::{RerankerPop, RerankerPush, Vectors}; +use serde::{Deserialize, Serialize}; +use std::ops::Range; + +pub trait Quantizer: + Serialize + for<'a> Deserialize<'a> + Send + Sync + 'static +{ + fn train( + vector_options: VectorOptions, + options: Option, + vectors: &(impl Vectors> + Sync), + transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy + Sync, + ) -> Self; + + fn encode(&self, vector: Borrowed<'_, O>) -> Vec; + fn fscan_encode(&self, vectors: [Owned; 32]) -> Vec; + fn code_size(&self) -> u32; + fn fcode_size(&self) -> u32; + + type Lut; + fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut; + fn process(&self, lut: &Self::Lut, code: &[u8], vector: Borrowed<'_, O>) -> Distance; + + type FLut; + fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut; + fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32]; + + type FlatRerankVec; + + fn flat_rerank_start() -> Self::FlatRerankVec; + + fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + opts: &SearchOptions, + ) -> Result; + + fn flat_rerank_continue( + &self, + locate_0: impl Fn(u32) -> C, + locate_1: impl Fn(u32) -> C, + frlut: &Result, + range: Range, + heap: &mut Self::FlatRerankVec, + ) where + C: AsRef<[u8]>; + + fn flat_rerank_break<'a, T: 'a, R>( + &'a self, + heap: Self::FlatRerankVec, + rerank: R, + opts: &SearchOptions, + ) -> impl RerankerPop + 'a + where + R: Fn(u32) -> (Distance, T) + 'a; + + fn graph_rerank<'a, T, R, C>( + &'a self, + locate: impl Fn(u32) -> C + 'a, + vector: Borrowed<'a, O>, + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a + where + T: 'a, + R: Fn(u32) -> (Distance, T) + 'a, + C: AsRef<[u8]>; +} diff --git a/crates/quantization/src/reranker/flat.rs b/crates/quantization/src/reranker/flat.rs index 202dee755..a73d7be77 100644 --- a/crates/quantization/src/reranker/flat.rs +++ b/crates/quantization/src/reranker/flat.rs @@ -4,34 +4,6 @@ use base::search::*; use std::cmp::Reverse; use std::collections::BinaryHeap; -pub struct DisabledFlatReranker { - heap: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, -} - -impl DisabledFlatReranker { - pub fn new(heap: Vec<(Reverse, AlwaysEqual)>, rerank: R) -> Self - where - R: Fn(u32) -> (Distance, T), - { - Self { - heap: heap - .into_iter() - .map(|(_, AlwaysEqual(u))| { - let (dis_u, pay_u) = rerank(u); - (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) - }) - .collect(), - } - } -} - -impl RerankerPop for DisabledFlatReranker { - fn pop(&mut self) -> Option<(Distance, u32, T)> { - let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.heap.pop()?; - Some((dis_u, u, pay_u)) - } -} - pub struct WindowFlatReranker { rerank: R, size: u32, diff --git a/crates/quantization/src/reranker/graph.rs b/crates/quantization/src/reranker/graph.rs index 1a513e48a..6967a7210 100644 --- a/crates/quantization/src/reranker/graph.rs +++ b/crates/quantization/src/reranker/graph.rs @@ -4,52 +4,37 @@ use base::search::*; use std::cmp::Reverse; use std::collections::BinaryHeap; -pub struct GraphReranker<'a, T, R> { - compute: Option Distance + 'a>>, +pub struct GraphReranker { rerank: R, - heap: BinaryHeap<(Reverse, AlwaysEqual)>, cache: BinaryHeap<(Reverse, AlwaysEqual, AlwaysEqual)>, } -impl<'a, T, R> GraphReranker<'a, T, R> { - pub fn new(compute: Option Distance + 'a>>, rerank: R) -> Self { +impl GraphReranker { + pub fn new(rerank: R) -> Self { Self { - compute, rerank, - heap: BinaryHeap::new(), cache: BinaryHeap::new(), } } } -impl<'a, T, R> RerankerPop for GraphReranker<'a, T, R> +impl RerankerPop for GraphReranker where R: Fn(u32) -> (Distance, T), { fn pop(&mut self) -> Option<(Distance, u32, T)> { - if self.compute.is_some() { - let (_, AlwaysEqual(u)) = self.heap.pop()?; - let (dis_u, pay_u) = (self.rerank)(u); - Some((dis_u, u, pay_u)) - } else { - let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.cache.pop()?; - Some((dis_u, u, pay_u)) - } + let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.cache.pop()?; + Some((dis_u, u, pay_u)) } } -impl<'a, T, R> RerankerPush for GraphReranker<'a, T, R> +impl RerankerPush for GraphReranker where R: Fn(u32) -> (Distance, T), { fn push(&mut self, u: u32) { - if let Some(compute) = self.compute.as_ref() { - let rough_u = (compute)(u); - self.heap.push((Reverse(rough_u), AlwaysEqual(u))); - } else { - let (dis_u, pay_u) = (self.rerank)(u); - self.cache - .push((Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u))); - } + let (dis_u, pay_u) = (self.rerank)(u); + self.cache + .push((Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u))); } } diff --git a/crates/quantization/src/reranker/graph_2.rs b/crates/quantization/src/reranker/graph_2.rs new file mode 100644 index 000000000..3768b57ea --- /dev/null +++ b/crates/quantization/src/reranker/graph_2.rs @@ -0,0 +1,43 @@ +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::search::*; +use std::cmp::Reverse; +use std::collections::BinaryHeap; + +pub struct Graph2Reranker { + compute: F, + rerank: R, + heap: BinaryHeap<(Reverse, AlwaysEqual)>, +} + +impl Graph2Reranker { + pub fn new(compute: F, rerank: R) -> Self { + Self { + compute, + rerank, + heap: BinaryHeap::new(), + } + } +} + +impl RerankerPop for Graph2Reranker +where + R: Fn(u32) -> (Distance, T), +{ + fn pop(&mut self) -> Option<(Distance, u32, T)> { + let (_, AlwaysEqual(u)) = self.heap.pop()?; + let (dis_u, pay_u) = (self.rerank)(u); + Some((dis_u, u, pay_u)) + } +} + +impl RerankerPush for Graph2Reranker +where + F: Fn(u32) -> Distance, + R: Fn(u32) -> (Distance, T), +{ + fn push(&mut self, u: u32) { + let rough_u = (self.compute)(u); + self.heap.push((Reverse(rough_u), AlwaysEqual(u))); + } +} diff --git a/crates/quantization/src/reranker/mod.rs b/crates/quantization/src/reranker/mod.rs index cd5917ee9..17dabb167 100644 --- a/crates/quantization/src/reranker/mod.rs +++ b/crates/quantization/src/reranker/mod.rs @@ -1,2 +1,3 @@ pub mod flat; pub mod graph; +pub mod graph_2; diff --git a/crates/quantization/src/scalar.rs b/crates/quantization/src/scalar.rs new file mode 100644 index 000000000..f6745121a --- /dev/null +++ b/crates/quantization/src/scalar.rs @@ -0,0 +1,473 @@ +use crate::fast_scan::b4::fast_scan_b4; +use crate::fast_scan::b4::pack; +use crate::quantize::quantize; +use crate::quantizer::Quantizer; +use crate::reranker::flat::WindowFlatReranker; +use crate::reranker::graph_2::Graph2Reranker; +use crate::utils::merge_2; +use crate::utils::merge_4; +use crate::utils::merge_8; +use crate::utils::InfiniteByteChunks; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::index::*; +use base::operator::*; +use base::scalar::impossible::Impossible; +use base::scalar::ScalarLike; +use base::search::RerankerPop; +use base::search::RerankerPush; +use base::search::Vectors; +use base::vector::*; +use common::vec2::Vec2; +use serde::Deserialize; +use serde::Serialize; +use std::cmp::Reverse; +use std::marker::PhantomData; +use std::ops::Range; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct ScalarQuantizer { + dims: u32, + bits: u32, + max: Vec, + min: Vec, + centroids: Vec2, + _phantom: PhantomData O>, +} + +impl Quantizer for ScalarQuantizer { + fn train( + vector_options: VectorOptions, + options: Option, + vectors: &impl Vectors>, + transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy, + ) -> Self { + let options = if let Some(QuantizationOptions::Scalar(x)) = options { + x + } else { + panic!("inconsistent parameters: options and generics") + }; + let dims = vector_options.dims; + let bits = options.bits; + let mut max = vec![f32::NEG_INFINITY; dims as usize]; + let mut min = vec![f32::INFINITY; dims as usize]; + let n = vectors.len(); + for i in 0..n { + let vector = transform(vectors.vector(i)); + let vector = vector.as_borrowed(); + for j in 0..dims { + min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32()); + max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32()); + } + } + let mut centroids = Vec2::zeros((1 << bits, dims as usize)); + for p in 0..dims { + let bas = min[p as usize]; + let del = max[p as usize] - min[p as usize]; + for j in 0_usize..(1 << bits) { + let val = j as f32 / ((1 << bits) - 1) as f32; + centroids[(j, p as usize)] = bas + val * del; + } + } + Self { + dims, + bits, + max, + min, + centroids, + _phantom: PhantomData, + } + } + + fn encode(&self, vector: Borrowed<'_, O>) -> Vec { + let dims = self.dims; + let bits = self.bits; + let code_size = (dims * bits).div_ceil(8); + let mut code = Vec::with_capacity(dims as usize); + for i in 0..dims { + let del = self.max[i as usize] - self.min[i as usize]; + let w = (((O::get(vector, i).to_f32() - self.min[i as usize]) / del).to_f32() + * (((1 << bits) - 1) as f32)) as u32; + code.push(w.clamp(0, 255) as u8); + } + match bits { + 1 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_8) + .take(code_size as usize) + .collect(), + 2 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_4) + .take(code_size as usize) + .collect(), + 4 => InfiniteByteChunks::new(code.into_iter()) + .map(merge_2) + .take(code_size as usize) + .collect(), + 8 => code, + _ => unreachable!(), + } + } + + fn fscan_encode(&self, vectors: [Owned; 32]) -> Vec { + let dims = self.dims; + let bits = self.bits; + if bits == 4 { + let codes = vectors.map(|vector| { + let mut code = Vec::with_capacity(dims as usize); + for i in 0..dims { + let del = self.max[i as usize] - self.min[i as usize]; + let w = (((O::get(vector.as_borrowed(), i).to_f32() - self.min[i as usize]) + / del) + .to_f32() + * (((1 << bits) - 1) as f32)) as u32; + code.push(w.clamp(0, 255) as u8); + } + code + }); + pack(dims, codes).collect() + } else { + Vec::new() + } + } + + fn code_size(&self) -> u32 { + (self.dims * self.bits).div_ceil(8) + } + + fn fcode_size(&self) -> u32 { + if self.bits == 4 { + self.dims * 16 + } else { + 0 + } + } + + type Lut = Vec; + + fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { + O::preprocess(self.dims, self.bits, &self.max, &self.min, vector) + } + + fn process(&self, lut: &Self::Lut, code: &[u8], _: Borrowed<'_, O>) -> Distance { + O::process(self.dims, self.bits, lut, code) + } + + type FLut = ( + /* dims */ u32, + /* k */ f32, + /* b */ f32, + Vec, + ); + + fn fscan_preprocess(&self, vector: Borrowed<'_, O>) -> Self::FLut { + O::fscan_preprocess(self.dims, self.bits, &self.max, &self.min, vector) + } + + fn fscan_process(flut: &Self::FLut, code: &[u8]) -> [Distance; 32] { + O::fscan_process(flut, code) + } + + type FlatRerankVec = Vec<(Reverse, AlwaysEqual)>; + + fn flat_rerank_start() -> Self::FlatRerankVec { + Vec::new() + } + + fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + opts: &SearchOptions, + ) -> Result { + if opts.sq_fast_scan && self.bits == 4 { + Ok(self.fscan_preprocess(vector)) + } else { + Err(self.preprocess(vector)) + } + } + + fn flat_rerank_break<'a, T: 'a, R>( + &'a self, + heap: Vec<(Reverse, AlwaysEqual)>, + rerank: R, + opts: &SearchOptions, + ) -> impl RerankerPop + 'a + where + R: Fn(u32) -> (Distance, T) + 'a, + { + WindowFlatReranker::new(heap, rerank, opts.sq_rerank_size) + } + + fn flat_rerank_continue( + &self, + locate_0: impl Fn(u32) -> C, + locate_1: impl Fn(u32) -> C, + frlut: &Result, + range: Range, + heap: &mut Vec<(Reverse, AlwaysEqual)>, + ) where + C: AsRef<[u8]>, + { + match frlut { + Ok(flut) => { + fn divide(r: Range) -> (Option, Range, Option) { + if r.start > r.end || r.start % 32 == 0 && r.end % 32 == 0 { + (None, r.start / 32..r.end / 32, None) + } else if r.start / 32 == r.end / 32 { + (Some(r.start / 32), 0..0, None) + } else { + let left = (r.start % 32 != 0).then_some(r.start / 32); + let right = (r.end % 32 != 0).then_some(r.end / 32); + (left, r.start / 32 + 1..r.end / 32, right) + } + } + let (left, main, right) = divide(range.clone()); + if let Some(i) = left { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + for i in main { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + if let Some(i) = right { + let r = Self::fscan_process(flut, locate_1(i).as_ref()); + for j in 0..32 { + if range.contains(&(i * 32 + j)) { + heap.push((Reverse(r[j as usize]), AlwaysEqual(i * 32 + j))); + } + } + } + } + Err(lut) => { + for j in range { + let r = O::process(self.dims, self.bits, lut, locate_0(j).as_ref()); + heap.push((Reverse(r), AlwaysEqual(j))); + } + } + } + } + + fn graph_rerank<'a, T, R, C>( + &'a self, + locate: impl Fn(u32) -> C + 'a, + vector: Borrowed<'a, O>, + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a + where + T: 'a, + R: Fn(u32) -> (Distance, T) + 'a, + C: AsRef<[u8]>, + { + let lut = self.preprocess(vector); + Graph2Reranker::new( + move |u| self.process(&lut, locate(u).as_ref(), vector), + rerank, + ) + } +} + +pub trait OperatorScalarQuantization: Operator { + type Scalar: ScalarLike; + fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar; + + fn preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> Vec; + fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance; + + fn fscan_preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec); + fn fscan_process(flut: &(u32, f32, f32, Vec), code: &[u8]) -> [Distance; 32]; +} + +impl OperatorScalarQuantization for VectDot { + type Scalar = S; + fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar { + vector.slice()[i as usize] + } + + fn preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> Vec { + let mut xy = Vec::with_capacity(dims as _); + for i in 0..dims { + let bas = min[i as usize]; + let del = max[i as usize] - min[i as usize]; + xy.extend((0..1 << bits).map(|k| { + let x = vector.slice()[i as usize].to_f32(); + let val = k as f32 / ((1 << bits) - 1) as f32; + let y = bas + val * del; + x * y + })); + } + xy + } + fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { + fn internal(dims: u32, bits: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { + let width = dims; + let xy = { + let mut xy = 0.0f32; + for i in 0..width as usize { + xy += t[i * (1 << bits) + f(i)]; + } + xy + }; + Distance::from(-xy) + } + match bits { + 1 => internal(dims, bits, lut, |i| { + ((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize + }), + 2 => internal(dims, bits, lut, |i| { + ((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize + }), + 4 => internal(dims, bits, lut, |i| { + ((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize + }), + 8 => internal(dims, bits, lut, |i| rhs[i] as usize), + _ => unreachable!(), + } + } + + fn fscan_preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector)); + (dims, k, b, t) + } + fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] { + let &(dims, k, b, ref t) = flut; + let r = fast_scan_b4(dims, codes, t); + std::array::from_fn(|i| Distance::from(-((dims as f32) * b + (r[i] as f32) * k))) + } +} + +impl OperatorScalarQuantization for VectL2 { + type Scalar = S; + fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar { + vector.slice()[i as usize] + } + + fn preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> Vec { + let mut d2 = Vec::with_capacity(dims as _); + for i in 0..dims { + let bas = min[i as usize]; + let del = max[i as usize] - min[i as usize]; + d2.extend((0..1 << bits).map(|k| { + let x = vector.slice()[i as usize].to_f32(); + let val = k as f32 / ((1 << bits) - 1) as f32; + let y = bas + val * del; + let d = x - y; + d * d + })); + } + d2 + } + fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { + fn internal(dims: u32, bits: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { + let width = dims; + let mut d2 = 0.0f32; + for i in 0..width as usize { + d2 += t[i * (1 << bits) + f(i)]; + } + Distance::from(d2) + } + match bits { + 1 => internal(dims, bits, lut, |i| { + ((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize + }), + 2 => internal(dims, bits, lut, |i| { + ((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize + }), + 4 => internal(dims, bits, lut, |i| { + ((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize + }), + 8 => internal(dims, bits, lut, |i| rhs[i] as usize), + _ => unreachable!(), + } + } + + fn fscan_preprocess( + dims: u32, + bits: u32, + max: &[f32], + min: &[f32], + vector: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + let (k, b, t) = quantize::<255>(&Self::preprocess(dims, bits, max, min, vector)); + (dims, k, b, t) + } + fn fscan_process(flut: &(u32, f32, f32, Vec), codes: &[u8]) -> [Distance; 32] { + let &(dims, k, b, ref t) = flut; + let r = fast_scan_b4(dims, codes, t); + std::array::from_fn(|i| Distance::from((dims as f32) * b + (r[i] as f32) * k)) + } +} + +macro_rules! unimpl_operator_scalar_quantization { + ($t:ty) => { + impl OperatorScalarQuantization for $t { + type Scalar = Impossible; + fn get(_: Borrowed<'_, Self>, _: u32) -> Self::Scalar { + unimplemented!() + } + + fn preprocess(_: u32, _: u32, _: &[f32], _: &[f32], _: Borrowed<'_, Self>) -> Vec { + unimplemented!() + } + fn process(_: u32, _: u32, _: &[f32], _: &[u8]) -> Distance { + unimplemented!() + } + + fn fscan_preprocess( + _: u32, + _: u32, + _: &[f32], + _: &[f32], + _: Borrowed<'_, Self>, + ) -> (u32, f32, f32, Vec) { + unimplemented!() + } + fn fscan_process(_: &(u32, f32, f32, Vec), _: &[u8]) -> [Distance; 32] { + unimplemented!() + } + } + }; +} + +unimpl_operator_scalar_quantization!(BVectorDot); +unimpl_operator_scalar_quantization!(BVectorHamming); +unimpl_operator_scalar_quantization!(BVectorJaccard); + +unimpl_operator_scalar_quantization!(SVectDot); +unimpl_operator_scalar_quantization!(SVectL2); diff --git a/crates/quantization/src/scalar/mod.rs b/crates/quantization/src/scalar/mod.rs deleted file mode 100644 index dcf722beb..000000000 --- a/crates/quantization/src/scalar/mod.rs +++ /dev/null @@ -1,208 +0,0 @@ -pub mod operator; - -use self::operator::OperatorScalarQuantization; -use crate::reranker::flat::WindowFlatReranker; -use crate::reranker::graph::GraphReranker; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::index::*; -use base::operator::*; -use base::scalar::*; -use base::search::RerankerPop; -use base::search::Vectors; -use base::vector::*; -use common::vec2::Vec2; -use serde::Deserialize; -use serde::Serialize; -use std::cmp::Reverse; -use std::marker::PhantomData; -use std::ops::Range; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct ScalarQuantizer { - dims: u32, - bits: u32, - max: Vec, - min: Vec, - centroids: Vec2, - _phantom: PhantomData O>, -} - -impl ScalarQuantizer { - pub fn train( - vector_options: VectorOptions, - scalar_quantization_options: ScalarQuantizationOptions, - vectors: &impl Vectors>, - transform: impl Fn(Borrowed<'_, O>) -> Owned + Copy, - ) -> Self { - let dims = vector_options.dims; - let bits = scalar_quantization_options.bits; - let mut max = vec![f32::NEG_INFINITY; dims as usize]; - let mut min = vec![f32::INFINITY; dims as usize]; - let n = vectors.len(); - for i in 0..n { - let vector = transform(vectors.vector(i)); - let vector = vector.as_borrowed(); - for j in 0..dims { - min[j as usize] = min[j as usize].min(O::get(vector, j).to_f32()); - max[j as usize] = max[j as usize].max(O::get(vector, j).to_f32()); - } - } - let mut centroids = Vec2::zeros((1 << bits, dims as usize)); - for p in 0..dims { - let bas = min[p as usize]; - let del = max[p as usize] - min[p as usize]; - for j in 0_usize..(1 << bits) { - let val = j as f32 / ((1 << bits) - 1) as f32; - centroids[(j, p as usize)] = bas + val * del; - } - } - Self { - dims, - bits, - max, - min, - centroids, - _phantom: PhantomData, - } - } - - pub fn bits(&self) -> u32 { - self.bits - } - - pub fn bytes(&self) -> u32 { - (self.dims * self.bits).div_ceil(8) - } - - pub fn width(&self) -> u32 { - self.dims - } - - pub fn encode(&self, vector: Borrowed<'_, O>) -> Vec { - let dims = self.dims; - let bits = self.bits; - let mut codes = Vec::with_capacity(dims as usize); - for i in 0..dims { - let del = self.max[i as usize] - self.min[i as usize]; - let w = (((O::get(vector, i).to_f32() - self.min[i as usize]) / del).to_f32() - * (((1 << bits) - 1) as f32)) as u32; - codes.push(w.clamp(0, 255) as u8); - } - codes - } - - pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { - vector.own() - } - - pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::QuantizationPreprocessed { - O::scalar_quantization_preprocess(self.dims, self.bits, &self.max, &self.min, lhs) - } - - pub fn process(&self, preprocessed: &O::QuantizationPreprocessed, rhs: &[u8]) -> Distance { - let dims = self.dims; - match self.bits { - 1 => O::process(dims, 1, 1, preprocessed, |i| { - ((rhs[i >> 3] >> ((i & 7) << 0)) & 1) as usize - }), - 2 => O::process(dims, 1, 2, preprocessed, |i| { - ((rhs[i >> 2] >> ((i & 3) << 1)) & 3) as usize - }), - 4 => O::process(dims, 1, 4, preprocessed, |i| { - ((rhs[i >> 1] >> ((i & 1) << 2)) & 15) as usize - }), - 8 => O::process(dims, 1, 8, preprocessed, |i| rhs[i] as usize), - _ => unreachable!(), - } - } - - pub fn push_batch( - &self, - preprocessed: &O::QuantizationPreprocessed, - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - codes: &[u8], - packed_codes: &[u8], - fast_scan: bool, - ) { - let dims = self.dims; - let width = dims; - if fast_scan && self.bits == 4 { - use crate::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let (k, b, lut) = O::fscan_preprocess(preprocessed); - let s = rhs.start.next_multiple_of(BLOCK_SIZE); - let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - if rhs.start != s { - let i = s - BLOCK_SIZE; - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (rhs.start..s).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - for i in (s..e).step_by(BLOCK_SIZE as _) { - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (i..i + BLOCK_SIZE).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - if e != rhs.end { - let i = e; - let bytes = width as usize * 16; - let start = (i / BLOCK_SIZE) as usize * bytes; - let end = start + bytes; - let res = fast_scan_b4(width, &packed_codes[start..end], &lut); - let r = res.map(|x| O::fscan_process(width, k, b, x)); - heap.extend({ - (e..rhs.end).map(|u| (Reverse(r[(u - i) as usize]), AlwaysEqual(u))) - }); - } - return; - } - heap.extend(rhs.map(|u| { - ( - Reverse(self.process(preprocessed, { - let bytes = self.bytes() as usize; - let start = u as usize * bytes; - let end = start + bytes; - &codes[start..end] - })), - AlwaysEqual(u), - ) - })); - } - - pub fn flat_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( - &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: R, - rerank_size: u32, - ) -> impl RerankerPop + 'a { - WindowFlatReranker::new(heap, r, rerank_size) - } - - pub fn graph_rerank< - 'a, - T: 'a, - C: Fn(u32) -> &'a [u8] + 'a, - R: Fn(u32) -> (Distance, T) + 'a, - >( - &'a self, - vector: Borrowed<'a, O>, - c: C, - r: R, - ) -> GraphReranker<'a, T, R> { - let p = - O::scalar_quantization_preprocess(self.dims, self.bits, &self.max, &self.min, vector); - GraphReranker::new(Some(Box::new(move |u| self.process(&p, c(u)))), r) - } -} diff --git a/crates/quantization/src/scalar/operator.rs b/crates/quantization/src/scalar/operator.rs deleted file mode 100644 index 654b7e1b4..000000000 --- a/crates/quantization/src/scalar/operator.rs +++ /dev/null @@ -1,98 +0,0 @@ -use crate::operator::OperatorQuantizationProcess; -use base::operator::*; -use base::scalar::impossible::Impossible; -use base::scalar::ScalarLike; - -pub trait OperatorScalarQuantization: Operator + OperatorQuantizationProcess { - type Scalar: ScalarLike; - fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar; - fn scalar_quantization_preprocess( - dims: u32, - bits: u32, - max: &[f32], - min: &[f32], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed; -} - -impl OperatorScalarQuantization for VectDot { - type Scalar = S; - fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar { - vector.slice()[i as usize] - } - fn scalar_quantization_preprocess( - dims: u32, - bits: u32, - max: &[f32], - min: &[f32], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - let mut xy = Vec::with_capacity(dims as _); - for i in 0..dims { - let bas = min[i as usize]; - let del = max[i as usize] - min[i as usize]; - xy.extend((0..1 << bits).map(|k| { - let x = lhs.slice()[i as usize].to_f32(); - let val = k as f32 / ((1 << bits) - 1) as f32; - let y = bas + val * del; - x * y - })); - } - xy - } -} - -impl OperatorScalarQuantization for VectL2 { - type Scalar = S; - fn get(vector: Borrowed<'_, Self>, i: u32) -> Self::Scalar { - vector.slice()[i as usize] - } - fn scalar_quantization_preprocess( - dims: u32, - bits: u32, - max: &[f32], - min: &[f32], - lhs: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - let mut d2 = Vec::with_capacity(dims as _); - for i in 0..dims { - let bas = min[i as usize]; - let del = max[i as usize] - min[i as usize]; - d2.extend((0..1 << bits).map(|k| { - let x = lhs.slice()[i as usize].to_f32(); - let val = k as f32 / ((1 << bits) - 1) as f32; - let y = bas + val * del; - let d = x - y; - d * d - })); - } - d2 - } -} - -macro_rules! unimpl_operator_scalar_quantization { - ($t:ty) => { - impl OperatorScalarQuantization for $t { - type Scalar = Impossible; - fn get(_: Borrowed<'_, Self>, _: u32) -> Self::Scalar { - unimplemented!() - } - fn scalar_quantization_preprocess( - _: u32, - _: u32, - _: &[f32], - _: &[f32], - _: Borrowed<'_, Self>, - ) -> Self::QuantizationPreprocessed { - unimplemented!() - } - } - }; -} - -unimpl_operator_scalar_quantization!(BVectorDot); -unimpl_operator_scalar_quantization!(BVectorHamming); -unimpl_operator_scalar_quantization!(BVectorJaccard); - -unimpl_operator_scalar_quantization!(SVectDot); -unimpl_operator_scalar_quantization!(SVectL2); diff --git a/crates/quantization/src/trivial.rs b/crates/quantization/src/trivial.rs new file mode 100644 index 000000000..0094cb816 --- /dev/null +++ b/crates/quantization/src/trivial.rs @@ -0,0 +1,128 @@ +use crate::quantizer::Quantizer; +use crate::reranker::graph::GraphReranker; +use base::always_equal::AlwaysEqual; +use base::distance::Distance; +use base::index::*; +use base::operator::*; +use base::search::*; +use base::vector::VectorBorrowed; +use base::vector::VectorOwned; +use serde::Deserialize; +use serde::Serialize; +use std::cmp::Reverse; +use std::collections::BinaryHeap; +use std::marker::PhantomData; +use std::ops::Range; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct TrivialQuantizer { + _maker: PhantomData, +} + +impl Quantizer for TrivialQuantizer { + fn train( + _: VectorOptions, + _: Option, + _: &impl Vectors>, + _: impl Fn(Borrowed<'_, O>) -> Owned + Copy, + ) -> Self { + Self { + _maker: PhantomData, + } + } + + fn encode(&self, _: Borrowed<'_, O>) -> Vec { + Vec::new() + } + + fn fscan_encode(&self, _: [Owned; 32]) -> Vec { + Vec::new() + } + + fn code_size(&self) -> u32 { + 0 + } + + fn fcode_size(&self) -> u32 { + 0 + } + + type Lut = Owned; + + fn preprocess(&self, vector: Borrowed<'_, O>) -> Self::Lut { + vector.own() + } + + fn process(&self, lut: &Self::Lut, _: &[u8], vector: Borrowed<'_, O>) -> Distance { + O::distance(lut.as_borrowed(), vector) + } + + type FLut = std::convert::Infallible; + + fn fscan_preprocess(&self, _: Borrowed<'_, O>) -> Self::FLut { + unimplemented!() + } + + fn fscan_process(_: &Self::FLut, _: &[u8]) -> [Distance; 32] { + unimplemented!() + } + + type FlatRerankVec = Vec; + + fn flat_rerank_start() -> Self::FlatRerankVec { + Vec::new() + } + + fn flat_rerank_preprocess( + &self, + vector: Borrowed<'_, O>, + _: &SearchOptions, + ) -> Result { + Err(self.preprocess(vector)) + } + + fn flat_rerank_continue( + &self, + _: impl Fn(u32) -> C, + _: impl Fn(u32) -> C, + _: &Result, + range: Range, + heap: &mut Vec, + ) where + C: AsRef<[u8]>, + { + heap.extend(range); + } + + fn flat_rerank_break<'a, T: 'a, R>( + &'a self, + heap: Vec, + rerank: R, + _: &SearchOptions, + ) -> impl RerankerPop + 'a + where + R: Fn(u32) -> (Distance, T) + 'a, + { + heap.into_iter() + .map(|u| { + let (dis_u, pay_u) = rerank(u); + (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) + }) + .collect::, AlwaysEqual, AlwaysEqual)>>() + } + + fn graph_rerank<'a, T, R, C>( + &'a self, + _: impl Fn(u32) -> C + 'a, + _: Borrowed<'a, O>, + rerank: R, + ) -> impl RerankerPush + RerankerPop + 'a + where + T: 'a, + R: Fn(u32) -> (Distance, T) + 'a, + C: AsRef<[u8]>, + { + GraphReranker::new(rerank) + } +} diff --git a/crates/quantization/src/trivial/mod.rs b/crates/quantization/src/trivial/mod.rs deleted file mode 100644 index b18f8ab19..000000000 --- a/crates/quantization/src/trivial/mod.rs +++ /dev/null @@ -1,78 +0,0 @@ -pub mod operator; - -use self::operator::OperatorTrivialQuantization; -use crate::reranker::flat::DisabledFlatReranker; -use crate::reranker::graph::GraphReranker; -use base::always_equal::AlwaysEqual; -use base::distance::Distance; -use base::index::*; -use base::operator::*; -use base::search::*; -use base::vector::VectorBorrowed; -use serde::Deserialize; -use serde::Serialize; -use std::cmp::Reverse; -use std::marker::PhantomData; -use std::ops::Range; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "")] -pub struct TrivialQuantizer { - dims: u32, - _maker: PhantomData, -} - -impl TrivialQuantizer { - pub fn train( - vector_options: VectorOptions, - _: TrivialQuantizationOptions, - _: &impl Vectors>, - _: impl Fn(Borrowed<'_, O>) -> Owned + Copy, - ) -> Self { - Self { - dims: vector_options.dims, - _maker: PhantomData, - } - } - - pub fn project(&self, vector: Borrowed<'_, O>) -> Owned { - vector.own() - } - - pub fn preprocess(&self, lhs: Borrowed<'_, O>) -> O::TrivialQuantizationPreprocessed { - O::trivial_quantization_preprocess(lhs) - } - - pub fn process( - &self, - preprocessed: &O::TrivialQuantizationPreprocessed, - rhs: Borrowed<'_, O>, - ) -> Distance { - O::trivial_quantization_process(preprocessed, rhs) - } - - pub fn push_batch( - &self, - _preprocessed: &O::TrivialQuantizationPreprocessed, - rhs: Range, - heap: &mut Vec<(Reverse, AlwaysEqual)>, - ) { - heap.extend(rhs.map(|u| (Reverse(Distance::ZERO), AlwaysEqual(u)))); - } - - pub fn flat_rerank<'a, T: 'a>( - &'a self, - heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (Distance, T) + 'a, - ) -> impl RerankerPop + 'a { - DisabledFlatReranker::new(heap, r) - } - - pub fn graph_rerank<'a, T: 'a, R: Fn(u32) -> (Distance, T) + 'a>( - &'a self, - _: Borrowed<'a, O>, - r: R, - ) -> GraphReranker { - GraphReranker::new(None, r) - } -} diff --git a/crates/quantization/src/trivial/operator.rs b/crates/quantization/src/trivial/operator.rs deleted file mode 100644 index 6af5d5172..000000000 --- a/crates/quantization/src/trivial/operator.rs +++ /dev/null @@ -1,34 +0,0 @@ -use base::distance::Distance; -use base::operator::*; -use base::vector::VectorBorrowed; -use base::vector::VectorOwned; - -pub trait OperatorTrivialQuantization: Operator { - type TrivialQuantizationPreprocessed; - - fn trivial_quantization_preprocess( - lhs: Borrowed<'_, Self>, - ) -> Self::TrivialQuantizationPreprocessed; - - fn trivial_quantization_process( - preprocessed: &Self::TrivialQuantizationPreprocessed, - rhs: Borrowed<'_, Self>, - ) -> Distance; -} - -impl OperatorTrivialQuantization for O { - type TrivialQuantizationPreprocessed = Owned; - - fn trivial_quantization_preprocess( - lhs: Borrowed<'_, Self>, - ) -> Self::TrivialQuantizationPreprocessed { - lhs.own() - } - - fn trivial_quantization_process( - preprocessed: &Self::TrivialQuantizationPreprocessed, - rhs: Borrowed<'_, Self>, - ) -> Distance { - O::distance(preprocessed.as_borrowed(), rhs) - } -} diff --git a/crates/quantization/src/utils.rs b/crates/quantization/src/utils.rs index c61b87e6f..f8bf86499 100644 --- a/crates/quantization/src/utils.rs +++ b/crates/quantization/src/utils.rs @@ -18,3 +18,15 @@ impl, const N: usize> Iterator for InfiniteByteChunks u8 { + b0 | (b1 << 1) | (b2 << 2) | (b3 << 3) | (b4 << 4) | (b5 << 5) | (b6 << 6) | (b7 << 7) +} + +pub fn merge_4([b0, b1, b2, b3]: [u8; 4]) -> u8 { + b0 | (b1 << 2) | (b2 << 4) | (b3 << 6) +} + +pub fn merge_2([b0, b1]: [u8; 2]) -> u8 { + b0 | (b1 << 4) +} diff --git a/crates/rabitq/src/quant/quantization.rs b/crates/rabitq/src/quant/quantization.rs index a2184857b..a504ec92a 100644 --- a/crates/rabitq/src/quant/quantization.rs +++ b/crates/rabitq/src/quant/quantization.rs @@ -195,7 +195,7 @@ impl Quantization { pub fn push_batch( &self, preprocessed: &QuantizationAnyPreprocessed, - rhs: Range, + range: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, rq_epsilon: f32, ) { @@ -203,7 +203,7 @@ impl Quantization { (Quantizer::Rabitq(x), QuantizationAnyPreprocessed::Rabitq((a, b))) => x.push_batch( a, b, - rhs, + range, heap, &self.codes, &self.packed_codes, diff --git a/crates/rabitq/src/quant/quantizer.rs b/crates/rabitq/src/quant/quantizer.rs index 442a71cd5..c65739765 100644 --- a/crates/rabitq/src/quant/quantizer.rs +++ b/crates/rabitq/src/quant/quantizer.rs @@ -103,7 +103,7 @@ impl RabitqQuantizer { &self, alpha: &O::Params, beta: &Result>, - rhs: Range, + range: Range, heap: &mut Vec<(Reverse, AlwaysEqual)>, codes: &[u8], packed_codes: &[u8], @@ -113,9 +113,9 @@ impl RabitqQuantizer { match beta { Err(lut) => { use quantization::fast_scan::b4::{fast_scan_b4, BLOCK_SIZE}; - let s = rhs.start.next_multiple_of(BLOCK_SIZE); - let e = (rhs.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); - if rhs.start != s { + let s = range.start.next_multiple_of(BLOCK_SIZE); + let e = (range.end + 1 - BLOCK_SIZE).next_multiple_of(BLOCK_SIZE); + if range.start != s { let i = s - BLOCK_SIZE; let t = self.dims.div_ceil(4); let bytes = (t * 16) as usize; @@ -123,7 +123,7 @@ impl RabitqQuantizer { let end = start + bytes; let res = fast_scan_b4(t, &packed_codes[start..end], lut); heap.extend({ - (rhs.start..s).map(|u| { + (range.start..s).map(|u| { ( Reverse({ let a = meta[4 * u as usize + 0]; @@ -160,7 +160,7 @@ impl RabitqQuantizer { }) }); } - if e != rhs.end { + if e != range.end { let i = e; let t = self.dims.div_ceil(4); let bytes = (t * 16) as usize; @@ -168,7 +168,7 @@ impl RabitqQuantizer { let end = start + bytes; let res = fast_scan_b4(t, &packed_codes[start..end], lut); heap.extend({ - (e..rhs.end).map(|u| { + (e..range.end).map(|u| { ( Reverse({ let a = meta[4 * u as usize + 0]; @@ -185,7 +185,7 @@ impl RabitqQuantizer { } } Ok(blut) => { - heap.extend(rhs.map(|u| { + heap.extend(range.map(|u| { ( Reverse(self.process_lowerbound( alpha, @@ -212,8 +212,8 @@ impl RabitqQuantizer { pub fn rerank<'a, T: 'a>( &'a self, heap: Vec<(Reverse, AlwaysEqual)>, - r: impl Fn(u32) -> (Distance, T) + 'a, + rerank: impl Fn(u32) -> (Distance, T) + 'a, ) -> impl RerankerPop + 'a { - ErrorFlatReranker::new(heap, r) + ErrorFlatReranker::new(heap, rerank) } } diff --git a/rust-toolchain.toml b/rust-toolchain.toml index fe3aba684..737397f07 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,5 +1,5 @@ [toolchain] -channel = "nightly-2024-08-05" +channel = "nightly-2024-09-14" profile = "default" components = ["rust-analyzer", "rust-src"] targets = ["aarch64-unknown-linux-gnu", "x86_64-unknown-linux-gnu"] diff --git a/src/datatype/memory_bvector.rs b/src/datatype/memory_bvector.rs index 48841a9a7..1400d4e89 100644 --- a/src/datatype/memory_bvector.rs +++ b/src/datatype/memory_bvector.rs @@ -89,9 +89,9 @@ impl BVectorOutput { let layout = BVectorHeader::layout(dims as usize); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut BVectorHeader; ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(BVectorHeader::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).magic).write(HEADER_MAGIC); - std::ptr::addr_of_mut!((*ptr).dims).write(internal_dims); + (&raw mut (*ptr).varlena).write(BVectorHeader::varlena(layout.size())); + (&raw mut (*ptr).magic).write(HEADER_MAGIC); + (&raw mut (*ptr).dims).write(internal_dims); std::ptr::copy_nonoverlapping( vector.data().as_ptr(), (*ptr).phantom.as_mut_ptr(), diff --git a/src/datatype/memory_svecf32.rs b/src/datatype/memory_svecf32.rs index a4d49d074..74b783fba 100644 --- a/src/datatype/memory_svecf32.rs +++ b/src/datatype/memory_svecf32.rs @@ -97,11 +97,11 @@ impl SVecf32Output { let layout = SVecf32Header::layout(vector.len() as usize); let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut SVecf32Header; ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32Header::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).magic).write(HEADER_MAGIC); - std::ptr::addr_of_mut!((*ptr).dims).write(vector.dims()); - std::ptr::addr_of_mut!((*ptr).len).write(vector.len()); + (&raw mut (*ptr).varlena).write(SVecf32Header::varlena(layout.size())); + (&raw mut (*ptr).reserved).write(0); + (&raw mut (*ptr).magic).write(HEADER_MAGIC); + (&raw mut (*ptr).dims).write(vector.dims()); + (&raw mut (*ptr).len).write(vector.len()); let mut data_ptr = (*ptr).phantom.as_mut_ptr().cast::(); std::ptr::copy_nonoverlapping( vector.indexes().as_ptr(), diff --git a/src/datatype/memory_vecf16.rs b/src/datatype/memory_vecf16.rs index e2ddc0807..292244bad 100644 --- a/src/datatype/memory_vecf16.rs +++ b/src/datatype/memory_vecf16.rs @@ -93,9 +93,9 @@ impl Vecf16Output { let internal_dims = dims as u16; let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16Header; ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16Header::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).magic).write(HEADER_MAGIC); - std::ptr::addr_of_mut!((*ptr).dims).write(internal_dims); + (&raw mut (*ptr).varlena).write(Vecf16Header::varlena(layout.size())); + (&raw mut (*ptr).magic).write(HEADER_MAGIC); + (&raw mut (*ptr).dims).write(internal_dims); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf16Output(NonNull::new(ptr).unwrap()) } diff --git a/src/datatype/memory_vecf32.rs b/src/datatype/memory_vecf32.rs index 46a335c61..d53ae41dd 100644 --- a/src/datatype/memory_vecf32.rs +++ b/src/datatype/memory_vecf32.rs @@ -92,9 +92,9 @@ impl Vecf32Output { let internal_dims = dims as u16; let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32Header; ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32Header::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).magic).write(HEADER_MAGIC); - std::ptr::addr_of_mut!((*ptr).dims).write(internal_dims); + (&raw mut (*ptr).varlena).write(Vecf32Header::varlena(layout.size())); + (&raw mut (*ptr).magic).write(HEADER_MAGIC); + (&raw mut (*ptr).dims).write(internal_dims); std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); Vecf32Output(NonNull::new(ptr).unwrap()) } diff --git a/src/datatype/subscript_bvector.rs b/src/datatype/subscript_bvector.rs index aebc496ac..0e0a996d9 100644 --- a/src/datatype/subscript_bvector.rs +++ b/src/datatype/subscript_bvector.rs @@ -177,5 +177,5 @@ fn _vectors_bvector_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Intern fetch_leakproof: false, store_leakproof: false, }; - Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) + Internal::from(Some(Datum::from(&SBSROUTINES as *const _))) } diff --git a/src/datatype/subscript_svecf32.rs b/src/datatype/subscript_svecf32.rs index bf5831f3c..7b84168ab 100644 --- a/src/datatype/subscript_svecf32.rs +++ b/src/datatype/subscript_svecf32.rs @@ -177,5 +177,5 @@ fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Intern fetch_leakproof: false, store_leakproof: false, }; - Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) + Internal::from(Some(Datum::from(&SBSROUTINES as *const _))) } diff --git a/src/datatype/subscript_vecf16.rs b/src/datatype/subscript_vecf16.rs index fd34c01be..9bcfee1ef 100644 --- a/src/datatype/subscript_vecf16.rs +++ b/src/datatype/subscript_vecf16.rs @@ -177,5 +177,5 @@ fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Interna fetch_leakproof: false, store_leakproof: false, }; - Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) + Internal::from(Some(Datum::from(&SBSROUTINES as *const _))) } diff --git a/src/datatype/subscript_vecf32.rs b/src/datatype/subscript_vecf32.rs index 2c1d8f35b..b824c1b45 100644 --- a/src/datatype/subscript_vecf32.rs +++ b/src/datatype/subscript_vecf32.rs @@ -177,5 +177,5 @@ fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Interna fetch_leakproof: false, store_leakproof: false, }; - Internal::from(Some(Datum::from(std::ptr::addr_of!(SBSROUTINES)))) + Internal::from(Some(Datum::from(&SBSROUTINES as *const _))) } diff --git a/src/gucs/executing.rs b/src/gucs/executing.rs index ff94aaac2..d5d2ba465 100644 --- a/src/gucs/executing.rs +++ b/src/gucs/executing.rs @@ -1,29 +1,17 @@ use base::index::*; use pgrx::guc::{GucContext, GucFlags, GucRegistry, GucSetting}; -static FLAT_SQ_RERANK_SIZE: GucSetting = - GucSetting::::new(SearchOptions::default_flat_sq_rerank_size() as i32); +static SQ_RERANK_SIZE: GucSetting = + GucSetting::::new(SearchOptions::default_sq_rerank_size() as i32); -static FLAT_SQ_FAST_SCAN: GucSetting = - GucSetting::::new(SearchOptions::default_flat_sq_fast_scan()); +static SQ_FAST_SCAN: GucSetting = + GucSetting::::new(SearchOptions::default_sq_fast_scan()); -static FLAT_PQ_RERANK_SIZE: GucSetting = - GucSetting::::new(SearchOptions::default_flat_pq_rerank_size() as i32); +static PQ_RERANK_SIZE: GucSetting = + GucSetting::::new(SearchOptions::default_pq_rerank_size() as i32); -static FLAT_PQ_FAST_SCAN: GucSetting = - GucSetting::::new(SearchOptions::default_flat_pq_fast_scan()); - -static IVF_SQ_RERANK_SIZE: GucSetting = - GucSetting::::new(SearchOptions::default_ivf_sq_rerank_size() as i32); - -static IVF_SQ_FAST_SCAN: GucSetting = - GucSetting::::new(SearchOptions::default_ivf_sq_fast_scan()); - -static IVF_PQ_RERANK_SIZE: GucSetting = - GucSetting::::new(SearchOptions::default_ivf_pq_rerank_size() as i32); - -static IVF_PQ_FAST_SCAN: GucSetting = - GucSetting::::new(SearchOptions::default_ivf_pq_fast_scan()); +static PQ_FAST_SCAN: GucSetting = + GucSetting::::new(SearchOptions::default_pq_fast_scan()); static IVF_NPROBE: GucSetting = GucSetting::::new(SearchOptions::default_ivf_nprobe() as i32); @@ -45,74 +33,38 @@ static DISKANN_EF_SEARCH: GucSetting = pub unsafe fn init() { GucRegistry::define_int_guc( - "vectors.flat_sq_rerank_size", - "Scalar quantization reranker size.", - "https://docs.pgvecto.rs/usage/search.html", - &FLAT_SQ_RERANK_SIZE, - 0, - 65535, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_bool_guc( - "vectors.flat_sq_fast_scan", - "Enables fast scan or not.", - "https://docs.pgvecto.rs/usage/search.html", - &FLAT_SQ_FAST_SCAN, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_int_guc( - "vectors.flat_pq_rerank_size", - "Product quantization reranker size.", - "https://docs.pgvecto.rs/usage/search.html", - &FLAT_PQ_RERANK_SIZE, - 0, - 65535, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_bool_guc( - "vectors.flat_pq_fast_scan", - "Enables fast scan or not.", - "https://docs.pgvecto.rs/usage/search.html", - &FLAT_PQ_FAST_SCAN, - GucContext::Userset, - GucFlags::default(), - ); - GucRegistry::define_int_guc( - "vectors.ivf_sq_rerank_size", + "vectors.sq_rerank_size", "Scalar quantization reranker size.", "https://docs.pgvecto.rs/usage/search.html", - &IVF_SQ_RERANK_SIZE, + &SQ_RERANK_SIZE, 0, 65535, GucContext::Userset, GucFlags::default(), ); GucRegistry::define_bool_guc( - "vectors.ivf_sq_fast_scan", + "vectors.sq_fast_scan", "Enables fast scan or not.", "https://docs.pgvecto.rs/usage/search.html", - &IVF_SQ_FAST_SCAN, + &SQ_FAST_SCAN, GucContext::Userset, GucFlags::default(), ); GucRegistry::define_int_guc( - "vectors.ivf_pq_rerank_size", + "vectors.pq_rerank_size", "Product quantization reranker size.", "https://docs.pgvecto.rs/usage/search.html", - &IVF_PQ_RERANK_SIZE, + &PQ_RERANK_SIZE, 0, 65535, GucContext::Userset, GucFlags::default(), ); GucRegistry::define_bool_guc( - "vectors.ivf_pq_fast_scan", + "vectors.pq_fast_scan", "Enables fast scan or not.", "https://docs.pgvecto.rs/usage/search.html", - &IVF_PQ_FAST_SCAN, + &PQ_FAST_SCAN, GucContext::Userset, GucFlags::default(), ); @@ -178,14 +130,10 @@ pub unsafe fn init() { pub fn search_options() -> SearchOptions { SearchOptions { - flat_sq_rerank_size: FLAT_SQ_RERANK_SIZE.get() as u32, - flat_sq_fast_scan: FLAT_SQ_FAST_SCAN.get(), - flat_pq_rerank_size: FLAT_PQ_RERANK_SIZE.get() as u32, - flat_pq_fast_scan: FLAT_PQ_FAST_SCAN.get(), - ivf_sq_rerank_size: IVF_SQ_RERANK_SIZE.get() as u32, - ivf_sq_fast_scan: IVF_SQ_FAST_SCAN.get(), - ivf_pq_rerank_size: IVF_PQ_RERANK_SIZE.get() as u32, - ivf_pq_fast_scan: IVF_PQ_FAST_SCAN.get(), + sq_rerank_size: SQ_RERANK_SIZE.get() as u32, + sq_fast_scan: SQ_FAST_SCAN.get(), + pq_rerank_size: PQ_RERANK_SIZE.get() as u32, + pq_fast_scan: PQ_FAST_SCAN.get(), ivf_nprobe: IVF_NPROBE.get() as u32, hnsw_ef_search: HNSW_EF_SEARCH.get() as u32, rabitq_nprobe: RABITQ_NPROBE.get() as u32, diff --git a/src/index/am_options.rs b/src/index/am_options.rs index 00a798cf2..3ecc05740 100644 --- a/src/index/am_options.rs +++ b/src/index/am_options.rs @@ -32,7 +32,7 @@ impl Reloption { }]; unsafe fn options(&self) -> &CStr { unsafe { - let ptr = std::ptr::addr_of!(*self) + let ptr = (self as *const Self) .cast::() .offset(self.options as _); CStr::from_ptr(ptr)