Skip to content

Commit

Permalink
feat: add metrics dot and cos (#566)
Browse files Browse the repository at this point in the history
* feat: add metrics dot and cos

Signed-off-by: cutecutecat <[email protected]>

* add option residual_quantization

Signed-off-by: cutecutecat <[email protected]>

* deprecate residual except l2

Signed-off-by: cutecutecat <[email protected]>

* fix by comments

Signed-off-by: cutecutecat <[email protected]>

---------

Signed-off-by: cutecutecat <[email protected]>
  • Loading branch information
cutecutecat authored Sep 18, 2024
1 parent 1ed47d8 commit bb46189
Show file tree
Hide file tree
Showing 5 changed files with 683 additions and 310 deletions.
12 changes: 9 additions & 3 deletions crates/base/src/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,9 @@ impl IndexOptions {
}
}
IndexingOptions::Rabitq(_) => {
if !matches!(self.vector.d, DistanceKind::L2) {
if !matches!(self.vector.d, DistanceKind::L2 | DistanceKind::Dot) {
return Err(ValidationError::new(
"rabitq is not support for distance that is not l2",
"rabitq is not support for distance that is not l2 or dot",
));
}
if !matches!(self.vector.v, VectorKind::Vecf32) {
Expand Down Expand Up @@ -446,8 +446,10 @@ pub struct RabitqIndexingOptions {
#[serde(default = "RabitqIndexingOptions::default_nlist")]
#[validate(range(min = 1, max = 1_000_000))]
pub nlist: u32,
#[serde(default = "IvfIndexingOptions::default_spherical_centroids")]
#[serde(default = "RabitqIndexingOptions::default_spherical_centroids")]
pub spherical_centroids: bool,
#[serde(default = "RabitqIndexingOptions::default_residual_quantization")]
pub residual_quantization: bool,
}

impl RabitqIndexingOptions {
Expand All @@ -457,13 +459,17 @@ impl RabitqIndexingOptions {
fn default_spherical_centroids() -> bool {
false
}
fn default_residual_quantization() -> bool {
false
}
}

impl Default for RabitqIndexingOptions {
fn default() -> Self {
Self {
nlist: Self::default_nlist(),
spherical_centroids: Self::default_spherical_centroids(),
residual_quantization: Self::default_residual_quantization(),
}
}
}
Expand Down
64 changes: 40 additions & 24 deletions crates/rabitq/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct Rabitq<O: Op> {
offsets: Json<Vec<u32>>,
projected_centroids: Json<Vec2<f32>>,
projection: Json<Vec<Vec<f32>>>,
is_residual: Json<bool>,
}

impl<O: Op> Rabitq<O> {
Expand Down Expand Up @@ -74,21 +75,16 @@ impl<O: Op> Rabitq<O> {
opts.rabitq_nprobe as usize,
);
let mut heap = Vec::new();
for &(_, i) in lists.iter() {
for &(dis_v2, i) in lists.iter() {
let trans_vector = if *self.is_residual {
&O::residual(&projected_query, &self.projected_centroids[(i,)])
} else {
&projected_query
};
let preprocessed = if opts.rabitq_fast_scan {
self.quantization
.fscan_preprocess(&O::residual(
&projected_query,
&self.projected_centroids[(i,)],
))
.into()
self.quantization.fscan_preprocess(trans_vector, dis_v2)
} else {
self.quantization
.preprocess(&O::residual(
&projected_query,
&self.projected_centroids[(i,)],
))
.into()
self.quantization.preprocess(trans_vector, dis_v2)
};
let start = self.offsets[i];
let end = self.offsets[i + 1];
Expand Down Expand Up @@ -116,6 +112,7 @@ fn from_nothing<O: Op>(
let RabitqIndexingOptions {
nlist,
spherical_centroids,
residual_quantization,
} = options.indexing.clone().unwrap_rabitq();
let projection = {
use nalgebra::{DMatrix, QR};
Expand All @@ -137,6 +134,7 @@ fn from_nothing<O: Op>(
}
projection
};
let is_residual = residual_quantization && O::SUPPORT_RESIDUAL;
rayon::check();
let samples = O::sample(collection, nlist);
rayon::check();
Expand Down Expand Up @@ -174,16 +172,30 @@ fn from_nothing<O: Op>(
let collection = RemappedCollection::from_collection(collection, remap);
rayon::check();
let storage = O::Storage::create(path.as_ref().join("storage"), &collection);
let quantization = Quantization::create(
path.as_ref().join("quantization"),
options.vector,
collection.len(),
|vector| {
let vector = O::cast(collection.vector(vector));
let target = k_means_lookup(vector, &centroids);
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
},
);

let quantization = if is_residual {
Quantization::create(
path.as_ref().join("quantization"),
options.vector,
collection.len(),
|vector| {
let vector = O::cast(collection.vector(vector));
let target = k_means_lookup(vector, &centroids);
O::proj(&projection, &O::residual(vector, &centroids[(target,)]))
},
)
} else {
Quantization::create(
path.as_ref().join("quantization"),
options.vector,
collection.len(),
|vector| {
let vector = O::cast(collection.vector(vector));
O::proj(&projection, vector)
},
)
};

let projected_centroids = Vec2::from_vec(
(centroids.shape_0(), centroids.shape_1()),
(0..centroids.shape_0())
Expand All @@ -200,13 +212,15 @@ fn from_nothing<O: Op>(
projected_centroids,
);
let projection = Json::create(path.as_ref().join("projection"), projection);
let is_residual = Json::create(path.as_ref().join("is_residual"), is_residual);
Rabitq {
storage,
payloads,
offsets,
projected_centroids,
quantization,
projection,
is_residual,
}
}

Expand All @@ -217,17 +231,19 @@ fn open<O: Op>(path: impl AsRef<Path>) -> Rabitq<O> {
let offsets = Json::open(path.as_ref().join("offsets"));
let projected_centroids = Json::open(path.as_ref().join("projected_centroids"));
let projection = Json::open(path.as_ref().join("projection"));
let is_residual = Json::open(path.as_ref().join("is_residual"));
Rabitq {
storage,
quantization,
payloads,
offsets,
projected_centroids,
projection,
is_residual,
}
}

fn select(mut lists: Vec<(f32, usize)>, n: usize) -> Vec<(f32, usize)> {
fn select<T>(mut lists: Vec<(f32, T)>, n: usize) -> Vec<(f32, T)> {
if lists.is_empty() || n == 0 {
return Vec::new();
}
Expand Down
Loading

0 comments on commit bb46189

Please sign in to comment.