Skip to content

Commit 8eb4090

Browse files
committed
Update nn_descent.rs
1 parent c8c1689 commit 8eb4090

File tree

1 file changed

+77
-65
lines changed

1 file changed

+77
-65
lines changed

crates/famst/src/nn_descent.rs

Lines changed: 77 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -232,78 +232,90 @@ where
232232
// Build combined neighbor lists (forward + reverse, with reservoir sampling)
233233
// Also marks all neighbors as old for next iteration
234234
let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng);
235-
236-
// For each point, generate candidates from neighbors of neighbors
237-
// Key optimization: only consider pairs where at least one is "new"
238-
// Collect all candidates with their distances, including reverse edges
239-
let mut candidates: Vec<(u32, u32, f32)> = (0..n)
240-
.into_par_iter()
241-
.flat_map_iter(|i| {
242-
let old_i = old_neighbors.get(i);
243-
let new_i = new_neighbors.get(i);
244-
245-
// new-new pairs: (u, v) where u < v
246-
let new_new = new_i.iter().flat_map(|&u| {
247-
new_i
248-
.iter()
249-
.filter_map(move |&v| if u < v { Some((u, v)) } else { None })
250-
});
251-
252-
// new-old pairs: (min, max) where u != v
253-
let new_old = new_i.iter().flat_map(|&u| {
254-
old_i.iter().filter_map(move |&v| {
255-
if u != v {
256-
Some((u.min(v), u.max(v)))
257-
} else {
258-
None
235+
println!(" Built neighbor lists");
236+
237+
// Process in batches of central nodes to limit memory usage
238+
let batch_size: usize = n / k;
239+
let mut total_updates = 0;
240+
let mut total_candidates = 0;
241+
242+
for batch_start in (0..n).step_by(batch_size) {
243+
let batch_end = (batch_start + batch_size).min(n);
244+
245+
// For each point in this batch, generate candidates from neighbors of neighbors
246+
// Key optimization: only consider pairs where at least one is "new"
247+
// Collect all candidates with their distances, including reverse edges
248+
let mut candidates: Vec<(u32, u32, f32)> = (batch_start..batch_end)
249+
.into_par_iter()
250+
.flat_map_iter(|i| {
251+
let old_i = old_neighbors.get(i);
252+
let new_i = new_neighbors.get(i);
253+
254+
// new-new pairs: (u, v) where u < v
255+
let new_new = new_i.iter().flat_map(|&u| {
256+
new_i
257+
.iter()
258+
.filter_map(move |&v| if u < v { Some((u, v)) } else { None })
259+
});
260+
261+
// new-old pairs: (min, max) where u != v
262+
let new_old = new_i.iter().flat_map(|&u| {
263+
old_i.iter().filter_map(move |&v| {
264+
if u != v {
265+
Some((u.min(v), u.max(v)))
266+
} else {
267+
None
268+
}
269+
})
270+
});
271+
272+
new_new
273+
.chain(new_old)
274+
.flat_map(|(u, v)| {
275+
let d = distance_fn(&data[u as usize], &data[v as usize]);
276+
// Insert both (u, v) and (v, u) so we can parallelize by node
277+
[(u, v, d), (v, u, d)]
278+
})
279+
})
280+
.collect();
281+
282+
// Sort by first node so candidates for each node are contiguous
283+
candidates.par_sort_unstable_by_key(|(a, _, _)| *a);
284+
candidates.dedup();
285+
total_candidates += candidates.len();
286+
287+
// Process in parallel by chunks of nodes
288+
const CHUNK_SIZE: usize = 64;
289+
let updates: usize = graph
290+
.neighbors_chunks_mut(CHUNK_SIZE)
291+
.enumerate()
292+
.map(|(i, chunk_neighbors)| {
293+
let start_node = (CHUNK_SIZE * i) as u32;
294+
// Binary search to find the range of candidates for this chunk
295+
let start = candidates.partition_point(|&(u, _, _)| u < start_node);
296+
let mut count = 0;
297+
for &(u, v, d) in &candidates[start..] {
298+
if u >= start_node + CHUNK_SIZE as u32 {
299+
break;
300+
}
301+
let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k];
302+
if insert_neighbor(neighbors, NodeId::new(v), d) {
303+
count += 1;
259304
}
260-
})
261-
});
262-
263-
new_new
264-
.chain(new_old)
265-
.flat_map(|(u, v)| {
266-
let d = distance_fn(&data[u as usize], &data[v as usize]);
267-
// Insert both (u, v) and (v, u) so we can parallelize by node
268-
[(u, v, d), (v, u, d)]
269-
})
270-
})
271-
.collect();
272-
273-
// Sort by first node so candidates for each node are contiguous
274-
candidates.par_sort_unstable_by_key(|(a, _, _)| *a);
275-
candidates.dedup();
276-
277-
// Process in parallel by chunks of nodes
278-
const CHUNK_SIZE: usize = 64;
279-
let updates: usize = graph
280-
.neighbors_chunks_mut(CHUNK_SIZE)
281-
.enumerate()
282-
.map(|(i, chunk_neighbors)| {
283-
let start_node = (CHUNK_SIZE * i) as u32;
284-
// Binary search to find the range of candidates for this node
285-
let start = candidates.partition_point(|&(u, _, _)| u < start_node);
286-
let mut count = 0;
287-
for &(u, v, d) in &candidates[start..] {
288-
if u >= start_node + CHUNK_SIZE as u32 {
289-
break;
290-
}
291-
let neighbors = &mut chunk_neighbors[(u - start_node) as usize * k..][..k];
292-
if insert_neighbor(neighbors, NodeId::new(v), d) {
293-
count += 1;
294305
}
295-
}
296-
count
297-
})
298-
.sum();
306+
count
307+
})
308+
.sum();
309+
310+
total_updates += updates;
311+
}
299312

300313
println!(
301-
"NN-Descent iteration {iter}: {updates} updates of {} candidates",
302-
candidates.len(),
314+
"NN-Descent iteration {iter}: {total_updates} updates of {total_candidates} candidates",
303315
);
304316

305317
// Early termination if no updates
306-
if updates == 0 {
318+
if total_updates == 0 {
307319
break;
308320
}
309321
}

0 commit comments

Comments
 (0)