@@ -232,78 +232,90 @@ where
232232 // Build combined neighbor lists (forward + reverse, with reservoir sampling)
233233 // Also marks all neighbors as old for next iteration
234234 let ( old_neighbors, new_neighbors) = build_neighbor_lists ( & mut graph, rng) ;
235-
236- // For each point, generate candidates from neighbors of neighbors
237- // Key optimization: only consider pairs where at least one is "new"
238- // Collect all candidates with their distances, including reverse edges
239- let mut candidates: Vec < ( u32 , u32 , f32 ) > = ( 0 ..n)
240- . into_par_iter ( )
241- . flat_map_iter ( |i| {
242- let old_i = old_neighbors. get ( i) ;
243- let new_i = new_neighbors. get ( i) ;
244-
245- // new-new pairs: (u, v) where u < v
246- let new_new = new_i. iter ( ) . flat_map ( |& u| {
247- new_i
248- . iter ( )
249- . filter_map ( move |& v| if u < v { Some ( ( u, v) ) } else { None } )
250- } ) ;
251-
252- // new-old pairs: (min, max) where u != v
253- let new_old = new_i. iter ( ) . flat_map ( |& u| {
254- old_i. iter ( ) . filter_map ( move |& v| {
255- if u != v {
256- Some ( ( u. min ( v) , u. max ( v) ) )
257- } else {
258- None
235+ println ! ( " Built neighbor lists" ) ;
236+
237+ // Process in batches of central nodes to limit memory usage
238+ let batch_size: usize = n / k;
239+ let mut total_updates = 0 ;
240+ let mut total_candidates = 0 ;
241+
242+ for batch_start in ( 0 ..n) . step_by ( batch_size) {
243+ let batch_end = ( batch_start + batch_size) . min ( n) ;
244+
245+ // For each point in this batch, generate candidates from neighbors of neighbors
246+ // Key optimization: only consider pairs where at least one is "new"
247+ // Collect all candidates with their distances, including reverse edges
248+ let mut candidates: Vec < ( u32 , u32 , f32 ) > = ( batch_start..batch_end)
249+ . into_par_iter ( )
250+ . flat_map_iter ( |i| {
251+ let old_i = old_neighbors. get ( i) ;
252+ let new_i = new_neighbors. get ( i) ;
253+
254+ // new-new pairs: (u, v) where u < v
255+ let new_new = new_i. iter ( ) . flat_map ( |& u| {
256+ new_i
257+ . iter ( )
258+ . filter_map ( move |& v| if u < v { Some ( ( u, v) ) } else { None } )
259+ } ) ;
260+
261+ // new-old pairs: (min, max) where u != v
262+ let new_old = new_i. iter ( ) . flat_map ( |& u| {
263+ old_i. iter ( ) . filter_map ( move |& v| {
264+ if u != v {
265+ Some ( ( u. min ( v) , u. max ( v) ) )
266+ } else {
267+ None
268+ }
269+ } )
270+ } ) ;
271+
272+ new_new
273+ . chain ( new_old)
274+ . flat_map ( |( u, v) | {
275+ let d = distance_fn ( & data[ u as usize ] , & data[ v as usize ] ) ;
276+ // Insert both (u, v) and (v, u) so we can parallelize by node
277+ [ ( u, v, d) , ( v, u, d) ]
278+ } )
279+ } )
280+ . collect ( ) ;
281+
282+ // Sort by first node so candidates for each node are contiguous
283+ candidates. par_sort_unstable_by_key ( |( a, _, _) | * a) ;
284+ candidates. dedup ( ) ;
285+ total_candidates += candidates. len ( ) ;
286+
287+ // Process in parallel by chunks of nodes
288+ const CHUNK_SIZE : usize = 64 ;
289+ let updates: usize = graph
290+ . neighbors_chunks_mut ( CHUNK_SIZE )
291+ . enumerate ( )
292+ . map ( |( i, chunk_neighbors) | {
293+ let start_node = ( CHUNK_SIZE * i) as u32 ;
294+ // Binary search to find the range of candidates for this chunk
295+ let start = candidates. partition_point ( |& ( u, _, _) | u < start_node) ;
296+ let mut count = 0 ;
297+ for & ( u, v, d) in & candidates[ start..] {
298+ if u >= start_node + CHUNK_SIZE as u32 {
299+ break ;
300+ }
301+ let neighbors = & mut chunk_neighbors[ ( u - start_node) as usize * k..] [ ..k] ;
302+ if insert_neighbor ( neighbors, NodeId :: new ( v) , d) {
303+ count += 1 ;
259304 }
260- } )
261- } ) ;
262-
263- new_new
264- . chain ( new_old)
265- . flat_map ( |( u, v) | {
266- let d = distance_fn ( & data[ u as usize ] , & data[ v as usize ] ) ;
267- // Insert both (u, v) and (v, u) so we can parallelize by node
268- [ ( u, v, d) , ( v, u, d) ]
269- } )
270- } )
271- . collect ( ) ;
272-
273- // Sort by first node so candidates for each node are contiguous
274- candidates. par_sort_unstable_by_key ( |( a, _, _) | * a) ;
275- candidates. dedup ( ) ;
276-
277- // Process in parallel by chunks of nodes
278- const CHUNK_SIZE : usize = 64 ;
279- let updates: usize = graph
280- . neighbors_chunks_mut ( CHUNK_SIZE )
281- . enumerate ( )
282- . map ( |( i, chunk_neighbors) | {
283- let start_node = ( CHUNK_SIZE * i) as u32 ;
284- // Binary search to find the range of candidates for this node
285- let start = candidates. partition_point ( |& ( u, _, _) | u < start_node) ;
286- let mut count = 0 ;
287- for & ( u, v, d) in & candidates[ start..] {
288- if u >= start_node + CHUNK_SIZE as u32 {
289- break ;
290- }
291- let neighbors = & mut chunk_neighbors[ ( u - start_node) as usize * k..] [ ..k] ;
292- if insert_neighbor ( neighbors, NodeId :: new ( v) , d) {
293- count += 1 ;
294305 }
295- }
296- count
297- } )
298- . sum ( ) ;
306+ count
307+ } )
308+ . sum ( ) ;
309+
310+ total_updates += updates;
311+ }
299312
300313 println ! (
301- "NN-Descent iteration {iter}: {updates} updates of {} candidates" ,
302- candidates. len( ) ,
314+ "NN-Descent iteration {iter}: {total_updates} updates of {total_candidates} candidates" ,
303315 ) ;
304316
305317 // Early termination if no updates
306- if updates == 0 {
318+ if total_updates == 0 {
307319 break ;
308320 }
309321 }
0 commit comments