-
Notifications
You must be signed in to change notification settings - Fork 179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: replicate boundary vectors to multiple partitions #2258
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,16 +6,19 @@ | |
use std::ops::Range; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::{ArrayBuilder, ListBuilder, UInt32Builder, UInt64Builder}; | ||
use arrow_array::types::UInt32Type; | ||
use arrow_array::{cast::AsArray, Array, ArrowPrimitiveType, RecordBatch, UInt32Array}; | ||
use arrow_array::{GenericListArray, ListArray, UInt64Array}; | ||
use arrow_schema::Field; | ||
use futures::{stream, StreamExt}; | ||
use lance_linalg::kmeans::compute_multiple_partitions; | ||
use log::info; | ||
use snafu::{location, Location}; | ||
use tracing::{instrument, Instrument}; | ||
|
||
use lance_arrow::{ArrowFloatType, RecordBatchExt}; | ||
use lance_core::Result; | ||
use lance_core::{Result, ROW_ID}; | ||
use lance_linalg::distance::{Dot, MetricType, L2}; | ||
use lance_linalg::MatrixView; | ||
|
||
|
@@ -36,6 +39,8 @@ | |
metric_type: MetricType, | ||
input_column: String, | ||
output_column: String, | ||
replicate_factor: f32, | ||
max_replica: usize, | ||
} | ||
|
||
impl<T: ArrowFloatType + L2 + Dot> IvfTransformer<T> { | ||
|
@@ -49,6 +54,8 @@ | |
metric_type, | ||
input_column: input_column.as_ref().to_owned(), | ||
output_column: PART_ID_COLUMN.to_owned(), | ||
replicate_factor: 1.02, | ||
max_replica: 8, | ||
} | ||
} | ||
|
||
|
@@ -106,6 +113,80 @@ | |
|
||
UInt32Array::from_iter(result.iter().flatten().copied()) | ||
} | ||
|
||
/// Compute the partition for each row in the input Matrix. | ||
/// | ||
#[instrument(level = "debug", skip(data))] | ||
pub(super) async fn compute_multiple_partitions( | ||
&self, | ||
data: &MatrixView<T>, | ||
) -> GenericListArray<i32> { | ||
let dimension = data.ndim(); | ||
let centroids = self.centroids.data(); | ||
let data = data.data(); | ||
let metric_type = self.metric_type; | ||
|
||
let num_centroids = centroids.len() / dimension; | ||
let num_rows = data.len() / dimension; | ||
|
||
let chunks = std::cmp::min(num_cpus::get(), num_rows); | ||
|
||
info!( | ||
"computing partition on {} chunks, out of {} centroids, and {} vectors", | ||
chunks, num_centroids, num_rows, | ||
); | ||
|
||
let chunk_size = num_rows.div_ceil(chunks); | ||
let stride = chunk_size * dimension; | ||
|
||
let result: Vec<Vec<Option<Vec<u32>>>> = stream::iter(0..chunks) | ||
.map(|chunk_id| stride * chunk_id..std::cmp::min(stride * (chunk_id + 1), data.len())) | ||
// When there are a large number of CPUs and a small number of rows, | ||
// it's possible there isn't an split of rows that there isn't | ||
// an even split of rows that both covers all CPUs and all rows. | ||
// For example, for 400 rows and 32 CPUs, 12-element chunks (12 * 32 = 384) | ||
// wouldn't cover all rows but 13-element chunks (13 * 32 = 416) would | ||
// have one empty chunk at the end. This filter removes those empty chunks. | ||
.filter(|range| futures::future::ready(range.start < range.end)) | ||
.map(|range| async { | ||
let range: Range<usize> = range; | ||
let centroids = centroids.clone(); | ||
let data = Arc::new( | ||
data.slice(range.start, range.end - range.start) | ||
.as_any() | ||
.downcast_ref::<T::ArrayType>() | ||
.unwrap() | ||
.clone(), | ||
); | ||
|
||
compute_multiple_partitions::<T>( | ||
centroids, | ||
data, | ||
dimension, | ||
metric_type, | ||
self.replicate_factor, | ||
self.max_replica, | ||
) | ||
.in_current_span() | ||
.await | ||
}) | ||
.buffered(chunks) | ||
.collect::<Vec<_>>() | ||
.await; | ||
let result = result.into_iter().flatten().collect::<Vec<_>>(); | ||
|
||
let mut builder = ListBuilder::new(UInt32Builder::new()); | ||
for part_ids in result { | ||
match part_ids { | ||
Some(part_ids) => { | ||
builder.append_value(part_ids.into_iter().map(|x| Some(x))); | ||
} | ||
None => builder.append_null(), | ||
} | ||
} | ||
|
||
builder.finish() | ||
} | ||
} | ||
|
||
#[async_trait::async_trait] | ||
|
@@ -137,9 +218,31 @@ | |
})?; | ||
|
||
let mat = MatrixView::<T>::try_from(fsl)?; | ||
let part_ids = self.compute_partitions(&mat).await; | ||
let field = Field::new(PART_ID_COLUMN, part_ids.data_type().clone(), true); | ||
Ok(batch.try_with_column(field, Arc::new(part_ids))?) | ||
let part_ids = self.compute_multiple_partitions(&mat).await; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, maybe we should use a different transformer for search? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we use transformer for search? |
||
let row_ids = batch | ||
.column_by_name(ROW_ID) | ||
.unwrap() | ||
.as_any() | ||
.downcast_ref::<UInt64Array>() | ||
.unwrap(); | ||
|
||
let total_num_rows = part_ids.values().len(); | ||
let mut row_id_builder = UInt64Builder::with_capacity(total_num_rows); | ||
let mut part_id_builder = UInt32Builder::with_capacity(total_num_rows); | ||
for i in 0..row_ids.len() { | ||
let part_ids = part_ids.value(i); | ||
let part_ids: &UInt32Array = part_ids.as_primitive(); | ||
for part_id in part_ids { | ||
row_id_builder.append_value(row_ids.value(i)); | ||
part_id_builder.append_value(part_id.unwrap()); | ||
} | ||
} | ||
|
||
let field = Field::new(PART_ID_COLUMN, part_ids.value_type().clone(), true); | ||
Ok(batch | ||
.drop_column(&self.input_column)? | ||
.replace_column_by_name(ROW_ID, Arc::new(row_id_builder.finish()))? | ||
.try_with_column(field, Arc::new(part_id_builder.finish()))?) | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we use this in search? or just in build time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only in build time