Skip to content

Commit de3b729

Browse files
committed
AVX2 SIMD optimization for LD calculations
Adds vectorized r2 and pairwiseDiffs functions using AVX2 intrinsics. Processes 8 samples simultaneously with automatic CPU detection and scalar fallback. Maintains bit-exact numerical compatibility.
1 parent 6649ffd commit de3b729

File tree

2 files changed

+245
-1
lines changed

2 files changed

+245
-1
lines changed

diploshic/utils.c

Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,129 @@
11
#include <math.h>
2+
#include <immintrin.h> // AVX2 intrinsics
3+
#ifdef __x86_64__
4+
#include <cpuid.h> // CPU feature detection
5+
#endif
26

7+
// CPU feature detection - check for AVX2 support
8+
static int has_avx2_support() {
9+
#ifdef __x86_64__
10+
unsigned int eax, ebx, ecx, edx;
11+
if (__get_cpuid(7, &eax, &ebx, &ecx, &edx)) {
12+
return (ebx & (1 << 5)) != 0; // bit 5 is AVX2
13+
}
14+
#endif
15+
return 0;
16+
}
17+
18+
#ifdef __AVX2__
19+
/*
20+
* SIMD-optimized r2 calculation using AVX2
21+
* Processes 8 samples simultaneously using 256-bit vectors
22+
* Expected speedup: 4-6x over scalar version
23+
*/
24+
double r2_avx2(int nSamps, int *haps, int i, int j){
25+
double pi = 0.0;
26+
double pj = 0.0;
27+
double pij = 0.0;
28+
double count = 0.0;
29+
30+
int k;
31+
32+
// Process 8 samples at a time with AVX2
33+
int vec_iterations = nSamps / 8;
34+
int remainder = nSamps % 8;
35+
36+
// Accumulators for vectorized portion
37+
__m256i pi_vec = _mm256_setzero_si256();
38+
__m256i pj_vec = _mm256_setzero_si256();
39+
__m256i pij_vec = _mm256_setzero_si256();
40+
__m256i count_vec = _mm256_setzero_si256();
41+
42+
// Constants for masking
43+
__m256i zero = _mm256_setzero_si256();
44+
__m256i one = _mm256_set1_epi32(1);
45+
46+
// Base pointers for the two SNPs
47+
int *hap_i_base = &haps[i * nSamps];
48+
int *hap_j_base = &haps[j * nSamps];
49+
50+
// Vectorized loop: process 8 samples per iteration
51+
for(k = 0; k < vec_iterations * 8; k += 8){
52+
// Load 8 haplotype values for SNP i and SNP j
53+
__m256i hap_i_vec = _mm256_loadu_si256((__m256i*)&hap_i_base[k]);
54+
__m256i hap_j_vec = _mm256_loadu_si256((__m256i*)&hap_j_base[k]);
55+
56+
// Create validity masks: valid if haplotype is 0 or 1
57+
__m256i valid_i_0 = _mm256_cmpeq_epi32(hap_i_vec, zero);
58+
__m256i valid_i_1 = _mm256_cmpeq_epi32(hap_i_vec, one);
59+
__m256i valid_i = _mm256_or_si256(valid_i_0, valid_i_1);
60+
61+
__m256i valid_j_0 = _mm256_cmpeq_epi32(hap_j_vec, zero);
62+
__m256i valid_j_1 = _mm256_cmpeq_epi32(hap_j_vec, one);
63+
__m256i valid_j = _mm256_or_si256(valid_j_0, valid_j_1);
64+
65+
// Both must be valid
66+
__m256i valid_both = _mm256_and_si256(valid_i, valid_j);
67+
68+
// Create masks for counting
69+
__m256i is_one_i = _mm256_and_si256(valid_i_1, valid_both);
70+
__m256i is_one_j = _mm256_and_si256(valid_j_1, valid_both);
71+
72+
// Both are 1: AND the masks
73+
__m256i both_one = _mm256_and_si256(is_one_i, is_one_j);
74+
75+
// Accumulate counts (masks are -1 for true, 0 for false)
76+
// Subtract because mask is -1, effectively adding 1
77+
pi_vec = _mm256_sub_epi32(pi_vec, is_one_i);
78+
pj_vec = _mm256_sub_epi32(pj_vec, is_one_j);
79+
pij_vec = _mm256_sub_epi32(pij_vec, both_one);
80+
count_vec = _mm256_sub_epi32(count_vec, valid_both);
81+
}
82+
83+
// Horizontal sum: add all 8 lanes together
84+
int pi_array[8], pj_array[8], pij_array[8], count_array[8];
85+
_mm256_storeu_si256((__m256i*)pi_array, pi_vec);
86+
_mm256_storeu_si256((__m256i*)pj_array, pj_vec);
87+
_mm256_storeu_si256((__m256i*)pij_array, pij_vec);
88+
_mm256_storeu_si256((__m256i*)count_array, count_vec);
89+
90+
for(int lane = 0; lane < 8; lane++){
91+
pi += pi_array[lane];
92+
pj += pj_array[lane];
93+
pij += pij_array[lane];
94+
count += count_array[lane];
95+
}
96+
97+
// Handle remainder samples with scalar code
98+
for(k = vec_iterations * 8; k < nSamps; k++){
99+
int hap_i = hap_i_base[k];
100+
int hap_j = hap_j_base[k];
101+
102+
if((hap_i == 1 || hap_i == 0) && (hap_j == 1 || hap_j == 0)){
103+
if(hap_i == 1) pi++;
104+
if(hap_j == 1) pj++;
105+
if(hap_i == 1 && hap_j == 1) pij++;
106+
count += 1.0;
107+
}
108+
}
109+
110+
// Same final computation as original (bit-exact)
111+
if (count == 0.0){
112+
return(-1.0);
113+
}
114+
else{
115+
pi /= count;
116+
pj /= count;
117+
pij /= count;
118+
119+
double Dij = pij - (pi*pj);
120+
121+
return (Dij*Dij) / ((pi*(1.0-pi)) * (pj*(1.0-pj)));
122+
}
123+
}
124+
#endif
125+
126+
// Scalar version of r2 (original implementation, used as fallback)
3127
double r2(int nSamps, int *haps, int i, int j){
4128
double pi = 0.0;
5129
double pj = 0.0;
@@ -43,6 +167,26 @@ double r2(int nSamps, int *haps, int i, int j){
43167
void computeR2Matrix(int nSamps, int nSnps, int *haps, double *r2Matrix){
44168
double r2Val;
45169
int i, j;
170+
171+
#ifdef __AVX2__
172+
// Use AVX2 if compiled with support and CPU has the capability
173+
static int use_avx2 = -1; // -1 = not checked, 0 = no, 1 = yes
174+
if (use_avx2 == -1) {
175+
use_avx2 = has_avx2_support();
176+
}
177+
178+
if (use_avx2) {
179+
for (i=0; i<nSnps-1; i++){
180+
for (j=i+1; j<nSnps; j++){
181+
r2Val = r2_avx2(nSamps, haps, i, j);
182+
r2Matrix[i*nSnps +j] = r2Val;
183+
}
184+
}
185+
return;
186+
}
187+
#endif
188+
189+
// Fallback to scalar version
46190
for (i=0; i<nSnps-1; i++){
47191
for (j=i+1; j<nSnps; j++){
48192
r2Val = r2(nSamps, haps, i, j);
@@ -116,9 +260,109 @@ void omega(int nSnps, double *r2Matrix, double *omegaMax){
116260
}
117261
}
118262

263+
#ifdef __AVX2__
264+
/*
265+
* SIMD-optimized pairwiseDiffs using AVX2
266+
* Processes 8 SNPs at a time for each sample pair
267+
* Expected speedup: 4-6x over scalar version
268+
*/
269+
void pairwiseDiffs_avx2(int nSamps, int nSnps, int *haps, double *diffLs){
270+
int i, j, k;
271+
int pairsSeen = 0;
272+
273+
int vec_iterations = nSnps / 8;
274+
int remainder = nSnps % 8;
275+
276+
__m256i zero = _mm256_setzero_si256();
277+
__m256i one = _mm256_set1_epi32(1);
278+
279+
for(i=0; i<nSamps-1; i++){
280+
for(j=i+1; j<nSamps; j++){
281+
int diffs = 0;
282+
283+
// Vectorized SNP comparison
284+
__m256i diff_vec = _mm256_setzero_si256();
285+
286+
for(k=0; k < vec_iterations * 8; k += 8){
287+
// Load 8 SNPs for sample i and j
288+
__m256i snps_i = _mm256_set_epi32(
289+
haps[(k+7)*nSamps + i], haps[(k+6)*nSamps + i],
290+
haps[(k+5)*nSamps + i], haps[(k+4)*nSamps + i],
291+
haps[(k+3)*nSamps + i], haps[(k+2)*nSamps + i],
292+
haps[(k+1)*nSamps + i], haps[(k+0)*nSamps + i]
293+
);
294+
__m256i snps_j = _mm256_set_epi32(
295+
haps[(k+7)*nSamps + j], haps[(k+6)*nSamps + j],
296+
haps[(k+5)*nSamps + j], haps[(k+4)*nSamps + j],
297+
haps[(k+3)*nSamps + j], haps[(k+2)*nSamps + j],
298+
haps[(k+1)*nSamps + j], haps[(k+0)*nSamps + j]
299+
);
300+
301+
// Check validity: both must be in [0,1]
302+
__m256i valid_i = _mm256_and_si256(
303+
_mm256_cmpgt_epi32(snps_i, _mm256_set1_epi32(-1)),
304+
_mm256_cmpgt_epi32(_mm256_set1_epi32(2), snps_i)
305+
);
306+
__m256i valid_j = _mm256_and_si256(
307+
_mm256_cmpgt_epi32(snps_j, _mm256_set1_epi32(-1)),
308+
_mm256_cmpgt_epi32(_mm256_set1_epi32(2), snps_j)
309+
);
310+
__m256i valid_both = _mm256_and_si256(valid_i, valid_j);
311+
312+
// Compare: are they different?
313+
__m256i different = _mm256_andnot_si256(
314+
_mm256_cmpeq_epi32(snps_i, snps_j),
315+
valid_both
316+
);
317+
318+
// Accumulate differences
319+
diff_vec = _mm256_sub_epi32(diff_vec, different);
320+
}
321+
322+
// Horizontal sum of diff_vec
323+
int diff_array[8];
324+
_mm256_storeu_si256((__m256i*)diff_array, diff_vec);
325+
for(int lane = 0; lane < 8; lane++){
326+
diffs += diff_array[lane];
327+
}
328+
329+
// Handle remainder SNPs with scalar code
330+
for(k = vec_iterations * 8; k < nSnps; k++){
331+
int basei = haps[k*nSamps + i];
332+
int basej = haps[k*nSamps + j];
333+
if(basei >= 0 && basei <= 1 && basej >= 0 && basej <= 1){
334+
if (basei != basej){
335+
diffs += 1;
336+
}
337+
}
338+
}
339+
340+
diffLs[pairsSeen] = diffs;
341+
pairsSeen += 1;
342+
}
343+
}
344+
}
345+
#endif
346+
347+
// Scalar version of pairwiseDiffs (original implementation, used as fallback)
119348
void pairwiseDiffs(int nSamps, int nSnps, int *haps, double *diffLs){
120349
int i, j, k, basei, basej, diffs;
121350
int pairsSeen = 0;
351+
352+
#ifdef __AVX2__
353+
// Use AVX2 if compiled with support and CPU has the capability
354+
static int use_avx2 = -1;
355+
if (use_avx2 == -1) {
356+
use_avx2 = has_avx2_support();
357+
}
358+
359+
if (use_avx2) {
360+
pairwiseDiffs_avx2(nSamps, nSnps, haps, diffLs);
361+
return;
362+
}
363+
#endif
364+
365+
// Fallback to scalar version
122366
for(i=0; i<nSamps-1; i++){
123367
for(j=i+1; j<nSamps; j++){
124368
diffs = 0;

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
shic_stats = Extension(
88
"diploshic.shicstats",
99
sources=["diploshic/shicstats.pyf", "diploshic/utils.c"],
10-
extra_compile_args=['-O3', '-march=native'],
10+
extra_compile_args=['-O3', '-march=native', '-mavx2'],
1111
)
1212
setup(
1313
name="diploSHIC",

0 commit comments

Comments
 (0)