Skip to content

Commit

Permalink
refactor: rework quantization abstraction (#591)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Sep 18, 2024
1 parent 8abfd82 commit 1ed47d8
Show file tree
Hide file tree
Showing 51 changed files with 1,638 additions and 1,490 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 23 additions & 67 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ pub struct IndexOptions {
impl IndexOptions {
fn validate_self_quantization(
&self,
quantization: &QuantizationOptions,
quantization: &Option<QuantizationOptions>,
) -> Result<(), ValidationError> {
match quantization {
QuantizationOptions::Trivial(_) => Ok(()),
QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_) => {
None => Ok(()),
Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => {
if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) {
return Err(ValidationError::new(
"scalar quantization or product quantization is not support for vectors that are not dense vectors",
Expand Down Expand Up @@ -356,13 +356,13 @@ impl Default for InvertedIndexingOptions {
pub struct FlatIndexingOptions {
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl Default for FlatIndexingOptions {
fn default() -> Self {
Self {
quantization: QuantizationOptions::default(),
quantization: Default::default(),
}
}
}
Expand All @@ -379,7 +379,7 @@ pub struct IvfIndexingOptions {
pub residual_quantization: bool,
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl IvfIndexingOptions {
Expand Down Expand Up @@ -418,7 +418,7 @@ pub struct HnswIndexingOptions {
pub ef_construction: u32,
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl HnswIndexingOptions {
Expand Down Expand Up @@ -472,37 +472,19 @@ impl Default for RabitqIndexingOptions {
#[serde(deny_unknown_fields)]
#[serde(rename_all = "snake_case")]
pub enum QuantizationOptions {
Trivial(TrivialQuantizationOptions),
Scalar(ScalarQuantizationOptions),
Product(ProductQuantizationOptions),
}

impl Validate for QuantizationOptions {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
Self::Trivial(x) => x.validate(),
Self::Scalar(x) => x.validate(),
Self::Product(x) => x.validate(),
}
}
}

impl Default for QuantizationOptions {
fn default() -> Self {
Self::Trivial(Default::default())
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
pub struct TrivialQuantizationOptions {}

impl Default for TrivialQuantizationOptions {
fn default() -> Self {
Self {}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
#[validate(schema(function = "Self::validate_self"))]
Expand Down Expand Up @@ -569,26 +551,16 @@ impl Default for ProductQuantizationOptions {
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
#[serde(deny_unknown_fields)]
pub struct SearchOptions {
#[serde(default = "SearchOptions::default_flat_sq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub flat_sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_flat_sq_fast_scan")]
pub flat_sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_flat_pq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub flat_pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_flat_pq_fast_scan")]
pub flat_pq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_sq_rerank_size")]
#[serde(default = "SearchOptions::default_sq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub ivf_sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_ivf_sq_fast_scan")]
pub ivf_sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_pq_rerank_size")]
pub sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_sq_fast_scan")]
pub sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_pq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub ivf_pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_ivf_pq_fast_scan")]
pub ivf_pq_fast_scan: bool,
pub pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_pq_fast_scan")]
pub pq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_nprobe")]
#[validate(range(min = 1, max = 65535))]
pub ivf_nprobe: u32,
Expand All @@ -609,28 +581,16 @@ pub struct SearchOptions {
}

impl SearchOptions {
pub const fn default_flat_sq_rerank_size() -> u32 {
0
}
pub const fn default_flat_sq_fast_scan() -> bool {
false
}
pub const fn default_flat_pq_rerank_size() -> u32 {
0
}
pub const fn default_flat_pq_fast_scan() -> bool {
false
}
pub const fn default_ivf_sq_rerank_size() -> u32 {
pub const fn default_sq_rerank_size() -> u32 {
0
}
pub const fn default_ivf_sq_fast_scan() -> bool {
pub const fn default_sq_fast_scan() -> bool {
false
}
pub const fn default_ivf_pq_rerank_size() -> u32 {
pub const fn default_pq_rerank_size() -> u32 {
0
}
pub const fn default_ivf_pq_fast_scan() -> bool {
pub const fn default_pq_fast_scan() -> bool {
false
}
pub const fn default_ivf_nprobe() -> u32 {
Expand All @@ -656,14 +616,10 @@ impl SearchOptions {
impl Default for SearchOptions {
fn default() -> Self {
Self {
flat_sq_rerank_size: Self::default_flat_sq_rerank_size(),
flat_sq_fast_scan: Self::default_flat_sq_fast_scan(),
flat_pq_rerank_size: Self::default_flat_pq_rerank_size(),
flat_pq_fast_scan: Self::default_flat_pq_fast_scan(),
ivf_sq_rerank_size: Self::default_ivf_sq_rerank_size(),
ivf_sq_fast_scan: Self::default_ivf_sq_fast_scan(),
ivf_pq_rerank_size: Self::default_ivf_pq_rerank_size(),
ivf_pq_fast_scan: Self::default_ivf_pq_fast_scan(),
sq_rerank_size: Self::default_sq_rerank_size(),
sq_fast_scan: Self::default_sq_fast_scan(),
pq_rerank_size: Self::default_pq_rerank_size(),
pq_fast_scan: Self::default_pq_fast_scan(),
ivf_nprobe: Self::default_ivf_nprobe(),
hnsw_ef_search: Self::default_hnsw_ef_search(),
rabitq_nprobe: Self::default_rabitq_nprobe(),
Expand Down
1 change: 0 additions & 1 deletion crates/base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![feature(const_float_bits_conv)]
#![feature(avx512_target_feature)]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))]
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/pod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ unsafe impl Pod for Distance {}
unsafe impl Pod for Impossible {}

pub fn bytes_of<T: Pod>(t: &T) -> &[u8] {
unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::<T>()) }
unsafe { core::slice::from_raw_parts(t as *const T as *const u8, size_of::<T>()) }
}

pub fn zeroed_vec<T: Pod>(length: usize) -> Vec<T> {
Expand Down
17 changes: 9 additions & 8 deletions crates/base/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::distance::Distance;
use crate::vector::VectorOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::fmt::Display;

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
Expand Down Expand Up @@ -93,14 +95,13 @@ pub trait RerankerPop<T> {
fn pop(&mut self) -> Option<(Distance, u32, T)>;
}

pub trait RerankerPush {
fn push(&mut self, u: u32);
}

pub trait FlatReranker<T>: RerankerPop<T> {}

impl<'a, T> RerankerPop<T> for Box<dyn FlatReranker<T> + 'a> {
impl<T> RerankerPop<T> for BinaryHeap<(Reverse<Distance>, AlwaysEqual<u32>, AlwaysEqual<T>)> {
fn pop(&mut self) -> Option<(Distance, u32, T)> {
self.as_mut().pop()
let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.pop()?;
Some((dis_u, u, pay_u))
}
}

pub trait RerankerPush {
fn push(&mut self, u: u32);
}
5 changes: 5 additions & 0 deletions crates/base/src/vector/bvect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ impl VectorOwned for BVectOwned {
data: &self.data,
}
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(dims, vec![0; dims.div_ceil(BVECTOR_WIDTH) as usize])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
2 changes: 2 additions & 0 deletions crates/base/src/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static {
type Borrowed<'a>: VectorBorrowed<Owned = Self>;

fn as_borrowed(&self) -> Self::Borrowed<'_>;

fn zero(dims: u32) -> Self;
}

pub trait VectorBorrowed: Copy + PartialEq + PartialOrd {
Expand Down
5 changes: 5 additions & 0 deletions crates/base/src/vector/svect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ impl<S: ScalarLike> VectorOwned for SVectOwned<S> {
values: &self.values,
}
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(dims, vec![], vec![])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
5 changes: 5 additions & 0 deletions crates/base/src/vector/vect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl<S: ScalarLike> VectorOwned for VectOwned<S> {
fn as_borrowed(&self) -> VectBorrowed<'_, S> {
VectBorrowed(self.0.as_slice())
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(vec![S::zero(); dims as usize])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
12 changes: 4 additions & 8 deletions crates/cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,13 @@ pub struct QueryArguments {
impl QueryArguments {
pub fn get_search_options(&self) -> SearchOptions {
SearchOptions {
flat_sq_rerank_size: 0,
flat_pq_rerank_size: 0,
ivf_sq_rerank_size: 0,
ivf_pq_rerank_size: 0,
sq_rerank_size: 0,
pq_rerank_size: 0,
hnsw_ef_search: self.ef,
ivf_nprobe: self.probe,
diskann_ef_search: 100,
flat_sq_fast_scan: false,
flat_pq_fast_scan: false,
ivf_sq_fast_scan: false,
ivf_pq_fast_scan: false,
sq_fast_scan: false,
pq_fast_scan: false,
rabitq_epsilon: 1.9,
rabitq_fast_scan: true,
rabitq_nprobe: self.probe,
Expand Down
38 changes: 16 additions & 22 deletions crates/flat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ use base::search::*;
use base::vector::VectorBorrowed;
use common::mmap_array::MmapArray;
use common::remap::RemappedCollection;
use quantization::operator::OperatorQuantization;
use quantization::quantizer::Quantizer;
use quantization::Quantization;
use std::fs::create_dir;
use std::path::Path;
use storage::OperatorStorage;
use storage::Storage;

pub trait OperatorFlat: OperatorQuantization + OperatorStorage {}
pub trait OperatorFlat: OperatorStorage {}

impl<T: OperatorQuantization + OperatorStorage> OperatorFlat for T {}
impl<T: OperatorStorage> OperatorFlat for T {}

pub struct Flat<O: OperatorFlat> {
pub struct Flat<O: OperatorFlat, Q: Quantizer<O>> {
storage: O::Storage,
quantization: Quantization<O>,
quantization: Quantization<O, Q>,
payloads: MmapArray<Payload>,
}

impl<O: OperatorFlat> Flat<O> {
impl<O: OperatorFlat, Q: Quantizer<O>> Flat<O, Q> {
pub fn create(
path: impl AsRef<Path>,
options: IndexOptions,
Expand All @@ -43,20 +43,14 @@ impl<O: OperatorFlat> Flat<O> {
vector: Borrowed<'a, O>,
opts: &'a SearchOptions,
) -> Box<dyn Iterator<Item = Element> + 'a> {
let mut heap = Vec::new();
let preprocessed = self.quantization.preprocess(vector);
self.quantization.push_batch(
&preprocessed,
0..self.storage.len(),
&mut heap,
opts.flat_sq_fast_scan,
opts.flat_pq_fast_scan,
);
let mut reranker = self.quantization.flat_rerank(
let mut heap = Q::flat_rerank_start();
let lut = self.quantization.flat_rerank_preprocess(vector, opts);
self.quantization
.flat_rerank_continue(&lut, 0..self.storage.len(), &mut heap);
let mut reranker = self.quantization.flat_rerank_break(
heap,
move |u| (O::distance(vector, self.storage.vector(u)), ()),
opts.flat_sq_rerank_size,
opts.flat_pq_rerank_size,
opts,
);
Box::new(std::iter::from_fn(move || {
reranker.pop().map(|(dis_u, u, ())| Element {
Expand All @@ -83,15 +77,15 @@ impl<O: OperatorFlat> Flat<O> {
}
}

fn from_nothing<O: OperatorFlat>(
fn from_nothing<O: OperatorFlat, Q: Quantizer<O>>(
path: impl AsRef<Path>,
options: IndexOptions,
collection: &(impl Vectors<Owned<O>> + Collection + Sync),
) -> Flat<O> {
) -> Flat<O, Q> {
create_dir(path.as_ref()).unwrap();
let flat_indexing_options = options.indexing.clone().unwrap_flat();
let storage = O::Storage::create(path.as_ref().join("storage"), collection);
let quantization = Quantization::<O>::create(
let quantization = Quantization::<O, Q>::create(
path.as_ref().join("quantization"),
options.vector,
flat_indexing_options.quantization,
Expand All @@ -109,7 +103,7 @@ fn from_nothing<O: OperatorFlat>(
}
}

fn open<O: OperatorFlat>(path: impl AsRef<Path>) -> Flat<O> {
fn open<O: OperatorFlat, Q: Quantizer<O>>(path: impl AsRef<Path>) -> Flat<O, Q> {
let storage = O::Storage::open(path.as_ref().join("storage"));
let quantization = Quantization::open(path.as_ref().join("quantization"));
let payloads = MmapArray::open(path.as_ref().join("payloads"));
Expand Down
Loading

0 comments on commit 1ed47d8

Please sign in to comment.