Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: replicate boundary vectors to multiple partitions #2258

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
111 changes: 107 additions & 4 deletions rust/lance-index/src/vector/ivf/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,19 @@
use std::ops::Range;
use std::sync::Arc;

use arrow::array::{ArrayBuilder, ListBuilder, UInt32Builder, UInt64Builder};

Check warning on line 9 in rust/lance-index/src/vector/ivf/transform.rs

View workflow job for this annotation

GitHub Actions / linux-build (nightly)

unused import: `ArrayBuilder`
use arrow_array::types::UInt32Type;
use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array};
use arrow_array::{GenericListArray, ListArray, UInt64Array};

Check warning on line 12 in rust/lance-index/src/vector/ivf/transform.rs

View workflow job for this annotation

GitHub Actions / linux-build (nightly)

unused import: `ListArray`
use arrow_schema::Field;
use futures::{stream, StreamExt};
use lance_linalg::kmeans::compute_multiple_partitions;
use log::info;
use snafu::{location, Location};
use tracing::{instrument, Instrument};

use lance_arrow::{ArrowFloatType, RecordBatchExt};
use lance_core::Result;
use lance_core::{Result, ROW_ID};
use lance_linalg::distance::{Dot, MetricType, L2};
use lance_linalg::MatrixView;

Expand All @@ -36,6 +39,8 @@
metric_type: MetricType,
input_column: String,
output_column: String,
replicate_factor: f32,
max_replica: usize,
}

impl<T: ArrowFloatType + L2 + Dot> IvfTransformer<T> {
Expand All @@ -49,6 +54,8 @@
metric_type,
input_column: input_column.as_ref().to_owned(),
output_column: PART_ID_COLUMN.to_owned(),
replicate_factor: 1.02,
max_replica: 8,
}
}

Expand Down Expand Up @@ -106,6 +113,80 @@

UInt32Array::from_iter(result.iter().flatten().copied())
}

/// Compute the partition for each row in the input Matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we use this in search? or just in build time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only in build time

///
#[instrument(level = "debug", skip(data))]
pub(super) async fn compute_multiple_partitions(
&self,
data: &MatrixView<T>,
) -> GenericListArray<i32> {
let dimension = data.ndim();
let centroids = self.centroids.data();
let data = data.data();
let metric_type = self.metric_type;

let num_centroids = centroids.len() / dimension;
let num_rows = data.len() / dimension;

let chunks = std::cmp::min(num_cpus::get(), num_rows);

info!(
"computing partition on {} chunks, out of {} centroids, and {} vectors",
chunks, num_centroids, num_rows,
);

let chunk_size = num_rows.div_ceil(chunks);
let stride = chunk_size * dimension;

let result: Vec<Vec<Option<Vec<u32>>>> = stream::iter(0..chunks)
.map(|chunk_id| stride * chunk_id..std::cmp::min(stride * (chunk_id + 1), data.len()))
// When there are a large number of CPUs and a small number of rows,
// it's possible there isn't an split of rows that there isn't
// an even split of rows that both covers all CPUs and all rows.
// For example, for 400 rows and 32 CPUs, 12-element chunks (12 * 32 = 384)
// wouldn't cover all rows but 13-element chunks (13 * 32 = 416) would
// have one empty chunk at the end. This filter removes those empty chunks.
.filter(|range| futures::future::ready(range.start < range.end))
.map(|range| async {
let range: Range<usize> = range;
let centroids = centroids.clone();
let data = Arc::new(
data.slice(range.start, range.end - range.start)
.as_any()
.downcast_ref::<T::ArrayType>()
.unwrap()
.clone(),
);

compute_multiple_partitions::<T>(
centroids,
data,
dimension,
metric_type,
self.replicate_factor,
self.max_replica,
)
.in_current_span()
.await
})
.buffered(chunks)
.collect::<Vec<_>>()
.await;
let result = result.into_iter().flatten().collect::<Vec<_>>();

let mut builder = ListBuilder::new(UInt32Builder::new());
for part_ids in result {
match part_ids {
Some(part_ids) => {
builder.append_value(part_ids.into_iter().map(|x| Some(x)));
}
None => builder.append_null(),
}
}

builder.finish()
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -137,9 +218,31 @@
})?;

let mat = MatrixView::<T>::try_from(fsl)?;
let part_ids = self.compute_partitions(&mat).await;
let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true);
Ok(batch.try_with_column(field, Arc::new(part_ids))?)
let part_ids = self.compute_multiple_partitions(&mat).await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, maybe we should use a different transformer for search?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we use transformer for search?
the code path should be IVFIndex::find_partitions -> IvfImpl::find_partitions -> kmeans::find_partitions right?

let row_ids = batch
.column_by_name(ROW_ID)
.unwrap()
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();

let total_num_rows = part_ids.values().len();
let mut row_id_builder = UInt64Builder::with_capacity(total_num_rows);
let mut part_id_builder = UInt32Builder::with_capacity(total_num_rows);
for i in 0..row_ids.len() {
let part_ids = part_ids.value(i);
let part_ids: &UInt32Array = part_ids.as_primitive();
for part_id in part_ids {
row_id_builder.append_value(row_ids.value(i));
part_id_builder.append_value(part_id.unwrap());
}
}

