Skip to content

Commit 48236b2

Browse files
committed
flatten x binary u64
Signed-off-by: Keming <[email protected]>
1 parent 37c688f commit 48236b2

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/rabitq.rs

+15-7
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub struct RaBitQ {
2727
centroids: Mat<f32>,
2828
offsets: Vec<u32>,
2929
map_ids: Vec<u32>,
30-
x_binary_vec: Vec<Vec<u64>>,
30+
x_binary_vec: Vec<u64>,
3131
x_c_distance_square: Vec<f32>,
3232
error_bound: Vec<f32>,
3333
factor_ip: Vec<f32>,
@@ -63,8 +63,11 @@ impl RaBitQ {
6363
let error_bound = factors[2].clone();
6464
let x_c_distance_square = factors[3].clone();
6565

66-
let x_binary_vec =
67-
read_u64_vecs(&path.join("x_binary_vec.u64vecs")).expect("open x_binary_vec error");
66+
let x_binary_vec = read_u64_vecs(&path.join("x_binary_vec.u64vecs"))
67+
.expect("open x_binary_vec error")
68+
.into_iter()
69+
.flatten()
70+
.collect();
6871

6972
let dim = orthogonal.nrows();
7073
let base = matrix_from_fvecs(&path.join("base.fvecs"))
@@ -113,7 +116,8 @@ impl RaBitQ {
113116
.expect("write factors error");
114117
write_vecs(
115118
&path.join("x_binary_vec.u64vecs"),
116-
&self.x_binary_vec.iter().collect::<Vec<_>>(),
119+
// &self.x_binary_vec.iter().collect::<Vec<_>>(),
120+
&[&self.x_binary_vec],
117121
)
118122
.expect("write x_binary_vec error");
119123
}
@@ -187,7 +191,7 @@ impl RaBitQ {
187191
.to_owned();
188192
let x_binary_vec = flat_labels
189193
.iter()
190-
.map(|i| x_binary_vec[*i as usize].clone())
194+
.flat_map(|i| x_binary_vec[*i as usize].clone())
191195
.collect();
192196
let x_c_distance_square = flat_labels
193197
.iter()
@@ -227,6 +231,7 @@ impl RaBitQ {
227231
topk: usize,
228232
heuristic_rank: bool,
229233
) -> Vec<(f32, u32)> {
234+
assert_eq!(self.dim as usize, query.nrows());
230235
let y_projected = project(query, &self.orthogonal.as_ref());
231236
let k = self.centroids.shape().1;
232237
let mut lists = Vec::with_capacity(k);
@@ -287,15 +292,18 @@ impl RaBitQ {
287292
rough_distances: &mut Vec<(f32, u32)>,
288293
) {
289294
let dist_sqrt = y_c_distance_square.sqrt();
295+
let binary_offset = y_binary_vec.len() / THETA_LOG_DIM as usize;
290296
for j in self.offsets[cluster_id]..self.offsets[cluster_id + 1] {
291297
let ju = j as usize;
292298
rough_distances.push((
293299
(self.x_c_distance_square[ju]
294300
+ y_c_distance_square
295301
+ lower_bound * self.factor_ppc[ju]
296302
+ (2.0
297-
* asymmetric_binary_dot_product(&self.x_binary_vec[ju], y_binary_vec)
298-
as f32
303+
* asymmetric_binary_dot_product(
304+
&self.x_binary_vec[ju * binary_offset..(ju + 1) * binary_offset],
305+
y_binary_vec,
306+
) as f32
299307
- scalar_sum)
300308
* self.factor_ip[ju]
301309
* delta

0 commit comments

Comments
 (0)