Skip to content

Commit ad5507a

Browse files
Added default bulk simd bbq scoring.
Signed-off-by: Dooyong Kim <kdooyong@amazon.com>
1 parent 509e90c commit ad5507a

File tree

1 file changed

+257
-0
lines changed

1 file changed

+257
-0
lines changed

jni/src/simd/similarity_function/default_simd_similarity_function.cpp

Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <algorithm>
22
#include <array>
33
#include <cstddef>
4+
#include <cstring>
45
#include <stdint.h>
56
#include <cmath>
67

@@ -44,12 +45,268 @@ DefaultFP16SimilarityFunction<FaissScoreToLuceneScoreTransform::ipToMaxIpTransfo
4445
// 2. L2
4546
DefaultFP16SimilarityFunction<FaissScoreToLuceneScoreTransform::l2TransformBulk, FaissScoreToLuceneScoreTransform::l2Transform> DEFAULT_FP16_L2_SIMIL_FUNC;
4647

48+
49+
//
50+
// BBQ (ADC: 4-bit query x 1-bit data) - Default (non-SIMD) implementation
51+
//
52+
// Pure C++ implementation written for compiler auto-vectorization with -O3.
53+
// No SIMD intrinsics. Uses __builtin_popcount[ll] and std::memcpy.
54+
//
55+
// The query is 4-bit quantized and transposed into 4 bit planes (via transposeHalfByte).
56+
// Each bit plane has `binaryCodeBytes` bytes. The int4BitDotProduct computes:
57+
// Result = popcount(plane0 AND data) * 1
58+
// + popcount(plane1 AND data) * 2
59+
// + popcount(plane2 AND data) * 4
60+
// + popcount(plane3 AND data) * 8
61+
//
62+
63+
static constexpr float FOUR_BIT_SCALE = 1.0f / 15.0f;
64+
65+
// Reads the per-vector correction factors from a potentially unaligned address.
66+
// On-disk layout after binaryCode: [lowerInterval(f32)][upperInterval(f32)][additionalCorrection(f32)][quantizedComponentSum(i32)]
67+
static FORCE_INLINE void readDataCorrections(const uint8_t* ptr, float& ax, float& lx, float& additional, float& x1) {
68+
float lower, upper;
69+
std::memcpy(&lower, ptr, sizeof(float));
70+
std::memcpy(&upper, ptr + 4, sizeof(float));
71+
std::memcpy(&additional, ptr + 8, sizeof(float));
72+
int32_t componentSum;
73+
std::memcpy(&componentSum, ptr + 12, sizeof(int32_t));
74+
ax = lower;
75+
lx = upper - lower;
76+
x1 = static_cast<float>(componentSum);
77+
}
78+
79+
// Scalar int4BitDotProduct
80+
// q has 4 * binaryCodeBytes bytes (4 bit planes), d has binaryCodeBytes bytes.
81+
// Uses std::memcpy for uint64_t loads to avoid undefined behavior from unaligned
82+
// reinterpret_cast when binaryCodeBytes is not a multiple of 8.
83+
static FORCE_INLINE int64_t int4BitDotProduct(const uint8_t* q, const uint8_t* d, const int32_t binaryCodeBytes) {
84+
int64_t result = 0;
85+
for (int32_t bitPlane = 0 ; bitPlane < 4 ; ++bitPlane) {
86+
const int32_t words = binaryCodeBytes >> 3;
87+
88+
int64_t subResult = 0;
89+
for (int32_t w = 0 ; w < words ; ++w) {
90+
uint64_t qWord, dWord;
91+
std::memcpy(&qWord, q + bitPlane * binaryCodeBytes + w * 8, sizeof(uint64_t));
92+
std::memcpy(&dWord, d + w * 8, sizeof(uint64_t));
93+
subResult += __builtin_popcountll(qWord & dWord);
94+
}
95+
96+
const int32_t remainStart = words * 8;
97+
for (int32_t r = remainStart ; r < binaryCodeBytes ; ++r) {
98+
subResult += __builtin_popcount((q[bitPlane * binaryCodeBytes + r] & d[r]) & 0xFF);
99+
}
100+
101+
result += subResult << bitPlane;
102+
}
103+
return result;
104+
}
105+
106+
// Default (non-SIMD) batched int4BitDotProduct.
107+
// Processes one batch element at a time with the byte loop as the inner loop,
108+
// giving the compiler the best auto-vectorization opportunity across the byte dimension.
109+
template <int BATCH_SIZE>
110+
static FORCE_INLINE void default4bitDotProductBatch(
111+
const uint8_t* queryPtr,
112+
uint8_t** dataVecs,
113+
const int32_t binaryCodeBytes,
114+
float* results) {
115+
116+
const uint8_t* plane0 = queryPtr;
117+
const uint8_t* plane1 = queryPtr + binaryCodeBytes;
118+
const uint8_t* plane2 = queryPtr + 2 * binaryCodeBytes;
119+
const uint8_t* plane3 = queryPtr + 3 * binaryCodeBytes;
120+
121+
const int32_t words = binaryCodeBytes >> 3;
122+
const int32_t remainStart = words * 8;
123+
124+
for (int32_t b = 0 ; b < BATCH_SIZE ; ++b) {
125+
int64_t acc = 0;
126+
const uint8_t* data = dataVecs[b];
127+
128+
// 8-byte word loop — compiler can auto-vectorize this
129+
for (int32_t w = 0 ; w < words ; ++w) {
130+
const int32_t offset = w * 8;
131+
uint64_t dWord;
132+
std::memcpy(&dWord, data + offset, sizeof(uint64_t));
133+
134+
uint64_t q0, q1, q2, q3;
135+
std::memcpy(&q0, plane0 + offset, sizeof(uint64_t));
136+
std::memcpy(&q1, plane1 + offset, sizeof(uint64_t));
137+
std::memcpy(&q2, plane2 + offset, sizeof(uint64_t));
138+
std::memcpy(&q3, plane3 + offset, sizeof(uint64_t));
139+
140+
acc += __builtin_popcountll(q0 & dWord) * 1
141+
+ __builtin_popcountll(q1 & dWord) * 2
142+
+ __builtin_popcountll(q2 & dWord) * 4
143+
+ __builtin_popcountll(q3 & dWord) * 8;
144+
}
145+
146+
// Byte remainder — handles non-multiple-of-8 binaryCodeBytes
147+
for (int32_t r = remainStart ; r < binaryCodeBytes ; ++r) {
148+
uint8_t db = data[r];
149+
acc += __builtin_popcount((plane0[r] & db) & 0xFF) * 1
150+
+ __builtin_popcount((plane1[r] & db) & 0xFF) * 2
151+
+ __builtin_popcount((plane2[r] & db) & 0xFF) * 4
152+
+ __builtin_popcount((plane3[r] & db) & 0xFF) * 8;
153+
}
154+
155+
results[b] = static_cast<float>(acc);
156+
}
157+
}
158+
159+
template <bool IsMaxIP>
160+
struct DefaultBBQSimilarityFunction final : SimilarityFunction {
161+
HOT_SPOT void calculateSimilarityInBulk(SimdVectorSearchContext* srchContext,
162+
int32_t* internalVectorIds,
163+
float* scores,
164+
const int32_t numVectors) {
165+
const auto* queryPtr = reinterpret_cast<const uint8_t*>(srchContext->queryVectorSimdAligned);
166+
const int32_t dim = srchContext->dimension;
167+
const int32_t binaryCodeBytes = (dim + 7) / 8;
168+
169+
// Read query correction factors from tmpBuffer
170+
const auto* queryCorrectionPtr = reinterpret_cast<const float*>(srchContext->tmpBuffer.data());
171+
const float ay = queryCorrectionPtr[0];
172+
const float ly = (queryCorrectionPtr[1] - queryCorrectionPtr[0]) * FOUR_BIT_SCALE;
173+
const float queryAdditional = queryCorrectionPtr[2];
174+
int32_t y1Raw; std::memcpy(&y1Raw, &queryCorrectionPtr[3], sizeof(int32_t));
175+
const float y1 = static_cast<float>(y1Raw);
176+
const float centroidDp = queryCorrectionPtr[4];
177+
178+
int32_t processedCount = 0;
179+
constexpr int32_t vecBlock = 8;
180+
constexpr int32_t vecHalfBlock = 4;
181+
uint8_t* vectors[vecBlock];
182+
183+
// Batch size 8
184+
for ( ; (processedCount + vecBlock) <= numVectors ; processedCount += vecBlock) {
185+
srchContext->getVectorPointersInBulk(vectors, &internalVectorIds[processedCount], vecBlock);
186+
default4bitDotProductBatch<vecBlock>(queryPtr, vectors, binaryCodeBytes, &scores[processedCount]);
187+
188+
for (int32_t i = 0 ; i < vecBlock ; ++i) {
189+
float ax, lx, additional, x1;
190+
readDataCorrections(vectors[i] + binaryCodeBytes, ax, lx, additional, x1);
191+
192+
scores[processedCount + i] = ax * ay * dim
193+
+ ay * lx * x1
194+
+ ax * ly * y1
195+
+ lx * ly * scores[processedCount + i];
196+
197+
if constexpr (IsMaxIP) {
198+
scores[processedCount + i] += queryAdditional + additional - centroidDp;
199+
} else {
200+
scores[processedCount + i] = std::max(0.0F, queryAdditional + additional - 2 * scores[processedCount + i]);
201+
}
202+
}
203+
}
204+
205+
// Batch size 4
206+
for ( ; (processedCount + vecHalfBlock) <= numVectors ; processedCount += vecHalfBlock) {
207+
srchContext->getVectorPointersInBulk(vectors, &internalVectorIds[processedCount], vecHalfBlock);
208+
default4bitDotProductBatch<vecHalfBlock>(queryPtr, vectors, binaryCodeBytes, &scores[processedCount]);
209+
210+
for (int32_t i = 0 ; i < vecHalfBlock ; ++i) {
211+
float ax, lx, additional, x1;
212+
readDataCorrections(vectors[i] + binaryCodeBytes, ax, lx, additional, x1);
213+
214+
scores[processedCount + i] = ax * ay * dim
215+
+ ay * lx * x1
216+
+ ax * ly * y1
217+
+ lx * ly * scores[processedCount + i];
218+
219+
if constexpr (IsMaxIP) {
220+
scores[processedCount + i] += queryAdditional + additional - centroidDp;
221+
} else {
222+
scores[processedCount + i] =
223+
std::max(0.0F, queryAdditional + additional - 2 * scores[processedCount + i]);
224+
}
225+
}
226+
}
227+
228+
// Tail: remaining vectors (scalar)
229+
for ( ; processedCount < numVectors ; ++processedCount) {
230+
const auto* dataVec = srchContext->getVectorPointer(internalVectorIds[processedCount]);
231+
const float qcDist = static_cast<float>(
232+
int4BitDotProduct(queryPtr, dataVec, binaryCodeBytes));
233+
234+
float ax, lx, additional, x1;
235+
readDataCorrections(dataVec + binaryCodeBytes, ax, lx, additional, x1);
236+
237+
scores[processedCount] = ax * ay * dim
238+
+ ay * lx * x1
239+
+ ax * ly * y1
240+
+ lx * ly * qcDist;
241+
242+
if constexpr (IsMaxIP) {
243+
scores[processedCount] += queryAdditional + additional - centroidDp;
244+
} else {
245+
scores[processedCount] =
246+
std::max(0.0F, queryAdditional + additional - 2 * scores[processedCount]);
247+
}
248+
}
249+
250+
if constexpr (IsMaxIP) {
251+
FaissScoreToLuceneScoreTransform::ipToMaxIpTransformBulk(scores, numVectors);
252+
} else {
253+
FaissScoreToLuceneScoreTransform::l2TransformBulk(scores, numVectors);
254+
}
255+
}
256+
257+
float calculateSimilarity(SimdVectorSearchContext* srchContext, const int32_t internalVectorId) {
258+
const auto* queryPtr = reinterpret_cast<const uint8_t*>(srchContext->queryVectorSimdAligned);
259+
const int32_t dim = srchContext->dimension;
260+
const int32_t binaryCodeBytes = (dim + 7) / 8;
261+
262+
const auto* queryCorrectionPtr = reinterpret_cast<const float*>(srchContext->tmpBuffer.data());
263+
const float ay = queryCorrectionPtr[0];
264+
const float ly = (queryCorrectionPtr[1] - queryCorrectionPtr[0]) * FOUR_BIT_SCALE;
265+
const float queryAdditional = queryCorrectionPtr[2];
266+
int32_t y1Raw; std::memcpy(&y1Raw, &queryCorrectionPtr[3], sizeof(int32_t));
267+
const float y1 = static_cast<float>(y1Raw);
268+
const float centroidDp = queryCorrectionPtr[4];
269+
270+
const auto* dataVec = srchContext->getVectorPointer(internalVectorId);
271+
const float qcDist = static_cast<float>(
272+
int4BitDotProduct(queryPtr, dataVec, binaryCodeBytes));
273+
274+
float ax, lx, additional, x1;
275+
readDataCorrections(dataVec + binaryCodeBytes, ax, lx, additional, x1);
276+
277+
float score = ax * ay * dim
278+
+ ay * lx * x1
279+
+ ax * ly * y1
280+
+ lx * ly * qcDist;
281+
282+
if constexpr (IsMaxIP) {
283+
score += queryAdditional + additional - centroidDp;
284+
return FaissScoreToLuceneScoreTransform::ipToMaxIpTransform(score);
285+
} else {
286+
score = std::max(0.0F, queryAdditional + additional - 2 * score);
287+
return FaissScoreToLuceneScoreTransform::l2Transform(score);
288+
}
289+
}
290+
};
291+
292+
//
293+
// BBQ
294+
//
295+
// 1. Max IP
296+
DefaultBBQSimilarityFunction<true> BBQ_IP_SIMIL_FUNC;
297+
// 2. L2
298+
DefaultBBQSimilarityFunction<false> BBQ_L2_SIMIL_FUNC;
299+
47300
#ifndef __NO_SELECT_FUNCTION
48301
SimilarityFunction* SimilarityFunction::selectSimilarityFunction(const NativeSimilarityFunctionType nativeFunctionType) {
49302
if (nativeFunctionType == NativeSimilarityFunctionType::FP16_MAXIMUM_INNER_PRODUCT) {
50303
return &DEFAULT_FP16_MAX_INNER_PRODUCT_SIMIL_FUNC;
51304
} else if (nativeFunctionType == NativeSimilarityFunctionType::FP16_L2) {
52305
return &DEFAULT_FP16_L2_SIMIL_FUNC;
306+
} else if (nativeFunctionType == NativeSimilarityFunctionType::BBQ_IP) {
307+
return &BBQ_IP_SIMIL_FUNC;
308+
} else if (nativeFunctionType == NativeSimilarityFunctionType::BBQ_L2) {
309+
return &BBQ_L2_SIMIL_FUNC;
53310
}
54311

55312
throw std::runtime_error("Invalid native similarity function type was given, nativeFunctionType="

0 commit comments

Comments
 (0)