Skip to content

Commit 5477533

Browse files
committed
Update nn_descent.rs
1 parent d4e68c7 commit 5477533

File tree

1 file changed

+64
-90
lines changed

1 file changed

+64
-90
lines changed

crates/famst/src/nn_descent.rs

Lines changed: 64 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -14,48 +14,54 @@ use std::collections::HashSet;
1414

1515
use crate::{AnnGraph, FamstConfig, Neighbor, NodeId};
1616

17-
/// Reverse neighbor lists with bounded size k per node.
18-
/// Uses flat storage: data[i*k..(i+1)*k] contains up to k reverse neighbors of node i.
19-
/// Uses reservoir sampling when more than k reverse edges exist.
20-
struct ReverseNeighbors {
21-
/// Flat storage of reverse neighbor IDs (with new flag preserved)
17+
/// Neighbor lists with bounded size k per node.
18+
/// Uses flat storage: data[i*k..(i+1)*k] contains up to k neighbors of node i.
19+
/// Uses reservoir sampling when more than k neighbors exist.
20+
struct Neighbors {
21+
/// Flat storage of neighbor IDs
2222
data: Vec<NodeId>,
23-
/// Count of reverse neighbors seen so far (for reservoir sampling)
23+
/// Count of neighbors seen so far (for reservoir sampling)
2424
counts: Vec<u32>,
25-
/// Max reverse neighbors per node
25+
/// Max neighbors per node
2626
k: usize,
2727
}
2828

29-
impl ReverseNeighbors {
29+
impl Neighbors {
3030
fn new(n: usize, k: usize) -> Self {
31-
ReverseNeighbors {
31+
Neighbors {
3232
data: vec![NodeId::new(0); n * k],
3333
counts: vec![0; n],
3434
k,
3535
}
3636
}
3737

38-
/// Add a reverse edge: node `from` is a neighbor of node `to`, so `to` has reverse edge to `from`.
39-
/// Uses reservoir sampling to maintain at most k reverse neighbors.
38+
/// Add a neighbor using reservoir sampling to maintain at most k neighbors.
39+
/// Skips if the neighbor is already present.
4040
#[inline]
41-
fn add(&mut self, to: usize, from: NodeId, rng: &mut impl Rng) {
42-
let count = self.counts[to] as usize;
43-
let start = to * self.k;
41+
fn add(&mut self, node: usize, neighbor: NodeId, rng: &mut impl Rng) {
42+
let count = self.counts[node] as usize;
43+
let start = node * self.k;
44+
let filled = count.min(self.k);
45+
46+
// Check if neighbor already exists in the filled portion
47+
if self.data[start..start + filled].contains(&neighbor) {
48+
return;
49+
}
4450

4551
if count < self.k {
4652
// Still have room, just append
47-
self.data[start + count] = from;
53+
self.data[start + count] = neighbor;
4854
} else {
4955
// Reservoir sampling: replace with probability k / (count + 1)
5056
let j = rng.gen_range(0..=count);
5157
if j < self.k {
52-
self.data[start + j] = from;
58+
self.data[start + j] = neighbor;
5359
}
5460
}
55-
self.counts[to] += 1;
61+
self.counts[node] += 1;
5662
}
5763

58-
/// Get the reverse neighbors of node i (only the filled slots)
64+
/// Get the neighbors of node i (only the filled slots)
5965
#[inline]
6066
fn get(&self, i: usize) -> &[NodeId] {
6167
let start = i * self.k;
@@ -64,26 +70,46 @@ impl ReverseNeighbors {
6470
}
6571
}
6672

67-
/// Build reverse neighbor lists with reservoir sampling.
68-
/// Returns separate old and new reverse neighbor structures.
69-
fn build_reverse_lists(graph: &AnnGraph, rng: &mut impl Rng) -> (ReverseNeighbors, ReverseNeighbors) {
73+
/// Build combined neighbor lists (forward + reverse) with reservoir sampling.
74+
/// Returns (old_neighbors, new_neighbors), each with 2*k capacity per node.
75+
/// Only marks neighbors that were selected into new_neighbors as old.
76+
fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors, Neighbors) {
7077
let n = graph.n();
7178
let k = graph.k();
72-
let mut old_reverse = ReverseNeighbors::new(n, k);
73-
let mut new_reverse = ReverseNeighbors::new(n, k);
79+
// 2*k capacity: k for forward + k for reverse
80+
let mut old_neighbors = Neighbors::new(n, 2 * k);
81+
let mut new_neighbors = Neighbors::new(n, 2 * k);
7482

7583
for i in 0..n {
7684
let i_id = NodeId::new(i as u32);
7785
for neighbor in graph.neighbors(i) {
7886
let target = neighbor.index.index() as usize;
87+
let target_id = neighbor.index.as_old(); // Strip new flag for storage
7988
if neighbor.index.is_new() {
80-
new_reverse.add(target, i_id, rng);
89+
// Forward: i -> target, Reverse: target <- i
90+
new_neighbors.add(i, target_id, rng);
91+
new_neighbors.add(target, i_id, rng);
8192
} else {
82-
old_reverse.add(target, i_id, rng);
93+
old_neighbors.add(i, target_id, rng);
94+
old_neighbors.add(target, i_id, rng);
8395
}
8496
}
8597
}
86-
(old_reverse, new_reverse)
98+
99+
// Only mark neighbors as old if they were selected into new_neighbors
100+
for i in 0..n {
101+
for &selected_id in new_neighbors.get(i) {
102+
// Find this neighbor in the graph and mark as old if it's still new
103+
for nb in graph.neighbors_mut(i) {
104+
if nb.index == selected_id {
105+
nb.index = nb.index.as_old();
106+
break;
107+
}
108+
}
109+
}
110+
}
111+
112+
(old_neighbors, new_neighbors)
87113
}
88114

89115
/// Initialize ANN graph with random neighbors
@@ -201,30 +227,10 @@ where
201227

202228
// NN-Descent iterations
203229
for iter in 0..config.nn_descent_iterations {
204-
// Build reverse neighbor lists, separating old and new (with reservoir sampling)
205-
let (old_reverse, new_reverse) = build_reverse_lists(&graph, rng);
206-
207-
// For each point, collect old and new forward neighbors
208-
let mut old_neighbors: Vec<Vec<NodeId>> = vec![Vec::new(); n];
209-
let mut new_neighbors: Vec<Vec<NodeId>> = vec![Vec::new(); n];
210-
211-
for i in 0..n {
212-
for nb in graph.neighbors(i) {
213-
let idx = nb.index.as_old(); // Strip new flag for storage
214-
if nb.index.is_new() {
215-
new_neighbors[i].push(idx);
216-
} else {
217-
old_neighbors[i].push(idx);
218-
}
219-
}
220-
}
221-
222-
// Mark all neighbors as old for next iteration
223-
for i in 0..n {
224-
for nb in graph.neighbors_mut(i) {
225-
nb.index = nb.index.as_old();
226-
}
227-
}
230+
println!("NN-Descent iteration {iter}...");
231+
// Build combined neighbor lists (forward + reverse, with reservoir sampling)
232+
// Also marks all neighbors as old for next iteration
233+
let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng);
228234

229235
let mut updates = 0;
230236

@@ -233,17 +239,8 @@ where
233239
for i in 0..n {
234240
let i_id = NodeId::new(i as u32);
235241

236-
// Combine forward and reverse neighbors
237-
let old_i: Vec<NodeId> = old_neighbors[i]
238-
.iter()
239-
.chain(old_reverse.get(i).iter())
240-
.copied()
241-
.collect();
242-
let new_i: Vec<NodeId> = new_neighbors[i]
243-
.iter()
244-
.chain(new_reverse.get(i).iter())
245-
.copied()
246-
.collect();
242+
let old_i = old_neighbors.get(i);
243+
let new_i = new_neighbors.get(i);
247244

248245
// Skip if no new neighbors
249246
if new_i.is_empty() {
@@ -260,44 +257,29 @@ where
260257
let mut candidates: HashSet<NodeId> = HashSet::new();
261258

262259
// new-new pairs: for each new neighbor, look at their new neighbors
263-
for &u in &new_i {
260+
for &u in new_i {
264261
let u_idx = u.index() as usize;
265-
for &v in &new_neighbors[u_idx] {
266-
if v != i_id && !current_neighbors.contains(&v) {
267-
candidates.insert(v);
268-
}
269-
}
270-
for v in new_reverse.get(u_idx) {
262+
for v in new_neighbors.get(u_idx) {
271263
if *v != i_id && !current_neighbors.contains(v) {
272264
candidates.insert(*v);
273265
}
274266
}
275267
}
276268

277269
// new-old pairs: for each new neighbor, look at their old neighbors
278-
for &u in &new_i {
270+
for &u in new_i {
279271
let u_idx = u.index() as usize;
280-
for &v in &old_neighbors[u_idx] {
281-
if v != i_id && !current_neighbors.contains(&v) {
282-
candidates.insert(v);
283-
}
284-
}
285-
for v in old_reverse.get(u_idx) {
272+
for v in old_neighbors.get(u_idx) {
286273
if *v != i_id && !current_neighbors.contains(v) {
287274
candidates.insert(*v);
288275
}
289276
}
290277
}
291278

292279
// old-new pairs: for each old neighbor, look at their new neighbors
293-
for &u in &old_i {
280+
for &u in old_i {
294281
let u_idx = u.index() as usize;
295-
for &v in &new_neighbors[u_idx] {
296-
if v != i_id && !current_neighbors.contains(&v) {
297-
candidates.insert(v);
298-
}
299-
}
300-
for v in new_reverse.get(u_idx) {
282+
for v in new_neighbors.get(u_idx) {
301283
if *v != i_id && !current_neighbors.contains(v) {
302284
candidates.insert(*v);
303285
}
@@ -320,13 +302,5 @@ where
320302
break;
321303
}
322304
}
323-
324-
// Strip the new flag from all neighbors before returning
325-
for i in 0..n {
326-
for nb in graph.neighbors_mut(i) {
327-
nb.index = nb.index.as_old();
328-
}
329-
}
330-
331305
graph
332306
}

0 commit comments

Comments
 (0)