Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rework quantization abstraction #591

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

90 changes: 23 additions & 67 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ pub struct IndexOptions {
impl IndexOptions {
fn validate_self_quantization(
&self,
quantization: &QuantizationOptions,
quantization: &Option<QuantizationOptions>,
) -> Result<(), ValidationError> {
match quantization {
QuantizationOptions::Trivial(_) => Ok(()),
QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_) => {
None => Ok(()),
Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => {
if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) {
return Err(ValidationError::new(
"scalar quantization or product quantization is not support for vectors that are not dense vectors",
Expand Down Expand Up @@ -356,13 +356,13 @@ impl Default for InvertedIndexingOptions {
pub struct FlatIndexingOptions {
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl Default for FlatIndexingOptions {
fn default() -> Self {
Self {
quantization: QuantizationOptions::default(),
quantization: Default::default(),
}
}
}
Expand All @@ -379,7 +379,7 @@ pub struct IvfIndexingOptions {
pub residual_quantization: bool,
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl IvfIndexingOptions {
Expand Down Expand Up @@ -418,7 +418,7 @@ pub struct HnswIndexingOptions {
pub ef_construction: u32,
#[serde(default)]
#[validate(nested)]
pub quantization: QuantizationOptions,
pub quantization: Option<QuantizationOptions>,
}

impl HnswIndexingOptions {
Expand Down Expand Up @@ -472,37 +472,19 @@ impl Default for RabitqIndexingOptions {
#[serde(deny_unknown_fields)]
#[serde(rename_all = "snake_case")]
pub enum QuantizationOptions {
Trivial(TrivialQuantizationOptions),
Scalar(ScalarQuantizationOptions),
Product(ProductQuantizationOptions),
}

impl Validate for QuantizationOptions {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
Self::Trivial(x) => x.validate(),
Self::Scalar(x) => x.validate(),
Self::Product(x) => x.validate(),
}
}
}

impl Default for QuantizationOptions {
fn default() -> Self {
Self::Trivial(Default::default())
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
pub struct TrivialQuantizationOptions {}

impl Default for TrivialQuantizationOptions {
fn default() -> Self {
Self {}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
#[validate(schema(function = "Self::validate_self"))]
Expand Down Expand Up @@ -569,26 +551,16 @@ impl Default for ProductQuantizationOptions {
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
#[serde(deny_unknown_fields)]
pub struct SearchOptions {
#[serde(default = "SearchOptions::default_flat_sq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub flat_sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_flat_sq_fast_scan")]
pub flat_sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_flat_pq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub flat_pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_flat_pq_fast_scan")]
pub flat_pq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_sq_rerank_size")]
#[serde(default = "SearchOptions::default_sq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub ivf_sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_ivf_sq_fast_scan")]
pub ivf_sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_pq_rerank_size")]
pub sq_rerank_size: u32,
#[serde(default = "SearchOptions::default_sq_fast_scan")]
pub sq_fast_scan: bool,
#[serde(default = "SearchOptions::default_pq_rerank_size")]
#[validate(range(min = 0, max = 65535))]
pub ivf_pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_ivf_pq_fast_scan")]
pub ivf_pq_fast_scan: bool,
pub pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_pq_fast_scan")]
pub pq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_nprobe")]
#[validate(range(min = 1, max = 65535))]
pub ivf_nprobe: u32,
Expand All @@ -609,28 +581,16 @@ pub struct SearchOptions {
}

impl SearchOptions {
pub const fn default_flat_sq_rerank_size() -> u32 {
0
}
pub const fn default_flat_sq_fast_scan() -> bool {
false
}
pub const fn default_flat_pq_rerank_size() -> u32 {
0
}
pub const fn default_flat_pq_fast_scan() -> bool {
false
}
pub const fn default_ivf_sq_rerank_size() -> u32 {
pub const fn default_sq_rerank_size() -> u32 {
0
}
pub const fn default_ivf_sq_fast_scan() -> bool {
pub const fn default_sq_fast_scan() -> bool {
false
}
pub const fn default_ivf_pq_rerank_size() -> u32 {
pub const fn default_pq_rerank_size() -> u32 {
0
}
pub const fn default_ivf_pq_fast_scan() -> bool {
pub const fn default_pq_fast_scan() -> bool {
false
}
pub const fn default_ivf_nprobe() -> u32 {
Expand All @@ -656,14 +616,10 @@ impl SearchOptions {
impl Default for SearchOptions {
fn default() -> Self {
Self {
flat_sq_rerank_size: Self::default_flat_sq_rerank_size(),
flat_sq_fast_scan: Self::default_flat_sq_fast_scan(),
flat_pq_rerank_size: Self::default_flat_pq_rerank_size(),
flat_pq_fast_scan: Self::default_flat_pq_fast_scan(),
ivf_sq_rerank_size: Self::default_ivf_sq_rerank_size(),
ivf_sq_fast_scan: Self::default_ivf_sq_fast_scan(),
ivf_pq_rerank_size: Self::default_ivf_pq_rerank_size(),
ivf_pq_fast_scan: Self::default_ivf_pq_fast_scan(),
sq_rerank_size: Self::default_sq_rerank_size(),
sq_fast_scan: Self::default_sq_fast_scan(),
pq_rerank_size: Self::default_pq_rerank_size(),
pq_fast_scan: Self::default_pq_fast_scan(),
ivf_nprobe: Self::default_ivf_nprobe(),
hnsw_ef_search: Self::default_hnsw_ef_search(),
rabitq_nprobe: Self::default_rabitq_nprobe(),
Expand Down
1 change: 0 additions & 1 deletion crates/base/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#![feature(const_float_bits_conv)]
#![feature(avx512_target_feature)]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))]
Expand Down
2 changes: 1 addition & 1 deletion crates/base/src/pod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ unsafe impl Pod for Distance {}
unsafe impl Pod for Impossible {}

pub fn bytes_of<T: Pod>(t: &T) -> &[u8] {
unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::<T>()) }
unsafe { core::slice::from_raw_parts(t as *const T as *const u8, size_of::<T>()) }
}

pub fn zeroed_vec<T: Pod>(length: usize) -> Vec<T> {
Expand Down
17 changes: 9 additions & 8 deletions crates/base/src/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use crate::distance::Distance;
use crate::vector::VectorOwned;
use serde::{Deserialize, Serialize};
use std::any::Any;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::fmt::Display;

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
Expand Down Expand Up @@ -93,14 +95,13 @@ pub trait RerankerPop<T> {
fn pop(&mut self) -> Option<(Distance, u32, T)>;
}

pub trait RerankerPush {
fn push(&mut self, u: u32);
}

pub trait FlatReranker<T>: RerankerPop<T> {}

impl<'a, T> RerankerPop<T> for Box<dyn FlatReranker<T> + 'a> {
impl<T> RerankerPop<T> for BinaryHeap<(Reverse<Distance>, AlwaysEqual<u32>, AlwaysEqual<T>)> {
fn pop(&mut self) -> Option<(Distance, u32, T)> {
self.as_mut().pop()
let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.pop()?;
Some((dis_u, u, pay_u))
}
}

pub trait RerankerPush {
fn push(&mut self, u: u32);
}
5 changes: 5 additions & 0 deletions crates/base/src/vector/bvect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ impl VectorOwned for BVectOwned {
data: &self.data,
}
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(dims, vec![0; dims.div_ceil(BVECTOR_WIDTH) as usize])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
2 changes: 2 additions & 0 deletions crates/base/src/vector/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static {
type Borrowed<'a>: VectorBorrowed<Owned = Self>;

fn as_borrowed(&self) -> Self::Borrowed<'_>;

fn zero(dims: u32) -> Self;
}

pub trait VectorBorrowed: Copy + PartialEq + PartialOrd {
Expand Down
5 changes: 5 additions & 0 deletions crates/base/src/vector/svect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ impl<S: ScalarLike> VectorOwned for SVectOwned<S> {
values: &self.values,
}
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(dims, vec![], vec![])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
5 changes: 5 additions & 0 deletions crates/base/src/vector/vect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ impl<S: ScalarLike> VectorOwned for VectOwned<S> {
fn as_borrowed(&self) -> VectBorrowed<'_, S> {
VectBorrowed(self.0.as_slice())
}

#[inline(always)]
fn zero(dims: u32) -> Self {
Self::new(vec![S::zero(); dims as usize])
}
}

#[derive(Debug, Clone, Copy)]
Expand Down
12 changes: 4 additions & 8 deletions crates/cli/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,13 @@ pub struct QueryArguments {
impl QueryArguments {
pub fn get_search_options(&self) -> SearchOptions {
SearchOptions {
flat_sq_rerank_size: 0,
flat_pq_rerank_size: 0,
ivf_sq_rerank_size: 0,
ivf_pq_rerank_size: 0,
sq_rerank_size: 0,
pq_rerank_size: 0,
hnsw_ef_search: self.ef,
ivf_nprobe: self.probe,
diskann_ef_search: 100,
flat_sq_fast_scan: false,
flat_pq_fast_scan: false,
ivf_sq_fast_scan: false,
ivf_pq_fast_scan: false,
sq_fast_scan: false,
pq_fast_scan: false,
rabitq_epsilon: 1.9,
rabitq_fast_scan: true,
rabitq_nprobe: self.probe,
Expand Down
38 changes: 16 additions & 22 deletions crates/flat/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,24 @@ use base::search::*;
use base::vector::VectorBorrowed;
use common::mmap_array::MmapArray;
use common::remap::RemappedCollection;
use quantization::operator::OperatorQuantization;
use quantization::quantizer::Quantizer;
use quantization::Quantization;
use std::fs::create_dir;
use std::path::Path;
use storage::OperatorStorage;
use storage::Storage;

pub trait OperatorFlat: OperatorQuantization + OperatorStorage {}
pub trait OperatorFlat: OperatorStorage {}

impl<T: OperatorQuantization + OperatorStorage> OperatorFlat for T {}
impl<T: OperatorStorage> OperatorFlat for T {}

pub struct Flat<O: OperatorFlat> {
pub struct Flat<O: OperatorFlat, Q: Quantizer<O>> {
storage: O::Storage,
quantization: Quantization<O>,
quantization: Quantization<O, Q>,
payloads: MmapArray<Payload>,
}

impl<O: OperatorFlat> Flat<O> {
impl<O: OperatorFlat, Q: Quantizer<O>> Flat<O, Q> {
pub fn create(
path: impl AsRef<Path>,
options: IndexOptions,
Expand All @@ -43,20 +43,14 @@ impl<O: OperatorFlat> Flat<O> {
vector: Borrowed<'a, O>,
opts: &'a SearchOptions,
) -> Box<dyn Iterator<Item = Element> + 'a> {
let mut heap = Vec::new();
let preprocessed = self.quantization.preprocess(vector);
self.quantization.push_batch(
&preprocessed,
0..self.storage.len(),
&mut heap,
opts.flat_sq_fast_scan,
opts.flat_pq_fast_scan,
);
let mut reranker = self.quantization.flat_rerank(
let mut heap = Q::flat_rerank_start();
let lut = self.quantization.flat_rerank_preprocess(vector, opts);
self.quantization
.flat_rerank_continue(&lut, 0..self.storage.len(), &mut heap);
let mut reranker = self.quantization.flat_rerank_break(
heap,
move |u| (O::distance(vector, self.storage.vector(u)), ()),
opts.flat_sq_rerank_size,
opts.flat_pq_rerank_size,
opts,
);
Box::new(std::iter::from_fn(move || {
reranker.pop().map(|(dis_u, u, ())| Element {
Expand All @@ -83,15 +77,15 @@ impl<O: OperatorFlat> Flat<O> {
}
}

fn from_nothing<O: OperatorFlat>(
fn from_nothing<O: OperatorFlat, Q: Quantizer<O>>(
path: impl AsRef<Path>,
options: IndexOptions,
collection: &(impl Vectors<Owned<O>> + Collection + Sync),
) -> Flat<O> {
) -> Flat<O, Q> {
create_dir(path.as_ref()).unwrap();
let flat_indexing_options = options.indexing.clone().unwrap_flat();
let storage = O::Storage::create(path.as_ref().join("storage"), collection);
let quantization = Quantization::<O>::create(
let quantization = Quantization::<O, Q>::create(
path.as_ref().join("quantization"),
options.vector,
flat_indexing_options.quantization,
Expand All @@ -109,7 +103,7 @@ fn from_nothing<O: OperatorFlat>(
}
}

fn open<O: OperatorFlat>(path: impl AsRef<Path>) -> Flat<O> {
fn open<O: OperatorFlat, Q: Quantizer<O>>(path: impl AsRef<Path>) -> Flat<O, Q> {
let storage = O::Storage::open(path.as_ref().join("storage"));
let quantization = Quantization::open(path.as_ref().join("quantization"));
let payloads = MmapArray::open(path.as_ref().join("payloads"));
Expand Down
Loading