Skip to content

Commit 014281b

Browse files
committed
refactor: rabitq
Signed-off-by: usamoi <[email protected]>
1 parent bb46189 commit 014281b

File tree

36 files changed

+1118
-2051
lines changed

36 files changed

+1118
-2051
lines changed

Cargo.lock

Lines changed: 7 additions & 28 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/base/src/index.rs

Lines changed: 26 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,14 @@ impl IndexOptions {
110110
) -> Result<(), ValidationError> {
111111
match quantization {
112112
None => Ok(()),
113-
Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => {
113+
Some(
114+
QuantizationOptions::Scalar(_)
115+
| QuantizationOptions::Product(_)
116+
| QuantizationOptions::Rabitq(_),
117+
) => {
114118
if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) {
115119
return Err(ValidationError::new(
116-
"scalar quantization or product quantization is not support for vectors that are not dense vectors",
120+
"quantization is not support for vectors that are not dense vectors",
117121
));
118122
}
119123
Ok(())
@@ -148,18 +152,6 @@ impl IndexOptions {
148152
));
149153
}
150154
}
151-
IndexingOptions::Rabitq(_) => {
152-
if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) {
153-
return Err(ValidationError::new(
154-
"rabitq is not support for distance that is not l2 or dot",
155-
));
156-
}
157-
if !matches!(self.vector.v, VectorKind::Vecf32) {
158-
return Err(ValidationError::new(
159-
"rabitq is not support for vectors that are not vector",
160-
));
161-
}
162-
}
163155
}
164156
Ok(())
165157
}
@@ -293,7 +285,6 @@ pub enum IndexingOptions {
293285
Ivf(IvfIndexingOptions),
294286
Hnsw(HnswIndexingOptions),
295287
InvertedIndex(InvertedIndexingOptions),
296-
Rabitq(RabitqIndexingOptions),
297288
}
298289

299290
impl IndexingOptions {
@@ -315,12 +306,6 @@ impl IndexingOptions {
315306
};
316307
x
317308
}
318-
pub fn unwrap_rabitq(self) -> RabitqIndexingOptions {
319-
let IndexingOptions::Rabitq(x) = self else {
320-
unreachable!()
321-
};
322-
x
323-
}
324309
}
325310

326311
impl Default for IndexingOptions {
@@ -336,7 +321,6 @@ impl Validate for IndexingOptions {
336321
Self::Ivf(x) => x.validate(),
337322
Self::Hnsw(x) => x.validate(),
338323
Self::InvertedIndex(x) => x.validate(),
339-
Self::Rabitq(x) => x.validate(),
340324
}
341325
}
342326
}
@@ -440,53 +424,21 @@ impl Default for HnswIndexingOptions {
440424
}
441425
}
442426

443-
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
444-
#[serde(deny_unknown_fields)]
445-
pub struct RabitqIndexingOptions {
446-
#[serde(default = "RabitqIndexingOptions::default_nlist")]
447-
#[validate(range(min = 1, max = 1_000_000))]
448-
pub nlist: u32,
449-
#[serde(default = "RabitqIndexingOptions::default_spherical_centroids")]
450-
pub spherical_centroids: bool,
451-
#[serde(default = "RabitqIndexingOptions::default_residual_quantization")]
452-
pub residual_quantization: bool,
453-
}
454-
455-
impl RabitqIndexingOptions {
456-
fn default_nlist() -> u32 {
457-
1000
458-
}
459-
fn default_spherical_centroids() -> bool {
460-
false
461-
}
462-
fn default_residual_quantization() -> bool {
463-
false
464-
}
465-
}
466-
467-
impl Default for RabitqIndexingOptions {
468-
fn default() -> Self {
469-
Self {
470-
nlist: Self::default_nlist(),
471-
spherical_centroids: Self::default_spherical_centroids(),
472-
residual_quantization: Self::default_residual_quantization(),
473-
}
474-
}
475-
}
476-
477427
#[derive(Debug, Clone, Serialize, Deserialize)]
478428
#[serde(deny_unknown_fields)]
479429
#[serde(rename_all = "snake_case")]
480430
pub enum QuantizationOptions {
481431
Scalar(ScalarQuantizationOptions),
482432
Product(ProductQuantizationOptions),
433+
Rabitq(RabitqQuantizationOptions),
483434
}
484435

