@@ -96,7 +96,7 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors,
9696
9797 // Only mark neighbors as old if they were selected into new_neighbors
9898 graph
99- . neighbors_chunks_mut ( )
99+ . neighbors_chunks_mut ( 1 )
100100 . enumerate ( )
101101 . for_each ( |( i, neighbors) | {
102102 for & selected_id in new_neighbors. get ( i) {
@@ -235,15 +235,18 @@ where
235235
236236 // For each point, generate candidates from neighbors of neighbors
237237 // Key optimization: only consider pairs where at least one is "new"
238- let mut candidates: Vec < ( u32 , u32 ) > = ( 0 ..n)
238+ // Collect all candidates with their distances, including reverse edges
239+ let mut candidates: Vec < ( u32 , u32 , f32 ) > = ( 0 ..n)
239240 . into_par_iter ( )
240241 . flat_map_iter ( |i| {
241242 let old_i = old_neighbors. get ( i) ;
242243 let new_i = new_neighbors. get ( i) ;
243244
244245 // new-new pairs: (u, v) where u < v
245246 let new_new = new_i. iter ( ) . flat_map ( |& u| {
246- new_i. iter ( ) . filter_map ( move |& v| if u < v { Some ( ( u, v) ) } else { None } )
247+ new_i
248+ . iter ( )
249+ . filter_map ( move |& v| if u < v { Some ( ( u, v) ) } else { None } )
247250 } ) ;
248251
249252 // new-old pairs: (min, max) where u != v
@@ -257,26 +260,47 @@ where
257260 } )
258261 } ) ;
259262
260- new_new. chain ( new_old)
263+ new_new
264+ . chain ( new_old)
265+ . flat_map ( |( u, v) | {
266+ let d = distance_fn ( & data[ u as usize ] , & data[ v as usize ] ) ;
267+ // Insert both (u, v) and (v, u) so we can parallelize by node
268+ [ ( u, v, d) , ( v, u, d) ]
269+ } )
261270 } )
262- . take_any ( n * k)
263271 . collect ( ) ;
264- candidates. par_sort_unstable ( ) ;
272+
273+ // Sort by first node so candidates for each node are contiguous
274+ candidates. par_sort_unstable_by_key ( |( a, _, _) | * a) ;
265275 candidates. dedup ( ) ;
266276
267- // Try to improve neighbors with candidates
268- let mut updates = 0 ;
269- for & ( u, v) in & candidates {
270- let d = distance_fn ( & data[ u as usize ] , & data[ v as usize ] ) ;
271- if insert_neighbor ( graph. neighbors_mut ( u as usize ) , NodeId :: new ( v) , d) {
272- updates += 1 ;
273- }
274- if insert_neighbor ( graph. neighbors_mut ( v as usize ) , NodeId :: new ( u) , d) {
275- updates += 1 ;
276- }
277- }
277+ // Process in parallel by chunks of nodes
278+ const CHUNK_SIZE : usize = 64 ;
279+ let updates: usize = graph
280+ . neighbors_chunks_mut ( CHUNK_SIZE )
281+ . enumerate ( )
282+ . map ( |( i, chunk_neighbors) | {
283+ let start_node = ( CHUNK_SIZE * i) as u32 ;
284+ // Binary search to find the range of candidates for this node
285+ let start = candidates. partition_point ( |& ( u, _, _) | u < start_node) ;
286+ let mut count = 0 ;
287+ for & ( u, v, d) in & candidates[ start..] {
288+ if u >= start_node + CHUNK_SIZE as u32 {
289+ break ;
290+ }
291+ let neighbors = & mut chunk_neighbors[ ( u - start_node) as usize * k..] [ ..k] ;
292+ if insert_neighbor ( neighbors, NodeId :: new ( v) , d) {
293+ count += 1 ;
294+ }
295+ }
296+ count
297+ } )
298+ . sum ( ) ;
278299
279- println ! ( "NN-Descent iteration {iter}: {updates} updates of {} candidates" , candidates. len( ) ) ;
300+ println ! (
301+ "NN-Descent iteration {iter}: {updates} updates of {} candidates" ,
302+ candidates. len( ) ,
303+ ) ;
280304
281305 // Early termination if no updates
282306 if updates == 0 {
0 commit comments