From 64c5a7d92eccf0fde27d09465d614b1eefc02c87 Mon Sep 17 00:00:00 2001 From: usamoi Date: Mon, 23 Sep 2024 19:15:29 +0800 Subject: [PATCH] refactor: improve SQ&PQ scan performance (#598) Signed-off-by: usamoi --- crates/quantization/src/product.rs | 84 +++++++++++++++++------------- crates/quantization/src/scalar.rs | 82 +++++++++++++++++------------ 2 files changed, 98 insertions(+), 68 deletions(-) diff --git a/crates/quantization/src/product.rs b/crates/quantization/src/product.rs index 8808cbbf..ac5df845 100644 --- a/crates/quantization/src/product.rs +++ b/crates/quantization/src/product.rs @@ -463,27 +463,33 @@ impl OperatorProductQuantization for VectDot { } } fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { - fn internal(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance - where - F: Fn(usize) -> usize, - { - let mut xy = 0.0f32; - for i in 0..dims.div_ceil(ratio) as usize { - xy += t[i * (1 << BITS) + f(i)]; + #[inline(never)] + fn internal(n: usize, lut: &[f32], code: &[u8]) -> Distance { + assert!(n >= 1); + assert!(n <= 65535); + assert!(code.len() == n / (8 / BITS)); + assert!(lut.len() == n * (1 << BITS)); + let mut sum = 0.0f32; + for i in 0..n { + unsafe { + // Safety: `i < n` + std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS)); + } + let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS)); + let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS); + unsafe { + // Safety: `i < n`, `j < (1 << BITS)` + std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS)); + } + sum += lut[i * (1 << BITS) + j]; } - Distance::from(-xy) + Distance::from(-sum) } match bits { - 1 => internal::<1, _>(dims, ratio, lut, |i| { - ((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize - }), - 2 => internal::<2, _>(dims, ratio, lut, |i| { - ((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize - }), - 4 => internal::<4, _>(dims, ratio, lut, |i| { - ((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize - }), - 8 => internal::<8, _>(dims, ratio, lut, |i| code[i] as usize), + 1 => internal::<1>(dims.div_ceil(ratio) as _, lut, code), + 2 => internal::<2>(dims.div_ceil(ratio) as _, lut, code), + 4 => internal::<4>(dims.div_ceil(ratio) as _, lut, code), + 8 => internal::<8>(dims.div_ceil(ratio) as _, lut, code), _ => unreachable!(), } } @@ -657,27 +663,33 @@ impl OperatorProductQuantization for VectL2 { } } fn process(dims: u32, ratio: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { - fn internal(dims: u32, ratio: u32, t: &[f32], f: F) -> Distance - where - F: Fn(usize) -> usize, - { - let mut d2 = 0.0f32; - for i in 0..dims.div_ceil(ratio) as usize { - d2 += t[i * (1 << BITS) + f(i)]; + #[inline(never)] + fn internal(n: usize, lut: &[f32], code: &[u8]) -> Distance { + assert!(n >= 1); + assert!(n <= 65535); + assert!(code.len() == n / (8 / BITS)); + assert!(lut.len() == n * (1 << BITS)); + let mut sum = 0.0f32; + for i in 0..n { + unsafe { + // Safety: `i < n` + std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS)); + } + let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS)); + let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS); + unsafe { + // Safety: `i < n`, `j < (1 << BITS)` + std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS)); + } + sum += lut[i * (1 << BITS) + j]; } - Distance::from(d2) + Distance::from(sum) } match bits { - 1 => internal::<1, _>(dims, ratio, lut, |i| { - ((code[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize - }), - 2 => internal::<2, _>(dims, ratio, lut, |i| { - ((code[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize - }), - 4 => internal::<4, _>(dims, ratio, lut, |i| { - ((code[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize - }), - 8 => internal::<8, _>(dims, ratio, lut, |i| code[i] as usize), + 1 => internal::<1>(dims.div_ceil(ratio) as _, lut, code), + 2 => internal::<2>(dims.div_ceil(ratio) as _, lut, code), + 4 => internal::<4>(dims.div_ceil(ratio) as _, lut, code), + 8 => internal::<8>(dims.div_ceil(ratio) as _, lut, code), _ => unreachable!(), } } diff --git a/crates/quantization/src/scalar.rs b/crates/quantization/src/scalar.rs index 9f691808..cf625e7a 100644 --- a/crates/quantization/src/scalar.rs +++ b/crates/quantization/src/scalar.rs @@ -376,25 +376,34 @@ impl OperatorScalarQuantization for VectDot { _ => unreachable!(), } } - fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { - fn internal(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { - let mut xy = 0.0f32; - for i in 0..dims as usize { - xy += t[i * (1 << BITS) + f(i)]; + fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { + #[inline(never)] + fn internal(n: usize, lut: &[f32], code: &[u8]) -> Distance { + assert!(n >= 1); + assert!(n <= 65535); + assert!(code.len() == n / (8 / BITS)); + assert!(lut.len() == n * (1 << BITS)); + let mut sum = 0.0f32; + for i in 0..n { + unsafe { + // Safety: `i < n` + std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS)); + } + let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS)); + let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS); + unsafe { + // Safety: `i < n`, `j < (1 << BITS)` + std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS)); + } + sum += lut[i * (1 << BITS) + j]; } - Distance::from(-xy) + Distance::from(-sum) } match bits { - 1 => internal::<1>(dims, lut, |i| { - ((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize - }), - 2 => internal::<2>(dims, lut, |i| { - ((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize - }), - 4 => internal::<4>(dims, lut, |i| { - ((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize - }), - 8 => internal::<8>(dims, lut, |i| rhs[i] as usize), + 1 => internal::<1>(dims as _, lut, code), + 2 => internal::<2>(dims as _, lut, code), + 4 => internal::<4>(dims as _, lut, code), + 8 => internal::<8>(dims as _, lut, code), _ => unreachable!(), } } @@ -466,25 +475,34 @@ impl OperatorScalarQuantization for VectL2 { _ => unreachable!(), } } - fn process(dims: u32, bits: u32, lut: &[f32], rhs: &[u8]) -> Distance { - fn internal(dims: u32, t: &[f32], f: impl Fn(usize) -> usize) -> Distance { - let mut d2 = 0.0f32; - for i in 0..dims as usize { - d2 += t[i * (1 << BITS) + f(i)]; + fn process(dims: u32, bits: u32, lut: &[f32], code: &[u8]) -> Distance { + #[inline(never)] + fn internal(n: usize, lut: &[f32], code: &[u8]) -> Distance { + assert!(n >= 1); + assert!(n <= 65535); + assert!(code.len() == n / (8 / BITS)); + assert!(lut.len() == n * (1 << BITS)); + let mut sum = 0.0f32; + for i in 0..n { + unsafe { + // Safety: `i < n` + std::hint::assert_unchecked(i / (8 / BITS) < n / (8 / BITS)); + } + let (alpha, beta) = (i / (8 / BITS), i % (8 / BITS)); + let j = (code[alpha] >> (beta * BITS)) as usize % (1 << BITS); + unsafe { + // Safety: `i < n`, `j < (1 << BITS)` + std::hint::assert_unchecked(i * (1 << BITS) + j < n * (1 << BITS)); + } + sum += lut[i * (1 << BITS) + j]; } - Distance::from(d2) + Distance::from(sum) } match bits { - 1 => internal::<1>(dims, lut, |i| { - ((rhs[i >> 3] >> ((i & 7) << 0)) & 1u8) as usize - }), - 2 => internal::<2>(dims, lut, |i| { - ((rhs[i >> 2] >> ((i & 3) << 1)) & 3u8) as usize - }), - 4 => internal::<4>(dims, lut, |i| { - ((rhs[i >> 1] >> ((i & 1) << 2)) & 15u8) as usize - }), - 8 => internal::<8>(dims, lut, |i| rhs[i] as usize), + 1 => internal::<1>(dims as _, lut, code), + 2 => internal::<2>(dims as _, lut, code), + 4 => internal::<4>(dims as _, lut, code), + 8 => internal::<8>(dims as _, lut, code), _ => unreachable!(), } }