Skip to content

Commit 34b6e1f

Browse files
committed
Use more cache-friendly loop
1 parent 5477533 commit 34b6e1f

File tree

2 files changed

+68
-78
lines changed

2 files changed

+68
-78
lines changed

crates/famst/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ mod union_find;
1717
use nn_descent::nn_descent;
1818
use rand::seq::SliceRandom;
1919
use rand::Rng;
20+
use rayon::iter::IndexedParallelIterator;
21+
use rayon::slice::ParallelSliceMut;
2022
use union_find::{find_components, UnionFind};
2123

2224
/// Node index with embedded "new" flag in the least significant bit.
@@ -169,6 +171,11 @@ impl AnnGraph {
169171
let start = i * self.k;
170172
&mut self.data[start..start + self.k]
171173
}
174+
175+
/// Get mutable access to all neighbor chunks for parallel processing
176+
pub(crate) fn neighbors_chunks_mut(&mut self) -> impl IndexedParallelIterator<Item = &mut [Neighbor]> {
177+
self.data.par_chunks_mut(self.k)
178+
}
172179
}
173180

174181
/// FAMST algorithm configuration

crates/famst/src/nn_descent.rs

Lines changed: 61 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::{AnnGraph, FamstConfig, Neighbor, NodeId};
1919
/// Uses reservoir sampling when more than k neighbors exist.
2020
struct Neighbors {
2121
/// Flat storage of neighbor IDs
22-
data: Vec<NodeId>,
22+
data: Vec<u32>,
2323
/// Count of neighbors seen so far (for reservoir sampling)
2424
counts: Vec<u32>,
2525
/// Max neighbors per node
@@ -29,7 +29,7 @@ struct Neighbors {
2929
impl Neighbors {
3030
fn new(n: usize, k: usize) -> Self {
3131
Neighbors {
32-
data: vec![NodeId::new(0); n * k],
32+
data: vec![0; n * k],
3333
counts: vec![0; n],
3434
k,
3535
}
@@ -38,7 +38,7 @@ impl Neighbors {
3838
/// Add a neighbor using reservoir sampling to maintain at most k neighbors.
3939
/// Skips if the neighbor is already present.
4040
#[inline]
41-
fn add(&mut self, node: usize, neighbor: NodeId, rng: &mut impl Rng) {
41+
fn add(&mut self, node: usize, neighbor: u32, rng: &mut impl Rng) {
4242
let count = self.counts[node] as usize;
4343
let start = node * self.k;
4444
let filled = count.min(self.k);
@@ -63,7 +63,7 @@ impl Neighbors {
6363

6464
/// Get the neighbors of node i (only the filled slots)
6565
#[inline]
66-
fn get(&self, i: usize) -> &[NodeId] {
66+
fn get(&self, i: usize) -> &[u32] {
6767
let start = i * self.k;
6868
let count = (self.counts[i] as usize).min(self.k);
6969
&self.data[start..start + count]
@@ -77,37 +77,38 @@ fn build_neighbor_lists(graph: &mut AnnGraph, rng: &mut impl Rng) -> (Neighbors,
7777
let n = graph.n();
7878
let k = graph.k();
7979
// 2*k capacity: k for forward + k for reverse
80-
let mut old_neighbors = Neighbors::new(n, 2 * k);
81-
let mut new_neighbors = Neighbors::new(n, 2 * k);
80+
let mut old_neighbors = Neighbors::new(n, k * 2);
81+
let mut new_neighbors = Neighbors::new(n, k * 2);
8282

8383
for i in 0..n {
84-
let i_id = NodeId::new(i as u32);
8584
for neighbor in graph.neighbors(i) {
86-
let target = neighbor.index.index() as usize;
87-
let target_id = neighbor.index.as_old(); // Strip new flag for storage
85+
let target = neighbor.index.index();
8886
if neighbor.index.is_new() {
8987
// Forward: i -> target, Reverse: target <- i
90-
new_neighbors.add(i, target_id, rng);
91-
new_neighbors.add(target, i_id, rng);
88+
new_neighbors.add(i, target, rng);
89+
new_neighbors.add(target as usize, i as u32, rng);
9290
} else {
93-
old_neighbors.add(i, target_id, rng);
94-
old_neighbors.add(target, i_id, rng);
91+
old_neighbors.add(i, target, rng);
92+
old_neighbors.add(target as usize, i as u32, rng);
9593
}
9694
}
9795
}
9896

9997
// Only mark neighbors as old if they were selected into new_neighbors
100-
for i in 0..n {
101-
for &selected_id in new_neighbors.get(i) {
102-
// Find this neighbor in the graph and mark as old if it's still new
103-
for nb in graph.neighbors_mut(i) {
104-
if nb.index == selected_id {
105-
nb.index = nb.index.as_old();
106-
break;
98+
graph
99+
.neighbors_chunks_mut()
100+
.enumerate()
101+
.for_each(|(i, neighbors)| {
102+
for &selected_id in new_neighbors.get(i) {
103+
// Find this neighbor in the graph and mark as old if it's still new
104+
for nb in neighbors.iter_mut() {
105+
if nb.index.index() == selected_id {
106+
nb.index = nb.index.as_old();
107+
break;
108+
}
107109
}
108110
}
109-
}
110-
}
111+
});
111112

112113
(old_neighbors, new_neighbors)
113114
}
@@ -232,66 +233,48 @@ where
232233
// Also marks all neighbors as old for next iteration
233234
let (old_neighbors, new_neighbors) = build_neighbor_lists(&mut graph, rng);
234235

235-
let mut updates = 0;
236-
237236
// For each point, generate candidates from neighbors of neighbors
238237
// Key optimization: only consider pairs where at least one is "new"
239-
for i in 0..n {
240-
let i_id = NodeId::new(i as u32);
241-
242-
let old_i = old_neighbors.get(i);
243-
let new_i = new_neighbors.get(i);
244-
245-
// Skip if no new neighbors
246-
if new_i.is_empty() {
247-
continue;
248-
}
249-
250-
// Build set of current neighbors for O(1) lookup
251-
let current_neighbors: HashSet<NodeId> = graph
252-
.neighbors(i)
253-
.iter()
254-
.map(|nb| nb.index.as_old())
255-
.collect();
256-
257-
let mut candidates: HashSet<NodeId> = HashSet::new();
258-
259-
// new-new pairs: for each new neighbor, look at their new neighbors
260-
for &u in new_i {
261-
let u_idx = u.index() as usize;
262-
for v in new_neighbors.get(u_idx) {
263-
if *v != i_id && !current_neighbors.contains(v) {
264-
candidates.insert(*v);
238+
let candidates: HashSet<(u32, u32)> = (0..n)
239+
.into_par_iter()
240+
.fold(
241+
HashSet::new,
242+
|mut local_candidates, i| {
243+
let old_i = old_neighbors.get(i);
244+
let new_i = new_neighbors.get(i);
245+
246+
// Skip if no new neighbors
247+
if !new_i.is_empty() {
248+
for &u in new_i {
249+
for &v in new_i {
250+
if u < v {
251+
local_candidates.insert((u, v));
252+
}
253+
}
254+
for &v in old_i {
255+
if u != v {
256+
local_candidates.insert((u.min(v), u.max(v)));
257+
}
258+
}
259+
}
265260
}
266-
}
267-
}
268-
269-
// new-old pairs: for each new neighbor, look at their old neighbors
270-
for &u in new_i {
271-
let u_idx = u.index() as usize;
272-
for v in old_neighbors.get(u_idx) {
273-
if *v != i_id && !current_neighbors.contains(v) {
274-
candidates.insert(*v);
275-
}
276-
}
277-
}
278-
279-
// old-new pairs: for each old neighbor, look at their new neighbors
280-
for &u in old_i {
281-
let u_idx = u.index() as usize;
282-
for v in new_neighbors.get(u_idx) {
283-
if *v != i_id && !current_neighbors.contains(v) {
284-
candidates.insert(*v);
285-
}
286-
}
261+
local_candidates
262+
},
263+
)
264+
.reduce(HashSet::new, |mut a, b| {
265+
a.extend(b);
266+
a
267+
});
268+
269+
// Try to improve neighbors with candidates
270+
let mut updates = 0;
271+
for &(u, v) in &candidates {
272+
let d = distance_fn(&data[u as usize], &data[v as usize]);
273+
if insert_neighbor(graph.neighbors_mut(u as usize), NodeId::new(v), d) {
274+
updates += 1;
287275
}
288-
289-
// Try to improve neighbors with candidates
290-
for c in candidates {
291-
let d = distance_fn(&data[i], &data[c.index() as usize]);
292-
if insert_neighbor(graph.neighbors_mut(i), c, d) {
293-
updates += 1;
294-
}
276+
if insert_neighbor(graph.neighbors_mut(v as usize), NodeId::new(u), d) {
277+
updates += 1;
295278
}
296279
}
297280

0 commit comments

Comments
 (0)