Skip to content

Commit

Permalink
Add openai embeddings (#915)
Browse files Browse the repository at this point in the history
* Add OpenAI embedding compatibility

* Use OPENAI_API_KEY by default

* lint

* Add default OpenAI URL
replace `authorization` by `apiKey`

* Add a note in readme

---------

Co-authored-by: Nathan Sarrazin <[email protected]>
  • Loading branch information
Martok88 and nsarrazin committed Mar 8, 2024
1 parent 0081568 commit f7db219
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -120,7 +120,7 @@ TEXT_EMBEDDING_MODELS = `[
```

The required fields are `name`, `chunkCharLength` and `endpoints`.
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js) and [`TEI`](https://github.com/huggingface/text-embeddings-inference). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint.
Supported text embedding backends are: [`transformers.js`](https://huggingface.co/docs/transformers.js), [`TEI`](https://github.com/huggingface/text-embeddings-inference) and [`OpenAI`](https://platform.openai.com/docs/guides/embeddings). `transformers.js` models run locally as part of `chat-ui`, whereas `TEI` models run in a different environment & accessed through an API endpoint. `openai` models are accessed through the [OpenAI API](https://platform.openai.com/docs/guides/embeddings).

When more than one embedding models are supplied in `.env.local` file, the first will be used by default, and the others will only be used on LLM's which configured `embeddingModel` to the name of the model.

Expand Down
6 changes: 6 additions & 0 deletions src/lib/server/embeddingEndpoints/embeddingEndpoints.ts
Expand Up @@ -7,6 +7,10 @@ import {
embeddingEndpointTransformersJS,
embeddingEndpointTransformersJSParametersSchema,
} from "./transformersjs/embeddingEndpoints";
import {
embeddingEndpointOpenAI,
embeddingEndpointOpenAIParametersSchema,
} from "./openai/embeddingEndpoints";

// parameters passed when generating text
interface EmbeddingEndpointParameters {
Expand All @@ -21,6 +25,7 @@ export type EmbeddingEndpoint = (params: EmbeddingEndpointParameters) => Promise
export const embeddingEndpointSchema = z.discriminatedUnion("type", [
embeddingEndpointTeiParametersSchema,
embeddingEndpointTransformersJSParametersSchema,
embeddingEndpointOpenAIParametersSchema,
]);

type EmbeddingEndpointTypeOptions = z.infer<typeof embeddingEndpointSchema>["type"];
Expand All @@ -36,6 +41,7 @@ export const embeddingEndpoints: {
} = {
tei: embeddingEndpointTei,
transformersjs: embeddingEndpointTransformersJS,
openai: embeddingEndpointOpenAI,
};

export default embeddingEndpoints;
51 changes: 51 additions & 0 deletions src/lib/server/embeddingEndpoints/openai/embeddingEndpoints.ts
@@ -0,0 +1,51 @@
import { z } from "zod";
import type { EmbeddingEndpoint, Embedding } from "../embeddingEndpoints";
import { chunk } from "$lib/utils/chunk";
import { OPENAI_API_KEY } from "$env/static/private";

export const embeddingEndpointOpenAIParametersSchema = z.object({
weight: z.number().int().positive().default(1),
model: z.any(),
type: z.literal("openai"),
url: z.string().url().default("https://api.openai.com/v1/embeddings"),
apiKey: z.string().default(OPENAI_API_KEY),
});

export async function embeddingEndpointOpenAI(
input: z.input<typeof embeddingEndpointOpenAIParametersSchema>
): Promise<EmbeddingEndpoint> {
const { url, model, apiKey } = embeddingEndpointOpenAIParametersSchema.parse(input);

const maxBatchSize = model.maxBatchSize || 100;

return async ({ inputs }) => {
const requestURL = new URL(url);

const batchesInputs = chunk(inputs, maxBatchSize);

const batchesResults = await Promise.all(
batchesInputs.map(async (batchInputs) => {
const response = await fetch(requestURL, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
...(apiKey ? { Authorization: `Bearer ${apiKey}` } : {}),
},
body: JSON.stringify({ input: batchInputs, model: model.name }),
});

const embeddings: Embedding[] = [];
const responseObject = await response.json();
for (const embeddingObject of responseObject.data) {
embeddings.push(embeddingObject.embedding);
}
return embeddings;
})
);

const flatAllEmbeddings = batchesResults.flat();

return flatAllEmbeddings;
};
}
3 changes: 3 additions & 0 deletions src/lib/server/embeddingModels.ts
Expand Up @@ -22,6 +22,7 @@ const modelConfig = z.object({
modelUrl: z.string().url().optional(),
endpoints: z.array(embeddingEndpointSchema).nonempty(),
chunkCharLength: z.number().positive(),
maxBatchSize: z.number().positive().optional(),
preQuery: z.string().default(""),
prePassage: z.string().default(""),
});
Expand Down Expand Up @@ -70,6 +71,8 @@ const addEndpoint = (m: Awaited<ReturnType<typeof processEmbeddingModel>>) => ({
return embeddingEndpoints.tei(args);
case "transformersjs":
return embeddingEndpoints.transformersjs(args);
case "openai":
return embeddingEndpoints.openai(args);
}
}

Expand Down

0 comments on commit f7db219

Please sign in to comment.