-
Notifications
You must be signed in to change notification settings - Fork 296
/
OpenAIEmbedding.ts
129 lines (114 loc) · 3.52 KB
/
OpenAIEmbedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import type { ClientOptions as OpenAIClientOptions } from "openai";
import type { AzureOpenAIConfig } from "../llm/azure.js";
import {
getAzureConfigFromEnv,
getAzureModel,
shouldUseAzure,
} from "../llm/azure.js";
import type { OpenAISession } from "../llm/openai.js";
import { getOpenAISession } from "../llm/openai.js";
import { BaseEmbedding } from "./types.js";
export const ALL_OPENAI_EMBEDDING_MODELS = {
"text-embedding-ada-002": {
dimensions: 1536,
maxTokens: 8191,
},
"text-embedding-3-small": {
dimensions: 1536,
dimensionOptions: [512, 1536],
maxTokens: 8191,
},
"text-embedding-3-large": {
dimensions: 3072,
dimensionOptions: [256, 1024, 3072],
maxTokens: 8191,
},
};
export class OpenAIEmbedding extends BaseEmbedding {
/** embeddding model. defaults to "text-embedding-ada-002" */
model: string;
/** number of dimensions of the resulting vector, for models that support choosing fewer dimensions. undefined will default to model default */
dimensions: number | undefined;
// OpenAI session params
/** api key */
apiKey?: string = undefined;
/** maximum number of retries, default 10 */
maxRetries: number;
/** timeout in ms, default 60 seconds */
timeout?: number;
/** other session options for OpenAI */
additionalSessionOptions?: Omit<
Partial<OpenAIClientOptions>,
"apiKey" | "maxRetries" | "timeout"
>;
/** session object */
session: OpenAISession;
/**
* OpenAI Embedding
* @param init - initial parameters
*/
constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) {
super();
this.model = init?.model ?? "text-embedding-ada-002";
this.dimensions = init?.dimensions; // if no dimensions provided, will be undefined/not sent to OpenAI
this.embedBatchSize = init?.embedBatchSize ?? 10;
this.maxRetries = init?.maxRetries ?? 10;
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalSessionOptions = init?.additionalSessionOptions;
if (init?.azure || shouldUseAzure()) {
const azureConfig = {
...getAzureConfigFromEnv({
model: getAzureModel(this.model),
}),
...init?.azure,
};
this.apiKey = azureConfig.apiKey;
this.session =
init?.session ??
getOpenAISession({
azure: true,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
...azureConfig,
});
} else {
this.apiKey = init?.apiKey ?? undefined;
this.session =
init?.session ??
getOpenAISession({
apiKey: this.apiKey,
maxRetries: this.maxRetries,
timeout: this.timeout,
...this.additionalSessionOptions,
});
}
}
/**
* Get embeddings for a batch of texts
* @param texts
* @param options
*/
private async getOpenAIEmbedding(input: string[]): Promise<number[][]> {
const { data } = await this.session.openai.embeddings.create({
model: this.model,
dimensions: this.dimensions, // only sent to OpenAI if set by user
input,
});
return data.map((d) => d.embedding);
}
/**
* Get embeddings for a batch of texts
* @param texts
*/
async getTextEmbeddings(texts: string[]): Promise<number[][]> {
return await this.getOpenAIEmbedding(texts);
}
/**
* Get embeddings for a single text
* @param texts
*/
async getTextEmbedding(text: string): Promise<number[]> {
return (await this.getOpenAIEmbedding([text]))[0];
}
}