Skip to content

Commit d946967

Browse files
committed
add multilayer kmeans script, fix l2 simd
Signed-off-by: Keming <[email protected]>
1 parent 0e087df commit d946967

File tree

2 files changed

+121
-5
lines changed

2 files changed

+121
-5
lines changed

scripts/cluster.py

+118
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from struct import unpack, pack
2+
from sys import argv
3+
from functools import partial
4+
5+
from faiss import Kmeans
6+
import numpy as np
7+
from tqdm import tqdm
8+
9+
10+
def default_filter(vec):
11+
return True
12+
13+
14+
def reservoir_sampling(iterator, k: int):
15+
"""Reservoir sampling from an iterator."""
16+
res = []
17+
while len(res) < k:
18+
res.append(next(iterator))
19+
for i, vec in enumerate(iterator, k + 1):
20+
j = np.random.randint(0, i)
21+
if j < k:
22+
res[j] = vec
23+
return res
24+
25+
26+
def read_vec_yield(
27+
filepath: str, vec_type: np.dtype = np.float32, filter=default_filter
28+
):
29+
"""Read vectors and yield an iterator."""
30+
size = np.dtype(vec_type).itemsize
31+
with open(filepath, "rb") as f:
32+
while True:
33+
try:
34+
buf = f.read(4)
35+
if len(buf) == 0:
36+
break
37+
dim = unpack("<i", buf)[0]
38+
vec = np.frombuffer(f.read(dim * size), dtype=vec_type)
39+
if filter(vec):
40+
yield vec
41+
except Exception as err:
42+
print(err)
43+
break
44+
45+
46+
def read_vec(filepath: str, vec_type: np.dtype = np.float32):
47+
"""Read vectors from a file. Support `fvecs`, `ivecs` and `bvecs` format.
48+
Args:
49+
filepath: The path of the file.
50+
vec_type: The type of the vectors.
51+
"""
52+
size = np.dtype(vec_type).itemsize
53+
with open(filepath, "rb") as f:
54+
vecs = []
55+
while True:
56+
try:
57+
buf = f.read(4)
58+
if len(buf) == 0:
59+
break
60+
dim = unpack("<i", buf)[0]
61+
vecs.append(np.frombuffer(f.read(dim * size), dtype=vec_type))
62+
except Exception as err:
63+
print(err)
64+
break
65+
return np.array(vecs)
66+
67+
68+
def write_vec(filepath: str, vecs: np.ndarray, vec_type: np.dtype = np.float32):
69+
"""Write vectors to a file. Support `fvecs`, `ivecs` and `bvecs` format."""
70+
with open(filepath, "wb") as f:
71+
for vec in vecs:
72+
f.write(pack("<i", len(vec)))
73+
f.write(vec.tobytes())
74+
75+
76+
def hierarchical_kmeans(vecs, n_cluster_top, n_cluster_down):
77+
dim = vecs.shape[1]
78+
top = Kmeans(dim, n_cluster_top)
79+
top.train(vecs)
80+
_, labels = top.assign(vecs)
81+
82+
centroids = []
83+
for i in range(n_cluster_top):
84+
down = Kmeans(dim, n_cluster_down)
85+
down.train(vecs[labels == i])
86+
centroids.append(down.centroids)
87+
88+
return np.vstack(centroids)
89+
90+
91+
if __name__ == "__main__":
92+
filename = argv[1]
93+
top_n = int(argv[2])
94+
down_n = int(argv[3])
95+
max_point_per_cluster = 256
96+
top_points = reservoir_sampling(
97+
read_vec_yield(filename), top_n * max_point_per_cluster
98+
)
99+
dim = top_points[0].shape[0]
100+
101+
top_cluster = Kmeans(dim, top_n)
102+
top_cluster.train(top_points)
103+
104+
def filter_label(label, vec):
105+
_, label = top_cluster.assign(vec.reshape((1, -1)))
106+
return label[0] == label
107+
108+
centroids = []
109+
for i in tqdm(range(top_n)):
110+
down_points = reservoir_sampling(
111+
read_vec_yield(filename, filter=partial(filter_label, i)),
112+
down_n * max_point_per_cluster,
113+
)
114+
down_cluster = Kmeans(dim, down_n)
115+
down_cluster.train(down_points)
116+
centroids.append(down_cluster.centroids)
117+
118+
write_vec(f"centroids_{top_n}_{down_n}.fvecs", np.vstack(centroids))

src/simd.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,10 @@ pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 {
2020
assert_eq!(lhs.len(), rhs.len());
2121
let mut lhs_ptr = lhs.as_ptr();
2222
let mut rhs_ptr = rhs.as_ptr();
23-
let block_16_num = lhs.len() >> 4;
24-
let rest_num = lhs.len() & 0b1111;
2523
let (mut diff, mut vx, mut vy): (__m256, __m256, __m256);
2624
let mut sum = _mm256_setzero_ps();
2725

28-
for _ in 0..block_16_num {
26+
for _ in 0..(lhs.len() / 16) {
2927
vx = _mm256_loadu_ps(lhs_ptr);
3028
vy = _mm256_loadu_ps(rhs_ptr);
3129
lhs_ptr = lhs_ptr.add(8);
@@ -41,7 +39,7 @@ pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 {
4139
sum = _mm256_fmadd_ps(diff, diff, sum);
4240
}
4341

44-
for _ in 0..rest_num / 8 {
42+
for _ in 0..((lhs.len() & 0b1111) / 8) {
4543
vx = _mm256_loadu_ps(lhs_ptr);
4644
vy = _mm256_loadu_ps(rhs_ptr);
4745
lhs_ptr = lhs_ptr.add(8);
@@ -65,7 +63,7 @@ pub unsafe fn l2_squared_distance(lhs: &[f32], rhs: &[f32]) -> f32 {
6563
}
6664

6765
let mut res = reduce_f32_256(sum);
68-
for _ in 0..rest_num {
66+
for _ in 0..(lhs.len() & 0b111) {
6967
let residual = *lhs_ptr - *rhs_ptr;
7068
res += residual * residual;
7169
lhs_ptr = lhs_ptr.add(1);

0 commit comments

Comments
 (0)