@@ -14,48 +14,54 @@ use std::collections::HashSet;
1414
1515use crate :: { AnnGraph , FamstConfig , Neighbor , NodeId } ;
1616
17- /// Reverse neighbor lists with bounded size k per node.
18- /// Uses flat storage: data[i*k..(i+1)*k] contains up to k reverse neighbors of node i.
19- /// Uses reservoir sampling when more than k reverse edges exist.
20- struct ReverseNeighbors {
21- /// Flat storage of reverse neighbor IDs (with new flag preserved)
17+ /// Neighbor lists with bounded size k per node.
18+ /// Uses flat storage: data[i*k..(i+1)*k] contains up to k neighbors of node i.
19+ /// Uses reservoir sampling when more than k neighbors exist.
20+ struct Neighbors {
21+ /// Flat storage of neighbor IDs
2222 data : Vec < NodeId > ,
23- /// Count of reverse neighbors seen so far (for reservoir sampling)
23+ /// Count of neighbors seen so far (for reservoir sampling)
2424 counts : Vec < u32 > ,
25- /// Max reverse neighbors per node
25+ /// Max neighbors per node
2626 k : usize ,
2727}
2828
29- impl ReverseNeighbors {
29+ impl Neighbors {
3030 fn new ( n : usize , k : usize ) -> Self {
31- ReverseNeighbors {
31+ Neighbors {
3232 data : vec ! [ NodeId :: new( 0 ) ; n * k] ,
3333 counts : vec ! [ 0 ; n] ,
3434 k,
3535 }
3636 }
3737
38- /// Add a reverse edge: node `from` is a neighbor of node `to`, so `to` has reverse edge to `from` .
39- /// Uses reservoir sampling to maintain at most k reverse neighbors .
38+ /// Add a neighbor using reservoir sampling to maintain at most k neighbors .
39+ /// Skips if the neighbor is already present .
4040 #[ inline]
41- fn add ( & mut self , to : usize , from : NodeId , rng : & mut impl Rng ) {
42- let count = self . counts [ to] as usize ;
43- let start = to * self . k ;
41+ fn add ( & mut self , node : usize , neighbor : NodeId , rng : & mut impl Rng ) {
42+ let count = self . counts [ node] as usize ;
43+ let start = node * self . k ;
44+ let filled = count. min ( self . k ) ;
45+
46+ // Check if neighbor already exists in the filled portion
47+ if self . data [ start..start + filled] . contains ( & neighbor) {
48+ return ;
49+ }
4450
4551 if count < self . k {
4652 // Still have room, just append
47- self . data [ start + count] = from ;
53+ self . data [ start + count] = neighbor ;
4854 } else {
4955 // Reservoir sampling: replace with probability k / (count + 1)
5056 let j = rng. gen_range ( 0 ..=count) ;
5157 if j < self . k {
52- self . data [ start + j] = from ;
58+ self . data [ start + j] = neighbor ;
5359 }
5460 }
55- self . counts [ to ] += 1 ;
61+ self . counts [ node ] += 1 ;
5662 }
5763
58- /// Get the reverse neighbors of node i (only the filled slots)
64+ /// Get the neighbors of node i (only the filled slots)
5965 #[ inline]
6066 fn get ( & self , i : usize ) -> & [ NodeId ] {
6167 let start = i * self . k ;
@@ -64,26 +70,46 @@ impl ReverseNeighbors {
6470 }
6571}
6672
67- /// Build reverse neighbor lists with reservoir sampling.
68- /// Returns separate old and new reverse neighbor structures.
69- fn build_reverse_lists ( graph : & AnnGraph , rng : & mut impl Rng ) -> ( ReverseNeighbors , ReverseNeighbors ) {
73+ /// Build combined neighbor lists (forward + reverse) with reservoir sampling.
74+ /// Returns (old_neighbors, new_neighbors), each with 2*k capacity per node.
75+ /// Only marks neighbors that were selected into new_neighbors as old.
76+ fn build_neighbor_lists ( graph : & mut AnnGraph , rng : & mut impl Rng ) -> ( Neighbors , Neighbors ) {
7077 let n = graph. n ( ) ;
7178 let k = graph. k ( ) ;
72- let mut old_reverse = ReverseNeighbors :: new ( n, k) ;
73- let mut new_reverse = ReverseNeighbors :: new ( n, k) ;
79+ // 2*k capacity: k for forward + k for reverse
80+ let mut old_neighbors = Neighbors :: new ( n, 2 * k) ;
81+ let mut new_neighbors = Neighbors :: new ( n, 2 * k) ;
7482
7583 for i in 0 ..n {
7684 let i_id = NodeId :: new ( i as u32 ) ;
7785 for neighbor in graph. neighbors ( i) {
7886 let target = neighbor. index . index ( ) as usize ;
87+ let target_id = neighbor. index . as_old ( ) ; // Strip new flag for storage
7988 if neighbor. index . is_new ( ) {
80- new_reverse. add ( target, i_id, rng) ;
89+ // Forward: i -> target, Reverse: target <- i
90+ new_neighbors. add ( i, target_id, rng) ;
91+ new_neighbors. add ( target, i_id, rng) ;
8192 } else {
82- old_reverse. add ( target, i_id, rng) ;
93+ old_neighbors. add ( i, target_id, rng) ;
94+ old_neighbors. add ( target, i_id, rng) ;
8395 }
8496 }
8597 }
86- ( old_reverse, new_reverse)
98+
99+ // Only mark neighbors as old if they were selected into new_neighbors
100+ for i in 0 ..n {
101+ for & selected_id in new_neighbors. get ( i) {
102+ // Find this neighbor in the graph and mark as old if it's still new
103+ for nb in graph. neighbors_mut ( i) {
104+ if nb. index == selected_id {
105+ nb. index = nb. index . as_old ( ) ;
106+ break ;
107+ }
108+ }
109+ }
110+ }
111+
112+ ( old_neighbors, new_neighbors)
87113}
88114
89115/// Initialize ANN graph with random neighbors
@@ -201,30 +227,10 @@ where
201227
202228 // NN-Descent iterations
203229 for iter in 0 ..config. nn_descent_iterations {
204- // Build reverse neighbor lists, separating old and new (with reservoir sampling)
205- let ( old_reverse, new_reverse) = build_reverse_lists ( & graph, rng) ;
206-
207- // For each point, collect old and new forward neighbors
208- let mut old_neighbors: Vec < Vec < NodeId > > = vec ! [ Vec :: new( ) ; n] ;
209- let mut new_neighbors: Vec < Vec < NodeId > > = vec ! [ Vec :: new( ) ; n] ;
210-
211- for i in 0 ..n {
212- for nb in graph. neighbors ( i) {
213- let idx = nb. index . as_old ( ) ; // Strip new flag for storage
214- if nb. index . is_new ( ) {
215- new_neighbors[ i] . push ( idx) ;
216- } else {
217- old_neighbors[ i] . push ( idx) ;
218- }
219- }
220- }
221-
222- // Mark all neighbors as old for next iteration
223- for i in 0 ..n {
224- for nb in graph. neighbors_mut ( i) {
225- nb. index = nb. index . as_old ( ) ;
226- }
227- }
230+ println ! ( "NN-Descent iteration {iter}..." ) ;
231+ // Build combined neighbor lists (forward + reverse, with reservoir sampling)
232+ // Also marks all neighbors as old for next iteration
233+ let ( old_neighbors, new_neighbors) = build_neighbor_lists ( & mut graph, rng) ;
228234
229235 let mut updates = 0 ;
230236
@@ -233,17 +239,8 @@ where
233239 for i in 0 ..n {
234240 let i_id = NodeId :: new ( i as u32 ) ;
235241
236- // Combine forward and reverse neighbors
237- let old_i: Vec < NodeId > = old_neighbors[ i]
238- . iter ( )
239- . chain ( old_reverse. get ( i) . iter ( ) )
240- . copied ( )
241- . collect ( ) ;
242- let new_i: Vec < NodeId > = new_neighbors[ i]
243- . iter ( )
244- . chain ( new_reverse. get ( i) . iter ( ) )
245- . copied ( )
246- . collect ( ) ;
242+ let old_i = old_neighbors. get ( i) ;
243+ let new_i = new_neighbors. get ( i) ;
247244
248245 // Skip if no new neighbors
249246 if new_i. is_empty ( ) {
@@ -260,44 +257,29 @@ where
260257 let mut candidates: HashSet < NodeId > = HashSet :: new ( ) ;
261258
262259 // new-new pairs: for each new neighbor, look at their new neighbors
263- for & u in & new_i {
260+ for & u in new_i {
264261 let u_idx = u. index ( ) as usize ;
265- for & v in & new_neighbors[ u_idx] {
266- if v != i_id && !current_neighbors. contains ( & v) {
267- candidates. insert ( v) ;
268- }
269- }
270- for v in new_reverse. get ( u_idx) {
262+ for v in new_neighbors. get ( u_idx) {
271263 if * v != i_id && !current_neighbors. contains ( v) {
272264 candidates. insert ( * v) ;
273265 }
274266 }
275267 }
276268
277269 // new-old pairs: for each new neighbor, look at their old neighbors
278- for & u in & new_i {
270+ for & u in new_i {
279271 let u_idx = u. index ( ) as usize ;
280- for & v in & old_neighbors[ u_idx] {
281- if v != i_id && !current_neighbors. contains ( & v) {
282- candidates. insert ( v) ;
283- }
284- }
285- for v in old_reverse. get ( u_idx) {
272+ for v in old_neighbors. get ( u_idx) {
286273 if * v != i_id && !current_neighbors. contains ( v) {
287274 candidates. insert ( * v) ;
288275 }
289276 }
290277 }
291278
292279 // old-new pairs: for each old neighbor, look at their new neighbors
293- for & u in & old_i {
280+ for & u in old_i {
294281 let u_idx = u. index ( ) as usize ;
295- for & v in & new_neighbors[ u_idx] {
296- if v != i_id && !current_neighbors. contains ( & v) {
297- candidates. insert ( v) ;
298- }
299- }
300- for v in new_reverse. get ( u_idx) {
282+ for v in new_neighbors. get ( u_idx) {
301283 if * v != i_id && !current_neighbors. contains ( v) {
302284 candidates. insert ( * v) ;
303285 }
@@ -320,13 +302,5 @@ where
320302 break ;
321303 }
322304 }
323-
324- // Strip the new flag from all neighbors before returning
325- for i in 0 ..n {
326- for nb in graph. neighbors_mut ( i) {
327- nb. index = nb. index . as_old ( ) ;
328- }
329- }
330-
331305 graph
332306}
0 commit comments