From d2f63d6c5f4f51cd9fa8fea27248bd3af824edea Mon Sep 17 00:00:00 2001 From: howlger Date: Wed, 27 Mar 2024 12:10:07 +0100 Subject: [PATCH 1/2] embedding : show full embedding for single prompt To support the use case of creating an embedding for a given prompt, the entire embedding and not just the first part needed to be printed. Also, show cosine similarity matrix only if there is more than one prompt, as the cosine similarity matrix for a single prompt is always `1.00`. --- examples/embedding/embedding.cpp | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 9aede7fadfe31..d68525fdc8a63 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -178,25 +178,27 @@ int main(int argc, char ** argv) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - // print the first part of the embeddings + // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); for (int j = 0; j < n_prompts; j++) { fprintf(stdout, "embedding %d: ", j); - for (int i = 0; i < std::min(16, n_embd); i++) { + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { fprintf(stdout, "%9.6f ", emb[j * n_embd + i]); } fprintf(stdout, "\n"); } // print cosine similarity matrix - fprintf(stdout, "\n"); - printf("cosine similarity matrix:\n\n"); - for (int i = 0; i < n_prompts; i++) { - for (int j = 0; j < n_prompts; j++) { - float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); - fprintf(stdout, "%6.2f ", sim); - } + if (n_prompts > 1) { fprintf(stdout, "\n"); + printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + for (int j = 0; j < n_prompts; j++) { + float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + fprintf(stdout, "%6.2f ", sim); + } + fprintf(stdout, "\n"); + } } // clean up From 9a8649625a7c927bb735c2d982dc5f8949c09e5c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 27 Mar 2024 13:15:28 +0200 Subject: [PATCH 2/2] Update examples/embedding/embedding.cpp --- examples/embedding/embedding.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index d68525fdc8a63..536657526685c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -178,7 +178,7 @@ int main(int argc, char ** argv) { float * out = emb + p * n_embd; batch_decode(ctx, batch, out, s, n_embd); - // print the first part of the embeddings or for a single prompt, the full embedding + // print the first part of the embeddings or for a single prompt, the full embedding fprintf(stdout, "\n"); for (int j = 0; j < n_prompts; j++) { fprintf(stdout, "embedding %d: ", j);