Skip to content

Commit c8c1689

Browse files
committed
fastest version so far
1 parent ca7b074 commit c8c1689

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

crates/famst/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ impl AnnGraph {
173173
}
174174

175175
/// Get mutable access to all neighbor chunks for parallel processing
176-
pub(crate) fn neighbors_chunks_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut [Neighbor]> {
177-
self.data.par_chunks_mut(self.k)
176+
pub(crate) fn neighbors_chunks_mut(&mut self, group_size: usize) -> impl IndexedParallelIterator<Item = &mut [Neighbor]> {
177+
self.data.par_chunks_mut(self.k * group_size)
178178
}
179179
}
180180

crates/famst/src/nn_descent.rs

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors,
9696

9797
// Only mark neighbors as old if they were selected into new_neighbors
9898
graph
99-
.neighbors_chunks_mut()
99+
.neighbors_chunks_mut(1)
100100
.enumerate()
101101
.for_each(|(i, neighbors)| {
102102
for &selected_id in new_neighbors.get(i) {
@@ -235,15 +235,18 @@ where
235235

236236
// For each point, generate candidates from neighbors of neighbors
237237
// Key optimization: only consider pairs where at least one is "new"
238-
let mut candidates: Vec<(u32, u32)> = (0..n)
238+
// Collect all candidates with their distances, including reverse edges
239+
let mut candidates: Vec<(u32, u32, f32)> = (0..n)
239240
.into_par_iter()
240241
.flat_map_iter(|i| {
241242
let old_i = old_neighbors.get(i);
242243
let new_i = new_neighbors.get(i);
243244

244245
// new-new pairs: (u, v) where u < v
245246
let new_new = new_i.iter().flat_map(|&u| {
246-
new_i.iter().filter_map(move |&v| if u < v { Some((u, v)) } else { None })
247+
new_i
248+
.iter()
249+
.filter_map(move |&v| if u < v { Some((u, v)) } else { None })
247250
});
248251

249252
// new-old pairs: (min, max) where u != v
@@ -257,26 +260,47 @@ where
257260
})
258261
});
259262

260-
new_new.chain(new_old)
263+
new_new
264+
.chain(new_old)
265+
.flat_map(|(u, v)| {
266+
let d = distance_fn(&data[u as usize], &data[v as usize]);
267+
// Insert both (u, v) and (v, u) so we can parallelize by node
268+
[(u, v, d), (v, u, d)]
269+
})
261270
})
262-
.take_any(n * k)
263271
.collect();
264-
candidates.par_sort_unstable();
272+
273+
// Sort by first node so candidates for each node are contiguous
274+
candidates.par_sort_unstable_by_key(|(a, _, _)| *a);
265275
candidates.dedup();
266276

267-
// Try to improve neighbors with candidates
268-
let mut updates = 0;
269-
for &(u, v) in &candidates {
270-
let d = distance_fn(&data[u as usize], &data[v as usize]);
271-
if insert_neighbor(graph.neighbors_mut(u as usize), NodeId::new(v), d) {
272-
updates += 1;
273-
}
274-
if insert_neighbor(graph.neighbors_mut(v as usize), NodeId::new(u), d) {
275-
updates += 1;
276-
}
277-
}
277+
// Process in parallel by chunks of nodes
278+
const CHUNK_SIZE: usize = 64;
279+
let updates: usize = graph
280+
.neighbors_chunks_mut(CHUNK_SIZE)
281+
.enumerate()
282+
.map(|(i, chunk_neighbors)| {
283+
let start_node = (CHUNK_SIZE * i) as u32;
284+
// Binary search to find the range of candidates for this node
285+
let start = candidates.partition_point(|&(u, _, _)| u < start_node);
286+
let mut count = 0;
287+
for &(u, v, d) in &candidates[start..] {
288+
if u >= start_node + CHUNK_SIZE as u32 {
289+
break;
290+
}
291+
let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k];
292+
if insert_neighbor(neighbors, NodeId::new(v), d) {
293+
count += 1;
294+
}
295+
}
296+
count
297+
})
298+
.sum();
278299

279-
println!("NN-Descent iteration {iter}: {updates} updates of {} candidates", candidates.len());
300+
println!(
301+
"NN-Descent iteration {iter}: {updates} updates of {} candidates",
302+
candidates.len(),
303+
);
280304

281305
// Early termination if no updates
282306
if updates == 0 {

0 commit comments

Comments
 (0)