From f7db2192660cc970b4a20a0a99a644dfbba6f86e Mon Sep 17 00:00:00 2001 From: Martok88 <17170367+Martok88@users.noreply.github.com> Date: Fri, 8 Mar 2024 09:07:44 -0800 Subject: [PATCH] Add openai embeddings (#915) * Add OpenAI embedding compatibility * Use OPENAI_API_KEY by default * lint * Add default OpenAI URL replace `authorization` by `apiKey` * Add a note in readme --------- Co-authored-by: Nathan Sarrazin --- README.md | 2 +- .../embeddingEndpoints/embeddingEndpoints.ts | 6 +++ .../openai/embeddingEndpoints.ts | 51 +++++++++++++++++++ src/lib/server/embeddingModels.ts | 3 ++ 4 files changed, 61 insertions(+), 1 deletion(-) create mode 100644 src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts diff --git a/README.md b/README.md index 5a79c95648..d97ccbca0e 100644 --- a/README.md +++ b/README.md @@ -120,7 +120,7 @@ TEXT_EMBEDDING_MODELS = `[ ``` The required fields are `name`, `chunkCharLength` and `endpoints`. -Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. +Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js), [`TEI`](https://github.com/huggingface/text-embeddings-inference) and [`OpenAI`](https://platform.openai.com/docs/guides/embeddings). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. `openai` models are accessed through the [OpenAI API](https://platform.openai.com/docs/guides/embeddings). When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model. diff --git a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts index 7beb33c4b8..2644d20e59 100644 --- a/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts +++ b/src/lib/server/embeddingEndpoints/embeddingEndpoints.ts @@ -7,6 +7,10 @@ import { embeddingEndpointTransformersJS, embeddingEndpointTransformersJSParametersSchema, } from "./transformersjs/embeddingEndpoints"; +import { + embeddingEndpointOpenAI, + embeddingEndpointOpenAIParametersSchema, +} from "./openai/embeddingEndpoints"; // parameters passed when generating text interface EmbeddingEndpointParameters { @@ -21,6 +25,7 @@ export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise export const embeddingEndpointSchema = z.discriminatedUnion("type", [ embeddingEndpointTeiParametersSchema, embeddingEndpointTransformersJSParametersSchema, + embeddingEndpointOpenAIParametersSchema, ]); type EmbeddingEndpointTypeOptions = z.infer["type"]; @@ -36,6 +41,7 @@ export const embeddingEndpoints: { } = { tei: embeddingEndpointTei, transformersjs: embeddingEndpointTransformersJS, + openai: embeddingEndpointOpenAI, }; export default embeddingEndpoints; diff --git a/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts b/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts new file mode 100644 index 0000000000..89d7900bb2 --- /dev/null +++ b/src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts @@ -0,0 +1,51 @@ +import { z } from "zod"; +import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints"; +import { chunk } from "$lib/utils/chunk"; +import { OPENAI_API_KEY } from "$env/static/private"; + +export const embeddingEndpointOpenAIParametersSchema = z.object({ + weight: z.number().int().positive().default(1), + model: z.any(), + type: z.literal("openai"), + url: z.string().url().default("https://api.openai.com/v1/embeddings"), + apiKey: z.string().default(OPENAI_API_KEY), +}); + +export async function embeddingEndpointOpenAI( + input: z.input +): Promise { + const { url, model, apiKey } = embeddingEndpointOpenAIParametersSchema.parse(input); + + const maxBatchSize = model.maxBatchSize || 100; + + return async ({ inputs }) => { + const requestURL = new URL(url); + + const batchesInputs = chunk(inputs, maxBatchSize); + + const batchesResults = await Promise.all( + batchesInputs.map(async (batchInputs) => { + const response = await fetch(requestURL, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/json", + ...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}), + }, + body: JSON.stringify({ input: batchInputs, model: model.name }), + }); + + const embeddings: Embedding[] = []; + const responseObject = await response.json(); + for (const embeddingObject of responseObject.data) { + embeddings.push(embeddingObject.embedding); + } + return embeddings; + }) + ); + + const flatAllEmbeddings = batchesResults.flat(); + + return flatAllEmbeddings; + }; +} diff --git a/src/lib/server/embeddingModels.ts b/src/lib/server/embeddingModels.ts index 75d65606bc..303c90c2e0 100644 --- a/src/lib/server/embeddingModels.ts +++ b/src/lib/server/embeddingModels.ts @@ -22,6 +22,7 @@ const modelConfig = z.object({ modelUrl: z.string().url().optional(), endpoints: z.array(embeddingEndpointSchema).nonempty(), chunkCharLength: z.number().positive(), + maxBatchSize: z.number().positive().optional(), preQuery: z.string().default(""), prePassage: z.string().default(""), }); @@ -70,6 +71,8 @@ const addEndpoint = (m: Awaited>) => ({ return embeddingEndpoints.tei(args); case "transformersjs": return embeddingEndpoints.transformersjs(args); + case "openai": + return embeddingEndpoints.openai(args); } }