@@ -19,7 +19,7 @@ use crate::{AnnGraph, FamstConfig, Neighbor, NodeId};
1919/// Uses reservoir sampling when more than k neighbors exist.
2020struct Neighbors {
2121 /// Flat storage of neighbor IDs
22- data : Vec < NodeId > ,
22+ data : Vec < u32 > ,
2323 /// Count of neighbors seen so far (for reservoir sampling)
2424 counts : Vec < u32 > ,
2525 /// Max neighbors per node
@@ -29,7 +29,7 @@ struct Neighbors {
2929impl Neighbors {
3030 fn new ( n : usize , k : usize ) -> Self {
3131 Neighbors {
32- data : vec ! [ NodeId :: new ( 0 ) ; n * k] ,
32+ data : vec ! [ 0 ; n * k] ,
3333 counts : vec ! [ 0 ; n] ,
3434 k,
3535 }
@@ -38,7 +38,7 @@ impl Neighbors {
3838 /// Add a neighbor using reservoir sampling to maintain at most k neighbors.
3939 /// Skips if the neighbor is already present.
4040 #[ inline]
41- fn add ( & mut self , node : usize , neighbor : NodeId , rng : & mut impl Rng ) {
41+ fn add ( & mut self , node : usize , neighbor : u32 , rng : & mut impl Rng ) {
4242 let count = self . counts [ node] as usize ;
4343 let start = node * self . k ;
4444 let filled = count. min ( self . k ) ;
@@ -63,7 +63,7 @@ impl Neighbors {
6363
6464 /// Get the neighbors of node i (only the filled slots)
6565 #[ inline]
66- fn get ( & self , i : usize ) -> & [ NodeId ] {
66+ fn get ( & self , i : usize ) -> & [ u32 ] {
6767 let start = i * self . k ;
6868 let count = ( self . counts [ i] as usize ) . min ( self . k ) ;
6969 & self . data [ start..start + count]
@@ -77,37 +77,38 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors,
7777 let n = graph. n ( ) ;
7878 let k = graph. k ( ) ;
7979 // 2*k capacity: k for forward + k for reverse
80- let mut old_neighbors = Neighbors :: new ( n, 2 * k ) ;
81- let mut new_neighbors = Neighbors :: new ( n, 2 * k ) ;
80+ let mut old_neighbors = Neighbors :: new ( n, k * 2 ) ;
81+ let mut new_neighbors = Neighbors :: new ( n, k * 2 ) ;
8282
8383 for i in 0 ..n {
84- let i_id = NodeId :: new ( i as u32 ) ;
8584 for neighbor in graph. neighbors ( i) {
86- let target = neighbor. index . index ( ) as usize ;
87- let target_id = neighbor. index . as_old ( ) ; // Strip new flag for storage
85+ let target = neighbor. index . index ( ) ;
8886 if neighbor. index . is_new ( ) {
8987 // Forward: i -> target, Reverse: target <- i
90- new_neighbors. add ( i, target_id , rng) ;
91- new_neighbors. add ( target, i_id , rng) ;
88+ new_neighbors. add ( i, target , rng) ;
89+ new_neighbors. add ( target as usize , i as u32 , rng) ;
9290 } else {
93- old_neighbors. add ( i, target_id , rng) ;
94- old_neighbors. add ( target, i_id , rng) ;
91+ old_neighbors. add ( i, target , rng) ;
92+ old_neighbors. add ( target as usize , i as u32 , rng) ;
9593 }
9694 }
9795 }
9896
9997 // Only mark neighbors as old if they were selected into new_neighbors
100- for i in 0 ..n {
101- for & selected_id in new_neighbors. get ( i) {
102- // Find this neighbor in the graph and mark as old if it's still new
103- for nb in graph. neighbors_mut ( i) {
104- if nb. index == selected_id {
105- nb. index = nb. index . as_old ( ) ;
106- break ;
98+ graph
99+ . neighbors_chunks_mut ( )
100+ . enumerate ( )
101+ . for_each ( |( i, neighbors) | {
102+ for & selected_id in new_neighbors. get ( i) {
103+ // Find this neighbor in the graph and mark as old if it's still new
104+ for nb in neighbors. iter_mut ( ) {
105+ if nb. index . index ( ) == selected_id {
106+ nb. index = nb. index . as_old ( ) ;
107+ break ;
108+ }
107109 }
108110 }
109- }
110- }
111+ } ) ;
111112
112113 ( old_neighbors, new_neighbors)
113114}
@@ -232,66 +233,48 @@ where
232233 // Also marks all neighbors as old for next iteration
233234 let ( old_neighbors, new_neighbors) = build_neighbor_lists ( & mut graph, rng) ;
234235
235- let mut updates = 0 ;
236-
237236 // For each point, generate candidates from neighbors of neighbors
238237 // Key optimization: only consider pairs where at least one is "new"
239- for i in 0 ..n {
240- let i_id = NodeId :: new ( i as u32 ) ;
241-
242- let old_i = old_neighbors. get ( i) ;
243- let new_i = new_neighbors. get ( i) ;
244-
245- // Skip if no new neighbors
246- if new_i. is_empty ( ) {
247- continue ;
248- }
249-
250- // Build set of current neighbors for O(1) lookup
251- let current_neighbors: HashSet < NodeId > = graph
252- . neighbors ( i)
253- . iter ( )
254- . map ( |nb| nb. index . as_old ( ) )
255- . collect ( ) ;
256-
257- let mut candidates: HashSet < NodeId > = HashSet :: new ( ) ;
258-
259- // new-new pairs: for each new neighbor, look at their new neighbors
260- for & u in new_i {
261- let u_idx = u. index ( ) as usize ;
262- for v in new_neighbors. get ( u_idx) {
263- if * v != i_id && !current_neighbors. contains ( v) {
264- candidates. insert ( * v) ;
238+ let candidates: HashSet < ( u32 , u32 ) > = ( 0 ..n)
239+ . into_par_iter ( )
240+ . fold (
241+ HashSet :: new,
242+ |mut local_candidates, i| {
243+ let old_i = old_neighbors. get ( i) ;
244+ let new_i = new_neighbors. get ( i) ;
245+
246+ // Skip if no new neighbors
247+ if !new_i. is_empty ( ) {
248+ for & u in new_i {
249+ for & v in new_i {
250+ if u < v {
251+ local_candidates. insert ( ( u, v) ) ;
252+ }
253+ }
254+ for & v in old_i {
255+ if u != v {
256+ local_candidates. insert ( ( u. min ( v) , u. max ( v) ) ) ;
257+ }
258+ }
259+ }
265260 }
266- }
267- }
268-
269- // new-old pairs: for each new neighbor, look at their old neighbors
270- for & u in new_i {
271- let u_idx = u. index ( ) as usize ;
272- for v in old_neighbors. get ( u_idx) {
273- if * v != i_id && !current_neighbors. contains ( v) {
274- candidates. insert ( * v) ;
275- }
276- }
277- }
278-
279- // old-new pairs: for each old neighbor, look at their new neighbors
280- for & u in old_i {
281- let u_idx = u. index ( ) as usize ;
282- for v in new_neighbors. get ( u_idx) {
283- if * v != i_id && !current_neighbors. contains ( v) {
284- candidates. insert ( * v) ;
285- }
286- }
261+ local_candidates
262+ } ,
263+ )
264+ . reduce ( HashSet :: new, |mut a, b| {
265+ a. extend ( b) ;
266+ a
267+ } ) ;
268+
269+ // Try to improve neighbors with candidates
270+ let mut updates = 0 ;
271+ for & ( u, v) in & candidates {
272+ let d = distance_fn ( & data[ u as usize ] , & data[ v as usize ] ) ;
273+ if insert_neighbor ( graph. neighbors_mut ( u as usize ) , NodeId :: new ( v) , d) {
274+ updates += 1 ;
287275 }
288-
289- // Try to improve neighbors with candidates
290- for c in candidates {
291- let d = distance_fn ( & data[ i] , & data[ c. index ( ) as usize ] ) ;
292- if insert_neighbor ( graph. neighbors_mut ( i) , c, d) {
293- updates += 1 ;
294- }
276+ if insert_neighbor ( graph. neighbors_mut ( v as usize ) , NodeId :: new ( u) , d) {
277+ updates += 1 ;
295278 }
296279 }
297280
0 commit comments