let field = Field::new(PART_ID_COLUMN, part_ids.value_type().clone(), true);
Ok(batch
.drop_column(&self.input_column)?
.replace_column_by_name(ROW_ID, Arc::new(row_id_builder.finish()))?
.try_with_column(field, Arc::new(part_id_builder.finish()))?)
}
}

Expand Down
132 changes: 132 additions & 0 deletions rust/lance-linalg/src/kmeans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,21 @@
}
}

pub struct KMeanMultipleMembership {
dimension: usize,

Check warning on line 284 in rust/lance-linalg/src/kmeans.rs

View workflow job for this annotation

GitHub Actions / linux-build (stable)

fields `dimension`, `k`, and `metric_type` are never read

Check warning on line 284 in rust/lance-linalg/src/kmeans.rs

View workflow job for this annotation

GitHub Actions / linux-build (nightly)

fields `dimension`, `k`, and `metric_type` are never read

/// Cluster Id and distance for each vector.
///
/// If it is None, means the assignment is not valid, i.e., input vectors might
/// be all `NaN`.
pub assignments: Vec<Option<Vec<(u32, f32)>>>,

/// Number of centroids.
k: usize,

metric_type: MetricType,
}

impl<T: ArrowFloatType> KMeans<T>
where
T: L2 + Dot + Normalize,
Expand Down Expand Up @@ -508,6 +523,61 @@
}
}

pub async fn compute_multiple_memberships(
&self,
data: Arc<T::ArrayType>,
replicate_factor: f32,
max_replica: usize,
) -> KMeanMultipleMembership {
let dimension = self.dimension;
let n = data.len() / self.dimension;
let metric_type = self.metric_type;
const CHUNK_SIZE: usize = 1024;

let assignments = stream::iter((0..n).step_by(CHUNK_SIZE))
// make tiles of input data to split between threads.
.zip(repeat_with(|| (data.clone(), self.centroids.clone())))
.map(|(start_idx, (data, centroids))| async move {
let data = tokio::task::spawn_blocking(move || {
let last_idx = min(start_idx + CHUNK_SIZE, n);

let centroids_array = centroids.as_slice();
let values = &data.as_slice()[start_idx * dimension..last_idx * dimension];

match metric_type {
MetricType::L2 => {
return compute_multiple_partitions_l2(centroids_array, values, dimension,replicate_factor, max_replica)
}
// MetricType::Dot => values
// .chunks_exact(dimension)
// .map(|vector| {
// let centroid_stream = centroids_array.chunks_exact(dimension);
// argmin_value(centroid_stream.map(|cent| dot_distance(vector, cent)))
// })
// .collect::<Vec<_>>(),
_ => {
panic!("KMeans: should not use cosine distance to train kmeans, use L2 instead.");
}
}
})
.await
.map_err(|e| {
ArrowError::ComputeError(format!("KMeans: failed to compute membership: {}", e))
})?;
Ok::<Vec<_>, Error>(data)
})
.buffered(num_cpus::get())
.try_collect::<Vec<_>>()
.await
.unwrap();
KMeanMultipleMembership {
dimension,
assignments: assignments.iter().flatten().cloned().collect(),
k: self.k,
metric_type: self.metric_type,
}
}

pub fn find_partitions(&self, query: &[T::Native], nprobes: usize) -> Result<UInt32Array> {
if query.len() != self.dimension {
return Err(Error::InvalidArgumentError(format!(
Expand Down Expand Up @@ -635,6 +705,41 @@
Box::new(stream)
}

fn compute_multiple_partitions_l2<'a, T: FloatToArrayType>(
centroids: &'a [T],
data: &'a [T],
dim: usize,
replicate_factor: f32,
max_replica: usize,
) -> Vec<Option<Vec<(u32, f32)>>>
where
T::ArrowType: L2,
{
data.chunks(dim)
.map(|row| {
let mut dists = l2_distance_batch(row, centroids, dim)
.enumerate()
.map(|(i, d)| (i as u32, d))
.collect::<Vec<_>>();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
let min_dist = dists[0].1;
if min_dist.is_nan() {
return None;
}
let mut keep_length = 1;
while keep_length < max_replica
&& keep_length < dists.len()
&& dists[keep_length].1 < min_dist * replicate_factor
{
keep_length += 1;
}

dists.truncate(keep_length);
Some(dists)
})
.collect()
}

/// Compute partition ID of each vector in the KMeans.
///
/// If returns `None`, means the vector is not valid, i.e., all `NaN`.
Expand All @@ -656,6 +761,33 @@
.collect()
}

/// Compute partition IDs of each vector in the KMeans.
pub async fn compute_multiple_partitions<T: ArrowFloatType>(
centroids: Arc<T::ArrayType>,
vectors: Arc<T::ArrayType>,
dimension: usize,
metric_type: MetricType,
replicate_factor: f32,
max_replica: usize,
) -> Vec<Option<Vec<u32>>>
where
<T::Native as FloatToArrayType>::ArrowType: Dot + L2 + Normalize,
{
let kmeans: KMeans<T> = KMeans::with_centroids(centroids, dimension, metric_type);
let membership = kmeans
.compute_multiple_memberships(vectors, replicate_factor, max_replica)
.await;
membership
.assignments
.iter()
.map(|assignment| {
assignment
.as_ref()
.map(|a| a.iter().map(|(c, _)| *c).collect())
})
.collect()
}

#[cfg(test)]
mod tests {

Expand Down