Skip to content

Commit

Permalink
refactor: rabitq (#593)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Sep 19, 2024
1 parent bb46189 commit 12aca46
Show file tree
Hide file tree
Showing 36 changed files with 1,248 additions and 2,156 deletions.
35 changes: 7 additions & 28 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

109 changes: 26 additions & 83 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ impl IndexOptions {
) -> Result<(), ValidationError> {
match quantization {
None => Ok(()),
Some(QuantizationOptions::Scalar(_) | QuantizationOptions::Product(_)) => {
Some(
QuantizationOptions::Scalar(_)
| QuantizationOptions::Product(_)
| QuantizationOptions::Rabitq(_),
) => {
if !matches!(self.vector.v, VectorKind::Vecf32 | VectorKind::Vecf16) {
return Err(ValidationError::new(
"scalar quantization or product quantization is not support for vectors that are not dense vectors",
"quantization is not support for vectors that are not dense vectors",
));
}
Ok(())
Expand Down Expand Up @@ -148,18 +152,6 @@ impl IndexOptions {
));
}
}
IndexingOptions::Rabitq(_) => {
if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) {
return Err(ValidationError::new(
"rabitq is not support for distance that is not l2 or dot",
));
}
if !matches!(self.vector.v, VectorKind::Vecf32) {
return Err(ValidationError::new(
"rabitq is not support for vectors that are not vector",
));
}
}
}
Ok(())
}
Expand Down Expand Up @@ -293,7 +285,6 @@ pub enum IndexingOptions {
Ivf(IvfIndexingOptions),
Hnsw(HnswIndexingOptions),
InvertedIndex(InvertedIndexingOptions),
Rabitq(RabitqIndexingOptions),
}

impl IndexingOptions {
Expand All @@ -315,12 +306,6 @@ impl IndexingOptions {
};
x
}
pub fn unwrap_rabitq(self) -> RabitqIndexingOptions {
let IndexingOptions::Rabitq(x) = self else {
unreachable!()
};
x
}
}

