Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Benchmark for C++ PDQ index #1699

Merged
merged 19 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions pdq/cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ LIBHDRS=\
io/pdqio.h \
index/mih.h \
hashing/torben.h \
./CImg.h
./CImg.h \
common/pdqutils.h

LIBSRCS=\
common/pdqhashtypes.cpp \
Expand All @@ -41,11 +42,13 @@ LIBSRCS=\
hashing/pdqhashing.cpp \
downscaling/downscaling.cpp \
io/pdqio.cpp \
hashing/torben.cpp
hashing/torben.cpp \
common/pdqutils.cpp

MAINS=\
pdq-photo-hasher \
test-mih \
benchmark-query \
clusterize256 \
snowball-clusterize256 \
clusterize256x \
Expand Down Expand Up @@ -111,6 +114,8 @@ pdq-downsample-demo: bin/pdq-downsample-demo.cpp $(LIBSRCS) $(LIBHDRS)

test-mih: bin/test-mih.cpp $(LIBSRCS) $(LIBHDRS)
$(CCOPT) bin/test-mih.cpp $(LIBSRCS) -o test-mih $(LFLAGS)
benchmark-query: bin/benchmark-query.cpp $(LIBSRCS) $(LIBHDRS)
$(CCOPT) bin/benchmark-query.cpp $(LIBSRCS) -o benchmark-query $(LFLAGS)
test-mihg: bin/test-mih.cpp $(LIBSRCS) $(LIBHDRS)
$(CCDBG) bin/test-mih.cpp $(LIBSRCS) -o test-mihg $(LFLAGS)
clusterize256: bin/clusterize256.cpp $(LIBSRCS) $(LIBHDRS)
Expand Down
288 changes: 288 additions & 0 deletions pdq/cpp/bin/benchmark-query.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <pdq/cpp/common/pdqutils.h>
#include <pdq/cpp/index/mih.h>
#include <pdq/cpp/io/hashio.h>

#include <algorithm>
#include <random>
#include <set>

// ================================================================

// Benchmark result structure
struct BenchmarkResult {
std::string method;
int queryCount;
int indexCount;
int totalMatchCount;
double totalQuerySeconds;
};

// Static function declarations
static void usage(char* argv0, int rc);
static void query(char* argv0, int argc, char** argv);

// Function declarations for each query method
static BenchmarkResult queryLinear(
const int maxDistance,
const bool verbose,
const unsigned int seed,
const size_t indexSize,
const size_t querySize,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
queries,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
index);
static BenchmarkResult queryMIH(
const int maxDistance,
const bool verbose,
const unsigned int seed,
const size_t indexSize,
const size_t querySize,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
queries,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
index);

// ----------------------------------------------------------------
int main(int argc, char** argv) {
if (argc > 1 && (!strcmp(argv[1], "-h") || !strcmp(argv[1], "--help"))) {
usage(argv[0], 0);
} else {
query(argv[0], argc - 1, argv + 1);
}
return 0;
}

// ----------------------------------------------------------------
static void usage(char* argv0, int rc) {
FILE* fp = (rc == 0) ? stdout : stderr;
fprintf(fp, "Usage: %s [options]\n", argv0);
fprintf(fp, "Options:\n");
fprintf(fp, " -v Verbose output\n");
fprintf(fp, " --seed N Random seed (default: 41)\n");
fprintf(fp, " -q N Number of queries to run (default: 1000)\n");
fprintf(
fp,
" -b N Number of PDQ hashes to query against (default: 10000)\n");
fprintf(
fp,
" -d N Maximum Hamming distance for matches (default: 31)\n");
fprintf(
fp,
" -m Method for querying (default: linear), Available: linear, mih\n");
exit(rc);
}

static void query(char* argv0, int argc, char** argv) {
int maxDistance = 31;
bool verbose = false;
unsigned int seed = 41;
size_t indexSize = 10000;
size_t querySize = 1000;
std::string method = "linear";

// Parse command line arguments
for (int i = 0; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-q") {
if (i + 1 < argc) {
querySize = std::stoi(argv[++i]);
} else {
fprintf(stderr, "Error: Missing argument for -q\n");
usage(argv0, 1);
return;
}
} else if (arg == "-b") {
if (i + 1 < argc) {
indexSize = std::stoi(argv[++i]);
} else {
fprintf(stderr, "Error: Missing argument for -b\n");
usage(argv0, 1);
return;
}
} else if (arg == "-d") {
if (i + 1 < argc) {
maxDistance = std::stoi(argv[++i]);
} else {
fprintf(stderr, "Error: Missing argument for -d\n");
usage(argv0, 1);
return;
}
} else if (arg == "--seed") {
if (i + 1 < argc) {
seed = std::stoi(argv[++i]);
} else {
fprintf(stderr, "Error: Missing argument for --seed\n");
usage(argv0, 1);
return;
}
} else if (arg == "-m") {
if (i + 1 < argc) {
std::string methodName = argv[++i];
if (methodName == "linear" || methodName == "mih") {
method = methodName;
} else {
fprintf(stderr, "Invalid method: %s\n", methodName.c_str());
usage(argv0, 1);
return;
}
}
} else if (arg == "-v") {
verbose = true;
} else if (arg == "-h" || arg == "--help") {
usage(argv0, 0);
return;
} else if (arg.length() > 0) {
fprintf(stderr, "Unknown argument: %s\n", arg.c_str());
usage(argv0, 1);
return;
}
}

// Initialize random number generator
std::mt19937 gen(seed);

