Skip to content

Commit

Permalink
refactor: improve SQ&PQ scan performance (#598)
Browse files Browse the repository at this point in the history
Signed-off-by: usamoi <[email protected]>
  • Loading branch information
usamoi authored Sep 23, 2024
1 parent 1d723fe commit 64c5a7d
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 68 deletions.
84 changes: 48 additions & 36 deletions crates/quantization/src/product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,27 +463,33 @@ impl<S: ScalarLike> OperatorProductQuantization for VectDot<S> {
}
}
fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance {
fn internal<const BITS: u32, F>(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance
where
F: Fn(usize) -> usize,
{
let mut xy = 0.0f32;
for i in 0..dims.div_ceil(ratio) as usize {
xy += t[i * (1 << BITS) + f(i)];
#[inline(never)]
fn internal<const BITS: usize>(n: usize, lut: &[f32], code: &[u8]) -> Distance {
assert!(n >= 1);
assert!(n <= 65535);
assert!(code.len() == n / (8 / BITS));
assert!(lut.len() == n * (1 << BITS));
let mut sum = 0.0f32;
for i in 0..n {
unsafe {
// Safety: `i < n`
std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS));
}
let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS));
let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS);
unsafe {
// Safety: `i < n`, `j < (1 << BITS)`
std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS));
}
sum += lut[i * (1 << BITS) + j];
}
Distance::from(-xy)
Distance::from(-sum)
}
match bits {
1 => internal::<1, _>(dims, ratio, lut, |i| {
((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize
}),
2 => internal::<2, _>(dims, ratio, lut, |i| {
((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize
}),
4 => internal::<4, _>(dims, ratio, lut, |i| {
((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize
}),
8 => internal::<8, _>(dims, ratio, lut, |i| code[i] as usize),
1 => internal::<1>(dims.div_ceil(ratio) as _, lut, code),
2 => internal::<2>(dims.div_ceil(ratio) as _, lut, code),
4 => internal::<4>(dims.div_ceil(ratio) as _, lut, code),
8 => internal::<8>(dims.div_ceil(ratio) as _, lut, code),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -657,27 +663,33 @@ impl<S: ScalarLike> OperatorProductQuantization for VectL2<S> {
}
}
fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance {
fn internal<const BITS: u32, F>(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance
where
F: Fn(usize) -> usize,
{
let mut d2 = 0.0f32;
for i in 0..dims.div_ceil(ratio) as usize {
d2 += t[i * (1 << BITS) + f(i)];
#[inline(never)]
fn internal<const BITS: usize>(n: usize, lut: &[f32], code: &[u8]) -> Distance {
assert!(n >= 1);
assert!(n <= 65535);
assert!(code.len() == n / (8 / BITS));
assert!(lut.len() == n * (1 << BITS));
let mut sum = 0.0f32;
for i in 0..n {
unsafe {
// Safety: `i < n`
std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS));
}
let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS));
let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS);
unsafe {
// Safety: `i < n`, `j < (1 << BITS)`
std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS));
}
sum += lut[i * (1 << BITS) + j];
}
Distance::from(d2)
Distance::from(sum)
}
match bits {
1 => internal::<1, _>(dims, ratio, lut, |i| {
((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize
}),
2 => internal::<2, _>(dims, ratio, lut, |i| {
((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize
}),
4 => internal::<4, _>(dims, ratio, lut, |i| {
((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize
}),
8 => internal::<8, _>(dims, ratio, lut, |i| code[i] as usize),
1 => internal::<1>(dims.div_ceil(ratio) as _, lut, code),
2 => internal::<2>(dims.div_ceil(ratio) as _, lut, code),
4 => internal::<4>(dims.div_ceil(ratio) as _, lut, code),
8 => internal::<8>(dims.div_ceil(ratio) as _, lut, code),
_ => unreachable!(),
}
}
Expand Down
82 changes: 50 additions & 32 deletions crates/quantization/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,25 +376,34 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectDot<S> {
_ => unreachable!(),
}
}
fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance {
fn internal<const BITS: u32>(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance {
let mut xy = 0.0f32;
for i in 0..dims as usize {
xy += t[i * (1 << BITS) + f(i)];
fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance {
#[inline(never)]
fn internal<const BITS: usize>(n: usize, lut: &[f32], code: &[u8]) -> Distance {
assert!(n >= 1);
assert!(n <= 65535);
assert!(code.len() == n / (8 / BITS));
assert!(lut.len() == n * (1 << BITS));
let mut sum = 0.0f32;
for i in 0..n {
unsafe {
// Safety: `i < n`
std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS));
}
let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS));
let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS);
unsafe {
// Safety: `i < n`, `j < (1 << BITS)`
std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS));
}
sum += lut[i * (1 << BITS) + j];
}
Distance::from(-xy)
Distance::from(-sum)
}
match bits {
1 => internal::<1>(dims, lut, |i| {
((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize
}),
2 => internal::<2>(dims, lut, |i| {
((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize
}),
4 => internal::<4>(dims, lut, |i| {
((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize
}),
8 => internal::<8>(dims, lut, |i| rhs[i] as usize),
1 => internal::<1>(dims as _, lut, code),
2 => internal::<2>(dims as _, lut, code),
4 => internal::<4>(dims as _, lut, code),
8 => internal::<8>(dims as _, lut, code),
_ => unreachable!(),
}
}
Expand Down Expand Up @@ -466,25 +475,34 @@ impl<S: ScalarLike> OperatorScalarQuantization for VectL2<S> {
_ => unreachable!(),
}
}
fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance {
fn internal<const BITS: u32>(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance {
let mut d2 = 0.0f32;
for i in 0..dims as usize {
d2 += t[i * (1 << BITS) + f(i)];
fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance {
#[inline(never)]
fn internal<const BITS: usize>(n: usize, lut: &[f32], code: &[u8]) -> Distance {
assert!(n >= 1);
assert!(n <= 65535);
assert!(code.len() == n / (8 / BITS));
assert!(lut.len() == n * (1 << BITS));
let mut sum = 0.0f32;
for i in 0..n {
unsafe {
// Safety: `i < n`
std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS));
}
let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS));
let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS);
unsafe {
// Safety: `i < n`, `j < (1 << BITS)`
std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS));
}
sum += lut[i * (1 << BITS) + j];
}
Distance::from(d2)
Distance::from(sum)
}
match bits {
1 => internal::<1>(dims, lut, |i| {
((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize
}),
2 => internal::<2>(dims, lut, |i| {
((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize
}),
4 => internal::<4>(dims, lut, |i| {
((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize
}),
8 => internal::<8>(dims, lut, |i| rhs[i] as usize),
1 => internal::<1>(dims as _, lut, code),
2 => internal::<2>(dims as _, lut, code),
4 => internal::<4>(dims as _, lut, code),
8 => internal::<8>(dims as _, lut, code),
_ => unreachable!(),
}
}
Expand Down

0 comments on commit 64c5a7d

Please sign in to comment.