1+ #include < cstring>
2+ #include < limits>
3+ #include < algorithm>
14#include " org_opensearch_knn_jni_SimdVectorComputeService.h"
25#include " jni_util.h"
36#include " simd/similarity_function/similarity_function.h"
@@ -26,8 +29,11 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) {
2629 JNI_UTIL.Uninitialize (env);
2730}
2831
29- JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scoreSimilarityInBulk
32+ JNIEXPORT jfloat JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scoreSimilarityInBulk
3033 (JNIEnv *env, jclass clazz, jintArray internalVectorIds, jfloatArray jscores, const jint numVectors) {
34+ if (numVectors <= 0 ) {
35+ return std::numeric_limits<float >::min ();
36+ }
3137
3238 try {
3339 // Get search context
@@ -38,7 +44,14 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scor
3844
3945 // Get pointers of vectorIds and scores
4046 jint* vectorIds = static_cast <jint*>(JNI_UTIL.GetPrimitiveArrayCritical (env, internalVectorIds, nullptr ));
47+ knn_jni::JNIReleaseElements releaseVectorIds {[=]{
48+ JNI_UTIL.ReleasePrimitiveArrayCritical (env, internalVectorIds, vectorIds, 0 );
49+ }};
50+
4151 jfloat* scores = static_cast <jfloat*>(JNI_UTIL.GetPrimitiveArrayCritical (env, jscores, nullptr ));
52+ knn_jni::JNIReleaseElements releaseScores {[=]{
53+ JNI_UTIL.ReleasePrimitiveArrayCritical (env, jscores, scores, 0 );
54+ }};
4255
4356 // Bulk similarity calculation
4457 srchContext->similarityFunction ->calculateSimilarityInBulk (
@@ -47,11 +60,11 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_scor
4760 reinterpret_cast <float *>(scores),
4861 numVectors);
4962
50- // Release pinned pointers
51- JNI_UTIL.ReleasePrimitiveArrayCritical (env, internalVectorIds, vectorIds, 0 );
52- JNI_UTIL.ReleasePrimitiveArrayCritical (env, jscores, scores, 0 );
63+ jfloat maxScore = *std::max_element (scores, scores + numVectors);
64+ return maxScore;
5365 } catch (...) {
5466 JNI_UTIL.CatchCppExceptionAndThrowJava (env);
67+ return 0 .0f ; // value ignored if exception pending
5568 }
5669}
5770
@@ -96,3 +109,46 @@ JNIEXPORT jfloat JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_sc
96109
97110 return 0 ;
98111}
112+
113+ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_SimdVectorComputeService_saveBBQSearchContext
114+ (JNIEnv *env, jclass clazz, jbyteArray quantizedQuery,
115+ jfloat lowerInterval, jfloat upperInterval, jfloat additionalCorrection,
116+ jint quantizedComponentSum, jlongArray addressAndSize,
117+ jint functionTypeOrd, jint dimension, jfloat centroidDp) {
118+ try {
119+ // Get quantized query bytes
120+ const jsize queryByteSize = JNI_UTIL.GetJavaBytesArrayLength (env, quantizedQuery);
121+ jbyte* queryPtr = static_cast <jbyte*>(JNI_UTIL.GetPrimitiveArrayCritical (env, quantizedQuery, nullptr ));
122+ knn_jni::JNIReleaseElements queryRelease {[=]{
123+ JNI_UTIL.ReleasePrimitiveArrayCritical (env, quantizedQuery, queryPtr, 0 );
124+ }};
125+
126+ // Get mmap address and size
127+ const jsize mmapAddressAndSizeLength = JNI_UTIL.GetJavaLongArrayLength (env, addressAndSize);
128+ jlong* mmapAddressAndSize = static_cast <jlong*>(JNI_UTIL.GetPrimitiveArrayCritical (env, addressAndSize, nullptr ));
129+ knn_jni::JNIReleaseElements mmapAddressAndSizeRelease {[=]{
130+ JNI_UTIL.ReleasePrimitiveArrayCritical (env, addressAndSize, mmapAddressAndSize, 0 );
131+ }};
132+
133+ // Store correction factors in tmpBuffer before calling saveSearchContext.
134+ // saveSearchContext will reset tmpBuffer at the beginning, so we need to call it first,
135+ // then write correction factors after.
136+ SimilarityFunction::saveSearchContext (
137+ reinterpret_cast <uint8_t *>(queryPtr), queryByteSize,
138+ dimension,
139+ reinterpret_cast <int64_t *>(mmapAddressAndSize), mmapAddressAndSizeLength,
140+ functionTypeOrd);
141+
142+ // Now store correction factors in tmpBuffer (saveSearchContext clears it, then BBQ_IP branch leaves it empty)
143+ SimdVectorSearchContext* ctx = SimilarityFunction::getSearchContext ();
144+ ctx->tmpBuffer .resize (5 * sizeof (float ));
145+ auto * correctionPtr = reinterpret_cast <float *>(ctx->tmpBuffer .data ());
146+ correctionPtr[0 ] = lowerInterval;
147+ correctionPtr[1 ] = upperInterval;
148+ correctionPtr[2 ] = additionalCorrection;
149+ std::memcpy (&correctionPtr[3 ], &quantizedComponentSum, sizeof (int32_t ));
150+ correctionPtr[4 ] = centroidDp;
151+ } catch (...) {
152+ JNI_UTIL.CatchCppExceptionAndThrowJava (env);
153+ }
154+ }
0 commit comments