Skip to content

Commit d3f368d

Browse files
0ctopus13primenaveentatikonda
authored andcommitted
Faiss BBQ Bulk SIMD. (#3170)
Signed-off-by: Dooyong Kim <kdooyong@amazon.com>
1 parent c7be037 commit d3f368d

File tree

10 files changed

+915
-14
lines changed

10 files changed

+915
-14
lines changed

jni/include/org_opensearch_knn_jni_SimdVectorComputeService.h

Lines changed: 10 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

jni/include/platform_defs.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,20 @@
1919
#define LIKELY(x) (x)
2020
#define UNLIKELY(x) (x)
2121
#endif
22+
23+
#if defined(_MSC_VER)
24+
// Microsoft Visual C++
25+
#define FORCE_INLINE __forceinline
26+
#elif defined(__GNUC__) || defined(__clang__)
27+
// GCC and Clang
28+
#define FORCE_INLINE inline __attribute__((always_inline))
29+
#else
30+
// Fallback for other compilers
31+
#define FORCE_INLINE inline
32+
#endif // FORCE_INLINE_H
33+
34+
#if defined(__GNUC__) || defined(__clang__)
35+
#define HOT_SPOT [[gnu::hot]]
36+
#else
37+
#define HOT_SPOT
38+
#endif

jni/include/simd/similarity_function/similarity_function.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ namespace knn_jni::simd::similarity_function {
1212
// Max inner product will transform inner product to v < 0 ? 1 / (1 - v) : (1 + v)
1313
FP16_MAXIMUM_INNER_PRODUCT,
1414
// L2 for FP16
15-
FP16_L2
15+
FP16_L2,
16+
BBQ_IP,
17+
BBQ_L2
1618
};
1719

1820
struct SimilarityFunction;

jni/src/org_opensearch_knn_jni_SimdVectorComputeService.cpp

Lines changed: 60 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#include <cstring>
2+
#include <limits>
3+
#include <algorithm>
14
#include "org_opensearch_knn_jni_SimdVectorComputeService.h"
25
#include "jni_util.h"
36
#include "simd/similarity_function/similarity_function.h"
@@ -26,8 +29,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
2629
JNI_UTIL.Uninitialize(env);
2730
}
2831

29-
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scoreSimilarityInBulk
32+
JNIEXPORT jfloat JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scoreSimilarityInBulk
3033
(JNIEnv *env, jclass clazz, jintArray internalVectorIds, jfloatArray jscores, const jint numVectors) {
34+
if (numVectors <= 0) {
35+
return std::numeric_limits<float>::min();
36+
}
3137

3238
try {
3339
// Get search context
@@ -38,7 +44,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scor
3844

3945
// Get pointers of vectorIds and scores
4046
jint* vectorIds = static_cast<jint*>(JNI_UTIL.GetPrimitiveArrayCritical(env, internalVectorIds, nullptr));
47+
knn_jni::JNIReleaseElements releaseVectorIds {[=]{
48+
JNI_UTIL.ReleasePrimitiveArrayCritical(env, internalVectorIds, vectorIds, 0);
49+
}};
50+
4151
jfloat* scores = static_cast<jfloat*>(JNI_UTIL.GetPrimitiveArrayCritical(env, jscores, nullptr));
52+
knn_jni::JNIReleaseElements releaseScores {[=]{
53+
JNI_UTIL.ReleasePrimitiveArrayCritical(env, jscores, scores, 0);
54+
}};
4255

4356
// Bulk similarity calculation
4457
srchContext->similarityFunction->calculateSimilarityInBulk(
@@ -47,11 +60,11 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scor
4760
reinterpret_cast<float*>(scores),
4861
numVectors);
4962

50-
// Release pinned pointers
51-
JNI_UTIL.ReleasePrimitiveArrayCritical(env, internalVectorIds, vectorIds, 0);
52-
JNI_UTIL.ReleasePrimitiveArrayCritical(env, jscores, scores, 0);
63+
jfloat maxScore = *std::max_element(scores, scores + numVectors);
64+
return maxScore;
5365
} catch (...) {
5466
JNI_UTIL.CatchCppExceptionAndThrowJava(env);
67+
return 0.0f; // value ignored if exception pending
5568
}
5669
}
5770

@@ -96,3 +109,46 @@ JNIEXPORT jfloat JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_sc
96109

97110
return 0;
98111
}
112+
113+
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_saveBBQSearchContext
114+
(JNIEnv *env, jclass clazz, jbyteArray quantizedQuery,
115+
jfloat lowerInterval, jfloat upperInterval, jfloat additionalCorrection,
116+
jint quantizedComponentSum, jlongArray addressAndSize,
117+
jint functionTypeOrd, jint dimension, jfloat centroidDp) {
118+
try {
119+
// Get quantized query bytes
120+
const jsize queryByteSize = JNI_UTIL.GetJavaBytesArrayLength(env, quantizedQuery);
121+
jbyte* queryPtr = static_cast<jbyte*>(JNI_UTIL.GetPrimitiveArrayCritical(env, quantizedQuery, nullptr));
122+
knn_jni::JNIReleaseElements queryRelease {[=]{
123+
JNI_UTIL.ReleasePrimitiveArrayCritical(env, quantizedQuery, queryPtr, 0);
124+
}};
125+
126+
// Get mmap address and size
127+
const jsize mmapAddressAndSizeLength = JNI_UTIL.GetJavaLongArrayLength(env, addressAndSize);
128+
jlong* mmapAddressAndSize = static_cast<jlong*>(JNI_UTIL.GetPrimitiveArrayCritical(env, addressAndSize, nullptr));
129+
knn_jni::JNIReleaseElements mmapAddressAndSizeRelease {[=]{
130+
JNI_UTIL.ReleasePrimitiveArrayCritical(env, addressAndSize, mmapAddressAndSize, 0);
131+
}};
132+
133+
// Store correction factors in tmpBuffer before calling saveSearchContext.
134+
// saveSearchContext will reset tmpBuffer at the beginning, so we need to call it first,
135+
// then write correction factors after.
136+
SimilarityFunction::saveSearchContext(
137+
reinterpret_cast<uint8_t*>(queryPtr), queryByteSize,
138+
dimension,
139+
reinterpret_cast<int64_t*>(mmapAddressAndSize), mmapAddressAndSizeLength,
140+
functionTypeOrd);
141+
142+
// Now store correction factors in tmpBuffer (saveSearchContext clears it, then BBQ_IP branch leaves it empty)
143+
SimdVectorSearchContext* ctx = SimilarityFunction::getSearchContext();
144+
ctx->tmpBuffer.resize(5 * sizeof(float));
145+
auto* correctionPtr = reinterpret_cast<float*>(ctx->tmpBuffer.data());
146+
correctionPtr[0] = lowerInterval;
147+
correctionPtr[1] = upperInterval;
148+
correctionPtr[2] = additionalCorrection;
149+
std::memcpy(&correctionPtr[3], &quantizedComponentSum, sizeof(int32_t));
150+
correctionPtr[4] = centroidDp;
151+
} catch (...) {
152+
JNI_UTIL.CatchCppExceptionAndThrowJava(env);
153+
}
154+
}

0 commit comments

Comments
 (0)