Skip to content

Commit 4ed17da

Browse files
mulugetam0ctopus13prime
authored andcommitted
Use native AVX512-FP16 instructions to speed up FP16 bulk similarity.
Signed-off-by: Mulugeta Mammo <mulugeta.mammo@intel.com>
1 parent d66d79e commit 4ed17da

File tree

2 files changed

+365
-7
lines changed

2 files changed

+365
-7
lines changed
Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
#include <immintrin.h>
2+
#include <algorithm>
3+
#include <array>
4+
#include <cstddef>
5+
#include <stdint.h>
6+
#include <cmath>
7+
8+
#include "simd_similarity_function_common.cpp"
9+
#include "faiss_score_to_lucene_transform.cpp"
10+
11+
// Max FP16 accumulations before draining to FP32. Trades accuracy for speed.
12+
// Lower values improve precision; higher values improve performance.
13+
static constexpr int32_t FP16_DRAIN_INTERVAL = 4;
14+
15+
// Drain FP16 accumulator into FP32 accumulator, then zero the FP16 accum.
16+
static inline void drain_ph_to_ps(__m512& acc_ps, __m512h& acc_ph) {
17+
__m512i raw = _mm512_castph_si512(acc_ph);
18+
acc_ps = _mm512_add_ps(acc_ps,
19+
_mm512_cvtph_ps(_mm512_castsi512_si256(raw)));
20+
acc_ps = _mm512_add_ps(acc_ps,
21+
_mm512_cvtph_ps(_mm512_extracti64x4_epi64(raw, 1)));
22+
acc_ph = _mm512_castsi512_ph(_mm512_setzero_si512());
23+
}
24+
25+
// Convert 32 floats to a packed 512-bit float16 register (2x16 FP32 -> 1x32 FP16).
26+
static inline __m512h cvt2x16_fp32_to_ph(const float* p) {
27+
__m256i lo = _mm512_cvtps_ph(_mm512_loadu_ps(p),
28+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
29+
__m256i hi = _mm512_cvtps_ph(_mm512_loadu_ps(p + 16),
30+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
31+
return _mm512_castsi512_ph(
32+
_mm512_inserti64x4(_mm512_castsi256_si512(lo), hi, 1));
33+
}
34+
35+
//
36+
// FP16 IP using native AVX512-FP16. FP32 query -> FP16.
37+
// Limit precision loss by draining FP16 accumulation to FP32 at regular intervals.
38+
//
39+
//
40+
template <BulkScoreTransform BulkScoreTransformFunc, ScoreTransform ScoreTransformFunc>
41+
struct AVX512SPRFP16MaxIP final : BaseSimilarityFunction<BulkScoreTransformFunc, ScoreTransformFunc> {
42+
void calculateSimilarityInBulk(SimdVectorSearchContext* srchContext,
43+
int32_t* internalVectorIds,
44+
float* scores,
45+
const int32_t numVectors) {
46+
47+
int32_t processedCount = 0;
48+
const auto* queryPtr = (const float*) srchContext->queryVectorSimdAligned;
49+
const int32_t dim = srchContext->dimension;
50+
51+
constexpr int32_t vecBlock = 8;
52+
constexpr int32_t elemPerLoad = 32;
53+
54+
// SIMD-aligned dim and tail dim
55+
const int32_t simdDim = (dim / elemPerLoad) * elemPerLoad;
56+
const int32_t tailDim = dim - simdDim;
57+
58+
// Precompute tail mask
59+
const __mmask32 tailMask = tailDim > 0 ? (__mmask32)((1ULL << tailDim) - 1) : 0;
60+
61+
// Do same for the tail part
62+
const int32_t tailMid16 = (tailDim >= 16) ? 16 : 0;
63+
const int32_t tailFinal = tailDim - tailMid16;
64+
const __mmask16 tailFinalMask = tailFinal > 0 ? (__mmask16)((1U << tailFinal) - 1) : 0;
65+
66+
__m512h sumFp16[vecBlock]; // FP16 accumulators (32 lanes each)
67+
__m512 sumFp32[vecBlock]; // FP32 accumulators (16 lanes each)
68+
69+
for (; processedCount <= numVectors - vecBlock; processedCount += vecBlock) {
70+
const uint8_t* vectors[vecBlock];
71+
srchContext->getVectorPointersInBulk((uint8_t**)vectors, &internalVectorIds[processedCount], vecBlock);
72+
73+
#pragma unroll
74+
for (int32_t v = 0; v < vecBlock; ++v) {
75+
sumFp16[v] = _mm512_castsi512_ph(_mm512_setzero_si512());
76+
sumFp32[v] = _mm512_setzero_ps();
77+
}
78+
79+
int32_t drainCount = 0;
80+
81+
// Mask-free hot loop
82+
for (int32_t i = 0; i < simdDim; i += elemPerLoad) {
83+
__m512h q = cvt2x16_fp32_to_ph(queryPtr + i);
84+
85+
if ((i + elemPerLoad) < dim) {
86+
const int32_t nextByteOffset = (i + elemPerLoad) * 2;
87+
#pragma unroll
88+
for (int32_t v = 0; v < vecBlock; ++v) {
89+
__builtin_prefetch(vectors[v] + nextByteOffset, 0, 3);
90+
}
91+
__builtin_prefetch(queryPtr + i + elemPerLoad, 0, 3);
92+
}
93+
94+
#pragma unroll
95+
for (int32_t v = 0; v < vecBlock; ++v) {
96+
__m512h vec = _mm512_loadu_ph(vectors[v] + 2 * i);
97+
sumFp16[v] = _mm512_fmadd_ph(q, vec, sumFp16[v]);
98+
}
99+
100+
if (++drainCount >= FP16_DRAIN_INTERVAL) {
101+
#pragma unroll
102+
for (int32_t v = 0; v < vecBlock; ++v)
103+
drain_ph_to_ps(sumFp32[v], sumFp16[v]);
104+
drainCount = 0;
105+
}
106+
}
107+
108+
// Single masked tail
109+
if (tailDim > 0) {
110+
__m256i qLoH, qHiH;
111+
if (tailMid16 > 0) {
112+
qLoH = _mm512_cvtps_ph(_mm512_loadu_ps(queryPtr + simdDim),
113+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
114+
qHiH = tailFinal > 0
115+
? _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim + 16),
116+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
117+
: _mm256_setzero_si256();
118+
} else {
119+
qLoH = _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim),
120+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
121+
qHiH = _mm256_setzero_si256();
122+
}
123+
__m512h q = _mm512_castsi512_ph(
124+
_mm512_inserti64x4(_mm512_castsi256_si512(qLoH), qHiH, 1));
125+
126+
#pragma unroll
127+
for (int32_t v = 0; v < vecBlock; ++v) {
128+
__m512h vec = _mm512_castsi512_ph(
129+
_mm512_maskz_loadu_epi16(tailMask, vectors[v] + 2 * simdDim));
130+
sumFp16[v] = _mm512_fmadd_ph(q, vec, sumFp16[v]);
131+
}
132+
}
133+
134+
#pragma unroll
135+
for (int32_t v = 0; v < vecBlock; ++v) {
136+
drain_ph_to_ps(sumFp32[v], sumFp16[v]);
137+
scores[processedCount + v] = _mm512_reduce_add_ps(sumFp32[v]);
138+
}
139+
}
140+
141+
// Tail loop for remaining vectors
142+
for (; processedCount < numVectors; ++processedCount) {
143+
const auto* vecPtr = (const uint8_t*) srchContext->getVectorPointer(internalVectorIds[processedCount]);
144+
__m512h sumFp16 = _mm512_castsi512_ph(_mm512_setzero_si512());
145+
__m512 sumFp32 = _mm512_setzero_ps();
146+
int32_t drainCount = 0;
147+
148+
for (int32_t i = 0; i < simdDim; i += elemPerLoad) {
149+
sumFp16 = _mm512_fmadd_ph(cvt2x16_fp32_to_ph(queryPtr + i),
150+
_mm512_loadu_ph(vecPtr + 2 * i), sumFp16);
151+
if (++drainCount >= FP16_DRAIN_INTERVAL) {
152+
drain_ph_to_ps(sumFp32, sumFp16);
153+
drainCount = 0;
154+
}
155+
}
156+
157+
if (tailDim > 0) {
158+
__m256i qLoH, qHiH;
159+
if (tailMid16 > 0) {
160+
qLoH = _mm512_cvtps_ph(_mm512_loadu_ps(queryPtr + simdDim),
161+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
162+
qHiH = tailFinal > 0
163+
? _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim + 16),
164+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
165+
: _mm256_setzero_si256();
166+
} else {
167+
qLoH = _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim),
168+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
169+
qHiH = _mm256_setzero_si256();
170+
}
171+
__m512h q = _mm512_castsi512_ph(
172+
_mm512_inserti64x4(_mm512_castsi256_si512(qLoH), qHiH, 1));
173+
sumFp16 = _mm512_fmadd_ph(q,
174+
_mm512_castsi512_ph(_mm512_maskz_loadu_epi16(tailMask, vecPtr + 2 * simdDim)),
175+
sumFp16);
176+
}
177+
178+
drain_ph_to_ps(sumFp32, sumFp16);
179+
scores[processedCount] = _mm512_reduce_add_ps(sumFp32);
180+
}
181+
182+
BulkScoreTransformFunc(scores, numVectors);
183+
}
184+
};
185+
186+
//
187+
// FP16 L2
188+
//
189+
template <BulkScoreTransform BulkScoreTransformFunc, ScoreTransform ScoreTransformFunc>
190+
struct AVX512SPRFP16L2 final : BaseSimilarityFunction<BulkScoreTransformFunc, ScoreTransformFunc> {
191+
void calculateSimilarityInBulk(SimdVectorSearchContext* srchContext,
192+
int32_t* internalVectorIds,
193+
float* scores,
194+
const int32_t numVectors) {
195+
196+
int32_t processedCount = 0;
197+
const auto* queryPtr = (const float*) srchContext->queryVectorSimdAligned;
198+
const int32_t dim = srchContext->dimension;
199+
200+
constexpr int32_t vecBlock = 8;
201+
constexpr int32_t elemPerLoad = 32;
202+
203+
// SIMD-aligned dim and tail dim
204+
const int32_t simdDim = (dim / elemPerLoad) * elemPerLoad;
205+
const int32_t tailDim = dim - simdDim;
206+
207+
// Precompute tail mask
208+
const __mmask32 tailMask = tailDim > 0 ? (__mmask32)((1ULL << tailDim) - 1) : 0;
209+
210+
// The tail part
211+
const int32_t tailMid16 = (tailDim >= 16) ? 16 : 0;
212+
const int32_t tailFinal = tailDim - tailMid16;
213+
const __mmask16 tailFinalMask = tailFinal > 0 ? (__mmask16)((1U << tailFinal) - 1) : 0;
214+
215+
__m512h sumFp16[vecBlock]; // FP16 accumulators (32 lanes each)
216+
__m512 sumFp32[vecBlock]; // FP32 accumulators (16 lanes each)
217+
218+
for (; processedCount <= numVectors - vecBlock; processedCount += vecBlock) {
219+
const uint8_t* vectors[vecBlock];
220+
srchContext->getVectorPointersInBulk((uint8_t**)vectors, &internalVectorIds[processedCount], vecBlock);
221+
222+
#pragma unroll
223+
for (int32_t v = 0; v < vecBlock; ++v) {
224+
sumFp16[v] = _mm512_castsi512_ph(_mm512_setzero_si512());
225+
sumFp32[v] = _mm512_setzero_ps();
226+
}
227+
228+
int32_t drainCount = 0;
229+
230+
// Mask-free hot loop
231+
for (int32_t i = 0; i < simdDim; i += elemPerLoad) {
232+
__m512h q = cvt2x16_fp32_to_ph(queryPtr + i);
233+
234+
if ((i + elemPerLoad) < dim) {
235+
const int32_t nextByteOffset = (i + elemPerLoad) * 2;
236+
#pragma unroll
237+
for (int32_t v = 0; v < vecBlock; ++v) {
238+
__builtin_prefetch(vectors[v] + nextByteOffset, 0, 3);
239+
}
240+
__builtin_prefetch(queryPtr + i + elemPerLoad, 0, 3);
241+
}
242+
243+
#pragma unroll
244+
for (int32_t v = 0; v < vecBlock; ++v) {
245+
__m512h vec = _mm512_loadu_ph(vectors[v] + 2 * i);
246+
__m512h diff = _mm512_sub_ph(q, vec);
247+
sumFp16[v] = _mm512_fmadd_ph(diff, diff, sumFp16[v]);
248+
}
249+
250+
if (++drainCount >= FP16_DRAIN_INTERVAL) {
251+
#pragma unroll
252+
for (int32_t v = 0; v < vecBlock; ++v)
253+
drain_ph_to_ps(sumFp32[v], sumFp16[v]);
254+
drainCount = 0;
255+
}
256+
}
257+
258+
// Single masked tail
259+
if (tailDim > 0) {
260+
__m256i qLoH, qHiH;
261+
if (tailMid16 > 0) {
262+
qLoH = _mm512_cvtps_ph(_mm512_loadu_ps(queryPtr + simdDim),
263+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
264+
qHiH = tailFinal > 0
265+
? _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim + 16),
266+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
267+
: _mm256_setzero_si256();
268+
} else {
269+
qLoH = _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim),
270+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
271+
qHiH = _mm256_setzero_si256();
272+
}
273+
__m512h q = _mm512_castsi512_ph(
274+
_mm512_inserti64x4(_mm512_castsi256_si512(qLoH), qHiH, 1));
275+
276+
#pragma unroll
277+
for (int32_t v = 0; v < vecBlock; ++v) {
278+
__m512h vec = _mm512_castsi512_ph(
279+
_mm512_maskz_loadu_epi16(tailMask, vectors[v] + 2 * simdDim));
280+
__m512h diff = _mm512_sub_ph(q, vec);
281+
sumFp16[v] = _mm512_fmadd_ph(diff, diff, sumFp16[v]);
282+
}
283+
}
284+
285+
#pragma unroll
286+
for (int32_t v = 0; v < vecBlock; ++v) {
287+
drain_ph_to_ps(sumFp32[v], sumFp16[v]);
288+
scores[processedCount + v] = _mm512_reduce_add_ps(sumFp32[v]);
289+
}
290+
}
291+
292+
// Tail loop for remaining vectors
293+
for (; processedCount < numVectors; ++processedCount) {
294+
const auto* vecPtr = (const uint8_t*) srchContext->getVectorPointer(internalVectorIds[processedCount]);
295+
__m512h sumFp16 = _mm512_castsi512_ph(_mm512_setzero_si512());
296+
__m512 sumFp32 = _mm512_setzero_ps();
297+
int32_t drainCount = 0;
298+
299+
for (int32_t i = 0; i < simdDim; i += elemPerLoad) {
300+
__m512h diff = _mm512_sub_ph(cvt2x16_fp32_to_ph(queryPtr + i),
301+
_mm512_loadu_ph(vecPtr + 2 * i));
302+
sumFp16 = _mm512_fmadd_ph(diff, diff, sumFp16);
303+
if (++drainCount >= FP16_DRAIN_INTERVAL) {
304+
drain_ph_to_ps(sumFp32, sumFp16);
305+
drainCount = 0;
306+
}
307+
}
308+
309+
if (tailDim > 0) {
310+
__m256i qLoH, qHiH;
311+
if (tailMid16 > 0) {
312+
qLoH = _mm512_cvtps_ph(_mm512_loadu_ps(queryPtr + simdDim),
313+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
314+
qHiH = tailFinal > 0
315+
? _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim + 16),
316+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)
317+
: _mm256_setzero_si256();
318+
} else {
319+
qLoH = _mm512_cvtps_ph(_mm512_maskz_loadu_ps(tailFinalMask, queryPtr + simdDim),
320+
_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
321+
qHiH = _mm256_setzero_si256();
322+
}
323+
__m512h q = _mm512_castsi512_ph(
324+
_mm512_inserti64x4(_mm512_castsi256_si512(qLoH), qHiH, 1));
325+
__m512h diff = _mm512_sub_ph(q,
326+
_mm512_castsi512_ph(_mm512_maskz_loadu_epi16(tailMask, vecPtr + 2 * simdDim)));
327+
sumFp16 = _mm512_fmadd_ph(diff, diff, sumFp16);
328+
}
329+
330+
drain_ph_to_ps(sumFp32, sumFp16);
331+
scores[processedCount] = _mm512_reduce_add_ps(sumFp32);
332+
}
333+
334+
BulkScoreTransformFunc(scores, numVectors);
335+
}
336+
};
337+
338+
339+
//
340+
// FP16
341+
//
342+
// 1. Max IP
343+
AVX512SPRFP16MaxIP<FaissScoreToLuceneScoreTransform::ipToMaxIpTransformBulk, FaissScoreToLuceneScoreTransform::ipToMaxIpTransform> FP16_MAX_INNER_PRODUCT_SIMIL_FUNC;
344+
// 2. L2
345+
AVX512SPRFP16L2<FaissScoreToLuceneScoreTransform::l2TransformBulk, FaissScoreToLuceneScoreTransform::l2Transform> FP16_L2_SIMIL_FUNC;
346+
347+
#ifndef __NO_SELECT_FUNCTION
348+
SimilarityFunction* SimilarityFunction::selectSimilarityFunction(const NativeSimilarityFunctionType nativeFunctionType) {
349+
if (nativeFunctionType == NativeSimilarityFunctionType::FP16_MAXIMUM_INNER_PRODUCT) {
350+
return &FP16_MAX_INNER_PRODUCT_SIMIL_FUNC;
351+
} else if (nativeFunctionType == NativeSimilarityFunctionType::FP16_L2) {
352+
return &FP16_L2_SIMIL_FUNC;
353+
}
354+
355+
throw std::runtime_error("Invalid native similarity function type was given, nativeFunctionType="
356+
+ std::to_string(static_cast<int32_t>(nativeFunctionType)));
357+
}
358+
#endif

jni/src/simd/similarity_function/similarity_function.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
* GitHub history for details.
1010
*/
1111

12-
#ifdef KNN_HAVE_AVX512
12+
#if defined(KNN_HAVE_AVX512_SPR)
13+
// Convert FP32 query vector to FP16 and do bulk
14+
// similarity with native AVX512-FP16 instructions.
15+
#include "avx512_spr_simd_similarity_function.cpp"
16+
#elif defined(KNN_HAVE_AVX512)
17+
// Convert FP16 vectors to FP32 and do bulk similarity.
1318
#include "avx512_simd_similarity_function.cpp"
14-
#elif KNN_HAVE_AVX512_SPR
15-
// Since we convert FP16 to FP32 then do bulk operation,
16-
// we're not really using SPR instruction set.
17-
// Therefore, both AVX512 and AVX512_SPR are sharing the same code piece.
18-
#include "avx512_simd_similarity_function.cpp"
19-
#elif KNN_HAVE_ARM_FP16
19+
#elif defined(KNN_HAVE_ARM_FP16)
2020
#include "arm_neon_simd_similarity_function.cpp"
2121
#else
2222
#include "default_simd_similarity_function.cpp"

0 commit comments

Comments
 (0)