Skip to content

Commit d58bf4c

Browse files
committed
refactor: rework quantization abstraction
Signed-off-by: usamoi <[email protected]>
1 parent 8abfd82 commit d58bf4c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+1614
-1462
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/base/src/index.rs

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,11 @@ pub struct IndexOptions {
106106
impl IndexOptions {
107107
fn validate_self_quantization(
108108
&self,
109-
quantization: &QuantizationOptions,
109+
quantization: &Option<QuantizationOptions>,
110110
) -> Result<(), ValidationError> {
111111
match quantization {
112-
QuantizationOptions::Trivial(_) => Ok(()),
113-
QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_) => {
112+
None => Ok(()),
113+
Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => {
114114
if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) {
115115
return Err(ValidationError::new(
116116
"scalar quantization or product quantization is not support for vectors that are not dense vectors",
@@ -356,13 +356,13 @@ impl Default for InvertedIndexingOptions {
356356
pub struct FlatIndexingOptions {
357357
#[serde(default)]
358358
#[validate(nested)]
359-
pub quantization: QuantizationOptions,
359+
pub quantization: Option<QuantizationOptions>,
360360
}
361361

362362
impl Default for FlatIndexingOptions {
363363
fn default() -> Self {
364364
Self {
365-
quantization: QuantizationOptions::default(),
365+
quantization: Default::default(),
366366
}
367367
}
368368
}
@@ -379,7 +379,7 @@ pub struct IvfIndexingOptions {
379379
pub residual_quantization: bool,
380380
#[serde(default)]
381381
#[validate(nested)]
382-
pub quantization: QuantizationOptions,
382+
pub quantization: Option<QuantizationOptions>,
383383
}
384384

385385
impl IvfIndexingOptions {
@@ -418,7 +418,7 @@ pub struct HnswIndexingOptions {
418418
pub ef_construction: u32,
419419
#[serde(default)]
420420
#[validate(nested)]
421-
pub quantization: QuantizationOptions,
421+
pub quantization: Option<QuantizationOptions>,
422422
}
423423

424424
impl HnswIndexingOptions {
@@ -472,37 +472,19 @@ impl Default for RabitqIndexingOptions {
472472
#[serde(deny_unknown_fields)]
473473
#[serde(rename_all = "snake_case")]
474474
pub enum QuantizationOptions {
475-
Trivial(TrivialQuantizationOptions),
476475
Scalar(ScalarQuantizationOptions),
477476
Product(ProductQuantizationOptions),
478477
}
479478

480479
impl Validate for QuantizationOptions {
481480
fn validate(&self) -> Result<(), validator::ValidationErrors> {
482481
match self {
483-
Self::Trivial(x) => x.validate(),
484482
Self::Scalar(x) => x.validate(),
485483
Self::Product(x) => x.validate(),
486484
}
487485
}
488486
}
489487

490-
impl Default for QuantizationOptions {
491-
fn default() -> Self {
492-
Self::Trivial(Default::default())
493-
}
494-
}
495-
496-
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
497-
#[serde(deny_unknown_fields)]
498-
pub struct TrivialQuantizationOptions {}
499-
500-
impl Default for TrivialQuantizationOptions {
501-
fn default() -> Self {
502-
Self {}
503-
}
504-
}
505-
506488
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
507489
#[serde(deny_unknown_fields)]
508490
#[validate(schema(function = "Self::validate_self"))]
@@ -569,26 +551,16 @@ impl Default for ProductQuantizationOptions {
569551
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
570552
#[serde(deny_unknown_fields)]
571553
pub struct SearchOptions {
572-
#[serde(default = "SearchOptions::default_flat_sq_rerank_size")]
573-
#[validate(range(min = 0, max = 65535))]
574-
pub flat_sq_rerank_size: u32,
575-
#[serde(default = "SearchOptions::default_flat_sq_fast_scan")]
576-
pub flat_sq_fast_scan: bool,
577-
#[serde(default = "SearchOptions::default_flat_pq_rerank_size")]
578-
#[validate(range(min = 0, max = 65535))]
579-
pub flat_pq_rerank_size: u32,
580-
#[serde(default = "SearchOptions::default_flat_pq_fast_scan")]
581-
pub flat_pq_fast_scan: bool,
582-
#[serde(default = "SearchOptions::default_ivf_sq_rerank_size")]
554+
#[serde(default = "SearchOptions::default_sq_rerank_size")]
583555
#[validate(range(min = 0, max = 65535))]
584-
pub ivf_sq_rerank_size: u32,
585-
#[serde(default = "SearchOptions::default_ivf_sq_fast_scan")]
586-
pub ivf_sq_fast_scan: bool,
587-
#[serde(default = "SearchOptions::default_ivf_pq_rerank_size")]
556+
pub sq_rerank_size: u32,
557+
#[serde(default = "SearchOptions::default_sq_fast_scan")]
558+
pub sq_fast_scan: bool,
559+
#[serde(default = "SearchOptions::default_pq_rerank_size")]
588560
#[validate(range(min = 0, max = 65535))]
589-
pub ivf_pq_rerank_size: u32,
590-
#[serde(default = "SearchOptions::default_ivf_pq_fast_scan")]
591-
pub ivf_pq_fast_scan: bool,
561+
pub pq_rerank_size: u32,
562+
#[serde(default = "SearchOptions::default_pq_fast_scan")]
563+
pub pq_fast_scan: bool,
592564
#[serde(default = "SearchOptions::default_ivf_nprobe")]
593565
#[validate(range(min = 1, max = 65535))]
594566
pub ivf_nprobe: u32,
@@ -609,28 +581,16 @@ pub struct SearchOptions {
609581
}
610582

611583
impl SearchOptions {
612-
pub const fn default_flat_sq_rerank_size() -> u32 {
613-
0
614-
}
615-
pub const fn default_flat_sq_fast_scan() -> bool {
616-
false
617-
}
618-
pub const fn default_flat_pq_rerank_size() -> u32 {
619-
0
620-
}
621-
pub const fn default_flat_pq_fast_scan() -> bool {
622-
false
623-
}
624-
pub const fn default_ivf_sq_rerank_size() -> u32 {
584+
pub const fn default_sq_rerank_size() -> u32 {
625585
0
626586
}
627-
pub const fn default_ivf_sq_fast_scan() -> bool {
587+
pub const fn default_sq_fast_scan() -> bool {
628588
false
629589
}
630-
pub const fn default_ivf_pq_rerank_size() -> u32 {
590+
pub const fn default_pq_rerank_size() -> u32 {
631591
0
632592
}
633-
pub const fn default_ivf_pq_fast_scan() -> bool {
593+
pub const fn default_pq_fast_scan() -> bool {
634594
false
635595
}
636596
pub const fn default_ivf_nprobe() -> u32 {
@@ -656,14 +616,10 @@ impl SearchOptions {
656616
impl Default for SearchOptions {
657617
fn default() -> Self {
658618
Self {
659-
flat_sq_rerank_size: Self::default_flat_sq_rerank_size(),
660-
flat_sq_fast_scan: Self::default_flat_sq_fast_scan(),
661-
flat_pq_rerank_size: Self::default_flat_pq_rerank_size(),
662-
flat_pq_fast_scan: Self::default_flat_pq_fast_scan(),
663-
ivf_sq_rerank_size: Self::default_ivf_sq_rerank_size(),
664-
ivf_sq_fast_scan: Self::default_ivf_sq_fast_scan(),
665-
ivf_pq_rerank_size: Self::default_ivf_pq_rerank_size(),
666-
ivf_pq_fast_scan: Self::default_ivf_pq_fast_scan(),
619+
sq_rerank_size: Self::default_sq_rerank_size(),
620+
sq_fast_scan: Self::default_sq_fast_scan(),
621+
pq_rerank_size: Self::default_pq_rerank_size(),
622+
pq_fast_scan: Self::default_pq_fast_scan(),
667623
ivf_nprobe: Self::default_ivf_nprobe(),
668624
hnsw_ef_search: Self::default_hnsw_ef_search(),
669625
rabitq_nprobe: Self::default_rabitq_nprobe(),

crates/base/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
#![feature(const_float_bits_conv)]
21
#![feature(avx512_target_feature)]
32
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
43
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512_f16))]

crates/base/src/pod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ unsafe impl Pod for Distance {}
3636
unsafe impl Pod for Impossible {}
3737

3838
pub fn bytes_of<T: Pod>(t: &T) -> &[u8] {
39-
unsafe { core::slice::from_raw_parts(std::ptr::addr_of!(*t) as *const u8, size_of::<T>()) }
39+
unsafe { core::slice::from_raw_parts(t as *const T as *const u8, size_of::<T>()) }
4040
}
4141

4242
pub fn zeroed_vec<T: Pod>(length: usize) -> Vec<T> {

crates/base/src/search.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use crate::distance::Distance;
33
use crate::vector::VectorOwned;
44
use serde::{Deserialize, Serialize};
55
use std::any::Any;
6+
use std::cmp::Reverse;
7+
use std::collections::BinaryHeap;
68
use std::fmt::Display;
79

810
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
@@ -93,14 +95,13 @@ pub trait RerankerPop<T> {
9395
fn pop(&mut self) -> Option<(Distance, u32, T)>;
9496
}
9597

96-
pub trait RerankerPush {
97-
fn push(&mut self, u: u32);
98-
}
99-
100-
pub trait FlatReranker<T>: RerankerPop<T> {}
101-
102-
impl<'a, T> RerankerPop<T> for Box<dyn FlatReranker<T> + 'a> {
98+
impl<T> RerankerPop<T> for BinaryHeap<(Reverse<Distance>, AlwaysEqual<u32>, AlwaysEqual<T>)> {
10399
fn pop(&mut self) -> Option<(Distance, u32, T)> {
104-
self.as_mut().pop()
100+
let (Reverse(dis_u), AlwaysEqual(u), AlwaysEqual(pay_u)) = self.pop()?;
101+
Some((dis_u, u, pay_u))
105102
}
106103
}
104+
105+
pub trait RerankerPush {
106+
fn push(&mut self, u: u32);
107+
}

crates/base/src/vector/bvect.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ impl VectorOwned for BVectOwned {
5555
data: &self.data,
5656
}
5757
}
58+
59+
#[inline(always)]
60+
fn zero(dims: u32) -> Self {
61+
Self::new(dims, vec![0; dims.div_ceil(BVECTOR_WIDTH) as usize])
62+
}
5863
}
5964

6065
#[derive(Debug, Clone, Copy)]

crates/base/src/vector/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static {
2424
type Borrowed<'a>: VectorBorrowed<Owned = Self>;
2525

2626
fn as_borrowed(&self) -> Self::Borrowed<'_>;
27+
28+
fn zero(dims: u32) -> Self;
2729
}
2830

2931
pub trait VectorBorrowed: Copy + PartialEq + PartialOrd {

crates/base/src/vector/svect.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ impl<S: ScalarLike> VectorOwned for SVectOwned<S> {
7777
values: &self.values,
7878
}
7979
}
80+
81+
#[inline(always)]
82+
fn zero(dims: u32) -> Self {
83+
Self::new(dims, vec![], vec![])
84+
}
8085
}
8186

8287
#[derive(Debug, Clone, Copy)]

crates/base/src/vector/vect.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ impl<S: ScalarLike> VectorOwned for VectOwned<S> {
4848
fn as_borrowed(&self) -> VectBorrowed<'_, S> {
4949
VectBorrowed(self.0.as_slice())
5050
}
51+
52+
#[inline(always)]
53+
fn zero(dims: u32) -> Self {
54+
Self::new(vec![S::zero(); dims as usize])
55+
}
5156
}
5257

5358
#[derive(Debug, Clone, Copy)]

crates/cli/src/args.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,13 @@ pub struct QueryArguments {
131131
impl QueryArguments {
132132
pub fn get_search_options(&self) -> SearchOptions {
133133
SearchOptions {
134-
flat_sq_rerank_size: 0,
135-
flat_pq_rerank_size: 0,
136-
ivf_sq_rerank_size: 0,
137-
ivf_pq_rerank_size: 0,
134+
sq_rerank_size: 0,
135+
pq_rerank_size: 0,
138136
hnsw_ef_search: self.ef,
139137
ivf_nprobe: self.probe,
140138
diskann_ef_search: 100,
141-
flat_sq_fast_scan: false,
142-
flat_pq_fast_scan: false,
143-
ivf_sq_fast_scan: false,
144-
ivf_pq_fast_scan: false,
139+
sq_fast_scan: false,
140+
pq_fast_scan: false,
145141
rabitq_epsilon: 1.9,
146142
rabitq_fast_scan: true,
147143
rabitq_nprobe: self.probe,

0 commit comments

Comments
 (0)