// Generate random hashes for queries
std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>> queries;
for (size_t i = 0; i < querySize; i++) {
auto hash = facebook::pdq::hashing::generateRandomHash(gen);
queries.push_back({hash, "query_" + std::to_string(i)});
}

// Generate random hashes for index
std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>> index;
for (size_t i = 0; i < indexSize - querySize; i++) {
auto hash = facebook::pdq::hashing::generateRandomHash(gen);
index.push_back({hash, "index_" + std::to_string(i)});
}

// Add noise to queries then insert into index
std::uniform_int_distribution<int> noiseDist(1, maxDistance);
for (const auto& query : queries) {
int bitsToFlip = noiseDist(gen);
auto noisyHash =
facebook::pdq::hashing::addNoise(query.first, bitsToFlip, gen);
index.push_back({noisyHash, "index_noisy_" + query.second});
}
std::shuffle(index.begin(), index.end(), gen);

if (verbose) {
printf("GENERATED QUERIES:\n");
for (const auto& it : queries) {
printf("%s,%s\n", it.first.format().c_str(), it.second.c_str());
}
printf("\n");

printf("GENERATED INDEX:\n");
for (const auto& it : index) {
printf("%s,%s\n", it.first.format().c_str(), it.second.c_str());
}
printf("\n");
}

BenchmarkResult result;
if (method == "linear") {
result = queryLinear(
maxDistance, verbose, seed, indexSize, querySize, queries, index);
} else if (method == "mih") {
result = queryMIH(
maxDistance, verbose, seed, indexSize, querySize, queries, index);
} else {
fprintf(stderr, "Unknown method: %s\n", method.c_str());
usage(argv0, 1);
return;
}

printf("METHOD: %s\n", result.method.c_str());
printf("QUERY COUNT: %d\n", result.queryCount);
printf("INDEX COUNT: %d\n", result.indexCount);
printf("TOTAL MATCH COUNT: %d\n", result.totalMatchCount);
printf("TOTAL QUERY SECONDS: %.6lf\n", result.totalQuerySeconds);
double queriesPerSecond = result.totalQuerySeconds > 0
? result.queryCount / result.totalQuerySeconds
: 0;
printf("QUERIES PER SECOND: %.2lf\n", queriesPerSecond);
printf("\n");
}

///////////////////////
//// Query methods ////
///////////////////////

static BenchmarkResult queryLinear(
const int maxDistance,
const bool verbose,
const unsigned int seed,
const size_t indexSize,
const size_t querySize,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
queries,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
index) {
// Do linear searches
std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>> matches;

Timer queryTimer("Linear query", verbose);
for (const auto& it : queries) {
for (const auto& it2 : index) {
if (it.first.hammingDistance(it2.first) <= maxDistance) {
matches.push_back(it2);
}
}
}
double seconds = queryTimer.elapsed();

return {
"linear query",
static_cast<int>(queries.size()),
static_cast<int>(index.size()),
static_cast<int>(matches.size()),
seconds,
};
}

static BenchmarkResult queryMIH(
const int maxDistance,
const bool verbose,
const unsigned int seed,
const size_t indexSize,
const size_t querySize,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
queries,
const std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>>&
index) {
// Build the MIH
facebook::pdq::index::MIH256<std::string> mih;

for (const auto& it : index) {
mih.insert(it.first, it.second);
}

printf("\n");
if (verbose) {
printf("\n");
mih.dump();
printf("\n");
}

// Do indexed searches
std::vector<std::pair<facebook::pdq::hashing::Hash256, std::string>> matches;
matches.clear();

Timer queryTimer("MIH query", verbose);
for (const auto& it : queries) {
mih.queryAll(it.first, maxDistance, matches);
}
double seconds = queryTimer.elapsed();

return {
"mutually-indexed hashing query",
static_cast<int>(queries.size()),
static_cast<int>(mih.size()),
static_cast<int>(matches.size()),
seconds,
};
}
58 changes: 58 additions & 0 deletions pdq/cpp/common/pdqutils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <algorithm>
#include <chrono>
#include <iostream>
#include <random>
#include <string>
#include <vector>
#include <pdq/cpp/common/pdqutils.h>

Timer::Timer(const std::string& context, bool printOnEnter)
: context_(context),
printOnEnter_(printOnEnter),
startTime_(std::chrono::steady_clock::now()) {
if (printOnEnter_) {
std::cout << context_ << "..." << std::endl;
}
}

double Timer::elapsed() const {
auto now = std::chrono::steady_clock::now();
std::chrono::duration<double> elapsed = now - startTime_;
return elapsed.count();
}

namespace facebook {
namespace pdq {
namespace hashing {

// Generate random hash
Hash256 generateRandomHash(std::mt19937& gen) {
Hash256 hash;
std::uniform_int_distribution<uint16_t> dist(0, UINT16_MAX);

for (int i = 0; i < HASH256_NUM_WORDS; i++) {
hash.w[i] = dist(gen);
}
return hash;
}

// Add noise to hash by flipping random bits
Hash256 addNoise(
const Hash256& original, int numBitsToFlip, std::mt19937& gen) {
Hash256 noisy = original;
std::vector<int> bitIndices(256);
for (int i = 0; i < 256; i++)
bitIndices[i] = i;
std::shuffle(bitIndices.begin(), bitIndices.end(), gen);
for (int i = 0; i < numBitsToFlip; i++) {
int bitIndex = bitIndices[i];
int wordIndex = bitIndex / 16;
int position = bitIndex % 16;
noisy.w[wordIndex] ^= (1 << position);
}
Comment on lines +47 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignorable: This looks uniformly random to me!

return noisy;
}

} // namespace hashing
} // namespace pdq
} // namespace facebook
Loading
Loading