impl Default for IndexingOptions {
Expand All @@ -336,7 +321,6 @@ impl Validate for IndexingOptions {
Self::Ivf(x) => x.validate(),
Self::Hnsw(x) => x.validate(),
Self::InvertedIndex(x) => x.validate(),
Self::Rabitq(x) => x.validate(),
}
}
}
Expand Down Expand Up @@ -440,53 +424,21 @@ impl Default for HnswIndexingOptions {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
pub struct RabitqIndexingOptions {
#[serde(default = "RabitqIndexingOptions::default_nlist")]
#[validate(range(min = 1, max = 1_000_000))]
pub nlist: u32,
#[serde(default = "RabitqIndexingOptions::default_spherical_centroids")]
pub spherical_centroids: bool,
#[serde(default = "RabitqIndexingOptions::default_residual_quantization")]
pub residual_quantization: bool,
}

impl RabitqIndexingOptions {
fn default_nlist() -> u32 {
1000
}
fn default_spherical_centroids() -> bool {
false
}
fn default_residual_quantization() -> bool {
false
}
}

impl Default for RabitqIndexingOptions {
fn default() -> Self {
Self {
nlist: Self::default_nlist(),
spherical_centroids: Self::default_spherical_centroids(),
residual_quantization: Self::default_residual_quantization(),
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
#[serde(rename_all = "snake_case")]
pub enum QuantizationOptions {
Scalar(ScalarQuantizationOptions),
Product(ProductQuantizationOptions),
Rabitq(RabitqQuantizationOptions),
}

impl Validate for QuantizationOptions {
fn validate(&self) -> Result<(), validator::ValidationErrors> {
match self {
Self::Scalar(x) => x.validate(),
Self::Product(x) => x.validate(),
Self::Rabitq(x) => x.validate(),
}
}
}
Expand Down Expand Up @@ -554,6 +506,18 @@ impl Default for ProductQuantizationOptions {
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
#[serde(deny_unknown_fields)]
pub struct RabitqQuantizationOptions {}

impl RabitqQuantizationOptions {}

impl Default for RabitqQuantizationOptions {
fn default() -> Self {
Self {}
}
}

#[derive(Debug, Clone, Serialize, Deserialize, Validate, Alter)]
#[serde(deny_unknown_fields)]
pub struct SearchOptions {
Expand All @@ -567,23 +531,14 @@ pub struct SearchOptions {
pub pq_rerank_size: u32,
#[serde(default = "SearchOptions::default_pq_fast_scan")]
pub pq_fast_scan: bool,
#[serde(default = "SearchOptions::default_rq_fast_scan")]
pub rq_fast_scan: bool,
#[serde(default = "SearchOptions::default_ivf_nprobe")]
#[validate(range(min = 1, max = 65535))]
pub ivf_nprobe: u32,
#[serde(default = "SearchOptions::default_hnsw_ef_search")]
#[validate(range(min = 1, max = 65535))]
pub hnsw_ef_search: u32,
#[serde(default = "SearchOptions::default_rabitq_nprobe")]
#[validate(range(min = 1, max = 65535))]
pub rabitq_nprobe: u32,
#[serde(default = "SearchOptions::default_rabitq_epsilon")]
#[validate(range(min = 1.0, max = 4.0))]
pub rabitq_epsilon: f32,
#[serde(default = "SearchOptions::default_rabitq_fast_scan")]
pub rabitq_fast_scan: bool,
#[serde(default = "SearchOptions::default_diskann_ef_search")]
#[validate(range(min = 1, max = 65535))]
pub diskann_ef_search: u32,
}

impl SearchOptions {
Expand All @@ -599,24 +554,15 @@ impl SearchOptions {
pub const fn default_pq_fast_scan() -> bool {
false
}
pub const fn default_rq_fast_scan() -> bool {
true
}
pub const fn default_ivf_nprobe() -> u32 {
10
}
pub const fn default_hnsw_ef_search() -> u32 {
100
}
pub const fn default_rabitq_nprobe() -> u32 {
10
}
pub const fn default_rabitq_epsilon() -> f32 {
1.9
}
pub const fn default_rabitq_fast_scan() -> bool {
true
}
pub const fn default_diskann_ef_search() -> u32 {
100
}
}

impl Default for SearchOptions {
Expand All @@ -626,12 +572,9 @@ impl Default for SearchOptions {
sq_fast_scan: Self::default_sq_fast_scan(),
pq_rerank_size: Self::default_pq_rerank_size(),
pq_fast_scan: Self::default_pq_fast_scan(),
rq_fast_scan: Self::default_rq_fast_scan(),
ivf_nprobe: Self::default_ivf_nprobe(),
hnsw_ef_search: Self::default_hnsw_ef_search(),
rabitq_nprobe: Self::default_rabitq_nprobe(),
rabitq_epsilon: Self::default_rabitq_epsilon(),
rabitq_fast_scan: Self::default_rabitq_fast_scan(),
diskann_ef_search: Self::default_diskann_ef_search(),
}
}
}
Expand Down
14 changes: 14 additions & 0 deletions crates/base/src/scalar/f16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ impl ScalarLike for f16 {
lhs * rhs
}

#[inline(always)]
fn scalar_is_sign_positive(self) -> bool {
self.is_sign_positive()
}

#[inline(always)]
fn scalar_is_sign_negative(self) -> bool {
self.is_sign_negative()
}

#[inline(always)]
fn from_f32(x: f32) -> Self {
f16::from_f32(x)
Expand Down Expand Up @@ -236,6 +246,10 @@ impl ScalarLike for f16 {
r
}

fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> {
Self::vector_to_f32(this)
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn kmeans_helper(this: &mut [f16], x: f32, y: f32) {
let x = f16::from_f32(x);
Expand Down
14 changes: 14 additions & 0 deletions crates/base/src/scalar/f32.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ impl ScalarLike for f32 {
lhs * rhs
}

#[inline(always)]
fn scalar_is_sign_positive(self) -> bool {
self.is_sign_positive()
}

#[inline(always)]
fn scalar_is_sign_negative(self) -> bool {
self.is_sign_negative()
}

#[inline(always)]
fn from_f32(x: f32) -> Self {
x
Expand Down Expand Up @@ -187,6 +197,10 @@ impl ScalarLike for f32 {
this.to_vec()
}

fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]> {
this
}

#[detect::multiversion(v4, v3, v2, neon, fallback)]
fn kmeans_helper(this: &mut [f32], x: f32, y: f32) {
let n = this.len();
Expand Down
13 changes: 13 additions & 0 deletions crates/base/src/scalar/impossible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ impl ScalarLike for Impossible {
unimplemented!()
}

fn scalar_is_sign_positive(self) -> bool {
unimplemented!()
}

fn scalar_is_sign_negative(self) -> bool {
unimplemented!()
}

fn from_f32(_: f32) -> Self {
unimplemented!()
}
Expand Down Expand Up @@ -101,6 +109,11 @@ impl ScalarLike for Impossible {
unimplemented!()
}

#[allow(unreachable_code)]
fn vector_to_f32_borrowed(_: &[Self]) -> impl AsRef<[f32]> {
unimplemented!() as Vec<f32>
}

fn vector_add(_lhs: &[Self], _rhs: &[Self]) -> Vec<Self> {
unimplemented!()
}
Expand Down
3 changes: 3 additions & 0 deletions crates/base/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ pub trait ScalarLike:
fn scalar_add(lhs: Self, rhs: Self) -> Self;
fn scalar_sub(lhs: Self, rhs: Self) -> Self;
fn scalar_mul(lhs: Self, rhs: Self) -> Self;
fn scalar_is_sign_positive(self) -> bool;
fn scalar_is_sign_negative(self) -> bool;

fn from_f32(x: f32) -> Self;
fn to_f32(self) -> f32;
Expand All @@ -42,6 +44,7 @@ pub trait ScalarLike:

fn vector_from_f32(this: &[f32]) -> Vec<Self>;
fn vector_to_f32(this: &[Self]) -> Vec<f32>;
fn vector_to_f32_borrowed(this: &[Self]) -> impl AsRef<[f32]>;
fn vector_add(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;
fn vector_add_inplace(lhs: &mut [Self], rhs: &[Self]);
fn vector_sub(lhs: &[Self], rhs: &[Self]) -> Vec<Self>;
Expand Down
Loading

0 comments on commit 12aca46

Please sign in to comment.