485436
impl Validate for QuantizationOptions {
486437
fn validate(&self) -> Result<(), validator::ValidationErrors> {
487438
match self {
488439
Self::Scalar(x) => x.validate(),
489440
Self::Product(x) => x.validate(),
441+
Self::Rabitq(x) => x.validate(),
490442
}
491443
}
492444
}
@@ -554,6 +506,18 @@ impl Default for ProductQuantizationOptions {
554506
}
555507
}
556508

509+
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
510+
#[serde(deny_unknown_fields)]
511+
pub struct RabitqQuantizationOptions {}
512+
513+
impl RabitqQuantizationOptions {}
514+
515+
impl Default for RabitqQuantizationOptions {
516+
fn default() -> Self {
517+
Self {}
518+
}
519+
}
520+
557521
#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
558522
#[serde(deny_unknown_fields)]
559523
pub struct SearchOptions {
@@ -567,23 +531,14 @@ pub struct SearchOptions {
567531
pub pq_rerank_size: u32,
568532
#[serde(default = "SearchOptions::default_pq_fast_scan")]
569533
pub pq_fast_scan: bool,
534+
#[serde(default = "SearchOptions::default_rq_fast_scan")]
535+
pub rq_fast_scan: bool,
570536
#[serde(default = "SearchOptions::default_ivf_nprobe")]
571537
#[validate(range(min = 1, max = 65535))]
572538
pub ivf_nprobe: u32,
573539
#[serde(default = "SearchOptions::default_hnsw_ef_search")]
574540
#[validate(range(min = 1, max = 65535))]
575541
pub hnsw_ef_search: u32,
576-
#[serde(default = "SearchOptions::default_rabitq_nprobe")]
577-
#[validate(range(min = 1, max = 65535))]
578-
pub rabitq_nprobe: u32,
579-
#[serde(default = "SearchOptions::default_rabitq_epsilon")]
580-
#[validate(range(min = 1.0, max = 4.0))]
581-
pub rabitq_epsilon: f32,
582-
#[serde(default = "SearchOptions::default_rabitq_fast_scan")]
583-
pub rabitq_fast_scan: bool,
584-
#[serde(default = "SearchOptions::default_diskann_ef_search")]
585-
#[validate(range(min = 1, max = 65535))]
586-
pub diskann_ef_search: u32,
587542
}
588543

589544
impl SearchOptions {
@@ -599,24 +554,15 @@ impl SearchOptions {
599554
pub const fn default_pq_fast_scan() -> bool {
600555
false
601556
}
557+
pub const fn default_rq_fast_scan() -> bool {
558+
true
559+
}
602560
pub const fn default_ivf_nprobe() -> u32 {
603561
10
604562
}
605563
pub const fn default_hnsw_ef_search() -> u32 {
606564
100
607565
}
608-
pub const fn default_rabitq_nprobe() -> u32 {
609-
10
610-
}
611-
pub const fn default_rabitq_epsilon() -> f32 {
612-
1.9
613-
}
614-
pub const fn default_rabitq_fast_scan() -> bool {
615-
true
616-
}
617-
pub const fn default_diskann_ef_search() -> u32 {
618-
100
619-
}
620566
}
621567

622568
impl Default for SearchOptions {
@@ -626,12 +572,9 @@ impl Default for SearchOptions {
626572
sq_fast_scan: Self::default_sq_fast_scan(),
627573
pq_rerank_size: Self::default_pq_rerank_size(),
628574
pq_fast_scan: Self::default_pq_fast_scan(),
575+
rq_fast_scan: Self::default_rq_fast_scan(),
629576
ivf_nprobe: Self::default_ivf_nprobe(),
630577
hnsw_ef_search: Self::default_hnsw_ef_search(),
631-
rabitq_nprobe: Self::default_rabitq_nprobe(),
632-
rabitq_epsilon: Self::default_rabitq_epsilon(),
633-
rabitq_fast_scan: Self::default_rabitq_fast_scan(),
634-
diskann_ef_search: Self::default_diskann_ef_search(),
635578
}
636579
}
637580
}

crates/base/src/scalar/f16.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ impl ScalarLike for f16 {
3737
lhs * rhs
3838
}
3939

40+
#[inline(always)]
41+
fn scalar_is_sign_positive(self) -> bool {
42+
self.is_sign_positive()
43+
}
44+
45+
#[inline(always)]
46+
fn scalar_is_sign_negative(self) -> bool {
47+
self.is_sign_negative()
48+
}
49+
4050
#[inline(always)]
4151
fn from_f32(x: f32) -> Self {
4252
f16::from_f32(x)
@@ -236,6 +246,10 @@ impl ScalarLike for f16 {
236246
r
237247
}
238248

249+
fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> {
250+
Self::vector_to_f32(this)
251+
}
252+
239253
#[detect::multiversion(v4, v3, v2, neon, fallback)]
240254
fn kmeans_helper(this: &mut [f16], x: f32, y: f32) {
241255
let x = f16::from_f32(x);

crates/base/src/scalar/f32.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,16 @@ impl ScalarLike for f32 {
3636
lhs * rhs
3737
}
3838

39+
#[inline(always)]
40+
fn scalar_is_sign_positive(self) -> bool {
41+
self.is_sign_positive()
42+
}
43+
44+
#[inline(always)]
45+
fn scalar_is_sign_negative(self) -> bool {
46+
self.is_sign_negative()
47+
}
48+
3949
#[inline(always)]
4050
fn from_f32(x: f32) -> Self {
4151
x
@@ -187,6 +197,10 @@ impl ScalarLike for f32 {
187197
this.to_vec()
188198
}
189199

200+
fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> {
201+
this
202+
}
203+
190204
#[detect::multiversion(v4, v3, v2, neon, fallback)]
191205
fn kmeans_helper(this: &mut [f32], x: f32, y: f32) {
192206
let n = this.len();

crates/base/src/scalar/impossible.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ impl ScalarLike for Impossible {
3939
unimplemented!()
4040
}
4141

42+
fn scalar_is_sign_positive(self) -> bool {
43+
unimplemented!()
44+
}
45+
46+
fn scalar_is_sign_negative(self) -> bool {
47+
unimplemented!()
48+
}
49+
4250
fn from_f32(_: f32) -> Self {
4351
unimplemented!()
4452
}
@@ -101,6 +109,11 @@ impl ScalarLike for Impossible {
101109
unimplemented!()
102110
}
103111

112+
#[allow(unreachable_code)]
113+
fn vector_to_f32_borrowed(_: &[Self]) -> impl AsRef<[f32]> {
114+
unimplemented!() as Vec<f32>
115+
}
116+
104117
fn vector_add(_lhs: &[Self], _rhs: &[Self]) -> Vec<Self> {
105118
unimplemented!()
106119
}

crates/base/src/scalar/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ pub trait ScalarLike:
2424
fn scalar_add(lhs: Self, rhs: Self) -> Self;
2525
fn scalar_sub(lhs: Self, rhs: Self) -> Self;
2626
fn scalar_mul(lhs: Self, rhs: Self) -> Self;
27+
fn scalar_is_sign_positive(self) -> bool;
28+
fn scalar_is_sign_negative(self) -> bool;
2729

2830
fn from_f32(x: f32) -> Self;
2931
fn to_f32(self) -> f32;
@@ -42,6 +44,7 @@ pub trait ScalarLike:
4244

4345
fn vector_from_f32(this: &[f32]) -> Vec<Self>;
4446
fn vector_to_f32(this: &[Self]) -> Vec<f32>;
47+
fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]>;
4548
fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;
4649
fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]);
4750
fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;

0 commit comments

Comments
 (0)