diff --git a/examples/abort-chat-completion.ts b/examples/abort-chat-completion.ts index c62e092..338c16c 100644 --- a/examples/abort-chat-completion.ts +++ b/examples/abort-chat-completion.ts @@ -1,6 +1,6 @@ import 'dotenv/config'; -import { ChatModel, type Msg, MsgUtil } from '@dexaai/dexter'; +import { ChatModel, type Msg, MsgUtil } from '../src/index.js'; /** * npx tsx examples/abort-chat-completion.ts diff --git a/examples/ai-function.ts b/examples/ai-function.ts index c21c42c..f1e5862 100644 --- a/examples/ai-function.ts +++ b/examples/ai-function.ts @@ -1,8 +1,14 @@ import 'dotenv/config'; -import { ChatModel, createAIFunction, type Msg, MsgUtil } from '@dexaai/dexter'; import { z } from 'zod'; +import { + ChatModel, + createAIFunction, + type Msg, + MsgUtil, +} from '../src/index.js'; + /** * npx tsx examples/ai-function.ts */ diff --git a/examples/ai-runner-with-error.ts b/examples/ai-runner-with-error.ts index faf3450..6c31467 100644 --- a/examples/ai-runner-with-error.ts +++ b/examples/ai-runner-with-error.ts @@ -1,12 +1,18 @@ import 'dotenv/config'; +import { z } from 'zod'; + import { ChatModel, createAIFunction, createAIRunner, MsgUtil, +<<<<<<< HEAD +} from '../src/index.js'; +======= } from '@dexaai/dexter'; import { z } from 'zod'; +>>>>>>> origin/master /** Get the weather for a given location. */ const getWeather = createAIFunction( diff --git a/examples/ai-runner.ts b/examples/ai-runner.ts index 9f5f03a..8cd469c 100644 --- a/examples/ai-runner.ts +++ b/examples/ai-runner.ts @@ -1,12 +1,18 @@ import 'dotenv/config'; +import { z } from 'zod'; + import { ChatModel, createAIFunction, createAIRunner, MsgUtil, +<<<<<<< HEAD +} from '../src/index.js'; +======= } from '@dexaai/dexter'; import { z } from 'zod'; +>>>>>>> origin/master /** Get the weather for a given location. */ const getWeather = createAIFunction( diff --git a/examples/anthropic-ai-function.ts b/examples/anthropic-ai-function.ts new file mode 100644 index 0000000..b4bda78 --- /dev/null +++ b/examples/anthropic-ai-function.ts @@ -0,0 +1,108 @@ +import 'dotenv/config'; + +import { z } from 'zod'; + +import { + ChatModel, + createAIFunction, + createAnthropicClient, + type Msg, + MsgUtil, +} from '../src/index.js'; + +/** + * npx tsx examples/ai-function.ts + */ +async function main() { + const getWeather = createAIFunction( + { + name: 'get_weather', + description: 'Gets the weather for a given location', + argsSchema: z.object({ + location: z + .string() + .describe('The city and state e.g. San Francisco, CA'), + unit: z + .enum(['c', 'f']) + .optional() + .default('f') + .describe('The unit of temperature to use'), + }), + }, + // Fake weather API implementation which returns a random temperature + // after a short delay + async (args: { location: string; unit?: string }) => { + await new Promise((resolve) => setTimeout(resolve, 500)); + + return { + location: args.location, + unit: args.unit, + temperature: (30 + Math.random() * 70) | 0, + }; + } + ); + + const chatModel = new ChatModel({ + debug: true, + client: createAnthropicClient(), + params: { + model: 'claude-2.0', + temperature: 0.5, + max_tokens: 500, + tools: [ + { + type: 'function', + function: getWeather.spec, + }, + ], + }, + }); + + const messages: Msg[] = [ + MsgUtil.user('What is the weather in San Francisco?'), + ]; + + { + // Invoke the chat model and have it create the args for the `get_weather` function + const { message } = await chatModel.run({ + messages, + model: 'claude-2.0', + tool_choice: { + type: 'function', + function: { + name: 'get_weather', + }, + }, + }); + + if (!MsgUtil.isToolCall(message)) { + throw new Error('Expected tool call'); + } + messages.push(message); + + for (const toolCall of message.tool_calls) { + if (toolCall.function.name !== 'get_weather') { + throw new Error(`Invalid function name: ${toolCall.function.name}`); + } + + const result = await getWeather(toolCall.function.arguments); + const toolResult = MsgUtil.toolResult(result, toolCall.id); + messages.push(toolResult); + } + } + + { + // Invoke the chat model with the result + const { message } = await chatModel.run({ + messages, + tool_choice: 'none', + }); + if (!MsgUtil.isAssistant(message)) { + throw new Error('Expected assistant message'); + } + + console.log(message.content); + } +} + +main(); diff --git a/examples/anthropic-ai-runner.ts b/examples/anthropic-ai-runner.ts new file mode 100644 index 0000000..a008459 --- /dev/null +++ b/examples/anthropic-ai-runner.ts @@ -0,0 +1,98 @@ +import 'dotenv/config'; + +import { z } from 'zod'; + +import { + ChatModel, + createAIFunction, + createAIRunner, + createAnthropicClient, + MsgUtil, +} from '../src/index.js'; + +/** Get the weather for a given location. */ +const getWeather = createAIFunction( + { + name: 'get_weather', + description: 'Gets the weather for a given location', + argsSchema: z.object({ + location: z + .string() + .describe('The city and state e.g. San Francisco, CA'), + unit: z + .enum(['c', 'f']) + .optional() + .default('f') + .describe('The unit of temperature to use'), + }), + }, + async ({ location, unit }) => { + await new Promise((resolve) => setTimeout(resolve, 500)); + const temperature = (30 + Math.random() * 70) | 0; + return { location, unit, temperature }; + } +); + +/** Get the capital city for a given state. */ +const getCapitalCity = createAIFunction( + { + name: 'get_capital_city', + description: 'Use this to get the the capital city for a given state', + argsSchema: z.object({ + state: z + .string() + .length(2) + .describe( + 'The state to get the capital city for, using the two letter abbreviation e.g. CA' + ), + }), + }, + async ({ state }) => { + await new Promise((resolve) => setTimeout(resolve, 500)); + let capitalCity = ''; + switch (state) { + case 'CA': + capitalCity = 'Sacramento'; + break; + case 'NY': + capitalCity = 'Albany'; + break; + default: + capitalCity = 'Unknown'; + } + return { capitalCity }; + } +); + +/** A runner that uses the weather and capital city functions. */ +const weatherCapitalRunner = createAIRunner({ + chatModel: new ChatModel({ + client: createAnthropicClient(), + params: { model: 'claude-2.0' }, + }), + functions: [getWeather, getCapitalCity], + systemMessage: `You use functions to answer questions about the weather and capital cities.`, +}); + +/** + * npx tsx examples/ai-runner.ts + */ +async function main() { + // Run with a string input + const rString = await weatherCapitalRunner( + `Whats the capital of California and NY and the weather for both` + ); + console.log('rString', rString); + + // Run with a message input + const rMessage = await weatherCapitalRunner({ + messages: [ + MsgUtil.user( + `Whats the capital of California and NY and the weather for both` + ), + ], + }); + console.log('rMessage', rMessage); +} + +main().catch(console.error); diff --git a/examples/caching-redis.ts b/examples/caching-redis.ts index dcae608..a39ad8b 100644 --- a/examples/caching-redis.ts +++ b/examples/caching-redis.ts @@ -1,9 +1,14 @@ import 'dotenv/config'; +<<<<<<< HEAD +======= import { EmbeddingModel } from '@dexaai/dexter'; +>>>>>>> origin/master import KeyvRedis from '@keyv/redis'; import Keyv from 'keyv'; +import { EmbeddingModel } from '../src/index.js'; + /** * npx tsx examples/caching-redis.ts */ diff --git a/examples/embeddings.ts b/examples/embeddings.ts new file mode 100644 index 0000000..135e3d6 --- /dev/null +++ b/examples/embeddings.ts @@ -0,0 +1,25 @@ +import 'dotenv/config'; + +import { createOpenAIClient, EmbeddingModel } from '../src/index.js'; + +/** + * npx tsx examples/ai-function.ts + */ +async function main() { + const embeddingModel = new EmbeddingModel({ + client: createOpenAIClient(), + params: { model: 'text-embedding-3-small' }, + }); + + { + // Invoke the chat model and have it create the args for the `get_weather` function + const response = await embeddingModel.run({ + input: ['What is the weather in San Francisco?'], + model: 'text-embedding-3-small', + }); + + console.log(response); + } +} + +main(); diff --git a/examples/extract-people-names.ts b/examples/extract-people-names.ts index 4943841..0eb219f 100644 --- a/examples/extract-people-names.ts +++ b/examples/extract-people-names.ts @@ -1,8 +1,13 @@ import 'dotenv/config'; +<<<<<<< HEAD +======= import { ChatModel, createExtractFunction } from '@dexaai/dexter'; +>>>>>>> origin/master import { z } from 'zod'; +import { ChatModel, createExtractFunction } from '../src/index.js'; + /** A function to extract people names from text. */ const extractPeopleNamesRunner = createExtractFunction({ chatModel: new ChatModel({ params: { model: 'gpt-4o-mini' } }), diff --git a/examples/with-telemetry.ts b/examples/with-telemetry.ts index 8ee7751..6b5863a 100644 --- a/examples/with-telemetry.ts +++ b/examples/with-telemetry.ts @@ -1,9 +1,10 @@ import './instrument.js'; import 'dotenv/config'; -import { ChatModel } from '@dexaai/dexter'; import * as Sentry from '@sentry/node'; +import { ChatModel } from '../src/index.js'; + const chatModel = new ChatModel({ // Send tracing data to Sentry telemetry: Sentry, diff --git a/package.json b/package.json index 5944781..964a928 100644 --- a/package.json +++ b/package.json @@ -49,7 +49,8 @@ "hash-object": "^5.0.1", "jsonrepair": "^3.8.1", "ky": "^1.7.2", - "openai-fetch": "3.3.1", + "ai-fetch": "../ai-fetch", + "openai-fetch": "../openai-fetch", "openai-zod-to-json-schema": "^1.0.3", "p-map": "^7.0.2", "p-throttle": "^6.2.0", @@ -58,7 +59,8 @@ "tslib": "^2.7.0", "zod": "^3.23.8", "zod-to-json-schema": "^3.23.3", - "zod-validation-error": "^3.4.0" + "zod-validation-error": "^3.4.0", + "anthropic-fetch": "../anthropic-fetch" }, "devDependencies": { "@dexaai/eslint-config": "^1.3.6", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index bb20d9c..676d0bc 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -11,6 +11,12 @@ importers: '@fastify/deepmerge': specifier: ^2.0.0 version: 2.0.0 + ai-fetch: + specifier: ../ai-fetch + version: link:../ai-fetch + anthropic-fetch: + specifier: ../anthropic-fetch + version: link:../anthropic-fetch dedent: specifier: ^1.5.3 version: 1.5.3 @@ -24,8 +30,13 @@ importers: specifier: ^1.7.2 version: 1.7.2 openai-fetch: +<<<<<<< HEAD + specifier: ../openai-fetch + version: link:../openai-fetch +======= specifier: 3.3.1 version: 3.3.1 +>>>>>>> origin/master openai-zod-to-json-schema: specifier: ^1.0.3 version: 1.0.3(zod@3.23.8) @@ -2319,9 +2330,16 @@ packages: resolution: {integrity: sha512-mnkeQ1qP5Ue2wd+aivTD3NHd/lZ96Lu0jgf0pwktLPtx6cTZiH7tyeGRRHs0zX0rbrahXPnXlUnbeXyaBBuIaw==} engines: {node: '>=18'} +<<<<<<< HEAD + openai-zod-to-json-schema@1.0.3: + resolution: {integrity: sha512-CFU+KtOmX1dk2nPCZcGYgbrI3YLJJgMSehx1mLbH1A2fsRmZevHzMau6vFIhtkCpHWkGQ3ossA4a0OzVHlGrkw==} +======= openai-fetch@3.3.1: resolution: {integrity: sha512-/b7rPeKLgS+3C2dxQHPiWDj4wOcbL/SF5L2dxktmJyfFza/VK6Mr3+rIldgGxRNpqsa3oonEowafPNx5Tdq9dA==} +>>>>>>> origin/master engines: {node: '>=18'} + peerDependencies: + zod: ^3.23.8 openai-zod-to-json-schema@1.0.3: resolution: {integrity: sha512-CFU+KtOmX1dk2nPCZcGYgbrI3YLJJgMSehx1mLbH1A2fsRmZevHzMau6vFIhtkCpHWkGQ3ossA4a0OzVHlGrkw==} @@ -5718,9 +5736,13 @@ snapshots: is-inside-container: 1.0.0 is-wsl: 3.1.0 +<<<<<<< HEAD + openai-zod-to-json-schema@1.0.3(zod@3.23.8): +======= openai-fetch@3.3.1: +>>>>>>> origin/master dependencies: - ky: 1.7.2 + zod: 3.23.8 openai-zod-to-json-schema@1.0.3(zod@3.23.8): dependencies: diff --git a/src/ai-function/ai-runner.ts b/src/ai-function/ai-runner.ts index 98dee0a..823dd65 100644 --- a/src/ai-function/ai-runner.ts +++ b/src/ai-function/ai-runner.ts @@ -162,7 +162,7 @@ export function createAIRunner(args: { function getParams(args: { functions?: AIFunction[]; mode: AIRunner.Mode; -}): Pick { +}): Pick, 'functions' | 'tools'> { const { functions } = args; // Return an empty object if there are no functions if (!functions?.length) { diff --git a/src/ai-function/types.ts b/src/ai-function/types.ts index a8e5f5f..d93db5f 100644 --- a/src/ai-function/types.ts +++ b/src/ai-function/types.ts @@ -12,11 +12,19 @@ export type AIRunner = ( ) => Promise>; export namespace AIRunner { + export type Client = any; + /** Parameters to execute a runner */ - export type Params = SetOptional; + export type Params = SetOptional< + Model.Chat.Run & Model.Chat.Config, + 'model' + >; export type ModelParams = Partial< - Omit + Omit< + Model.Chat.Run & Model.Chat.Config, + 'messages' | 'functions' | 'tools' + > >; /** Response from executing a runner */ diff --git a/src/index.ts b/src/index.ts index d5658eb..9c65d2d 100644 --- a/src/index.ts +++ b/src/index.ts @@ -12,6 +12,7 @@ export { export { createExtractFunction } from './extract/index.js'; export type { ChatModelArgs } from './model/chat.js'; export { ChatModel } from './model/chat.js'; +export { createAnthropicClient } from './model/clients/anthropic.js'; export { createOpenAIClient } from './model/clients/openai.js'; export type { CompletionModelArgs } from './model/completion.js'; export { CompletionModel } from './model/completion.js'; diff --git a/src/model/chat.test.ts b/src/model/chat.test.ts index 0af890f..d7c4340 100644 --- a/src/model/chat.test.ts +++ b/src/model/chat.test.ts @@ -120,7 +120,7 @@ describe('ChatModel', () => { expect(apiResponseEvent).toHaveBeenCalledWith({ timestamp: new Date().toISOString(), modelType: 'chat', - modelProvider: 'openai', + modelProvider: 'spy', // because we mocked the client params: { model: 'gpt-fake', messages: [{ role: 'user', content: 'content' }], @@ -133,7 +133,7 @@ describe('ChatModel', () => { it('implements extend', async () => { type ChatContext = { userId: string; cloned?: boolean }; - const chatModel = new ChatModel({ + const chatModel = new ChatModel>({ client: Client, context: { userId: '123' }, params: { model: 'gpt-fake' }, @@ -220,7 +220,7 @@ describe('ChatModel', () => { // Extend the model and make another request const secondChatModel = chatModel.extend({ - params: { model: 'gpt-fake-extended' }, + params: { model: 'gpt-fake-extended'}, context: { level: 2 }, events: { onComplete: [newOnComplete] }, }); diff --git a/src/model/chat.ts b/src/model/chat.ts index 66ed9f9..22f4fb5 100644 --- a/src/model/chat.ts +++ b/src/model/chat.ts @@ -8,10 +8,14 @@ import { calculateCost } from './utils/calculate-cost.js'; import { deepMerge, mergeEvents, type Prettify } from './utils/helpers.js'; import { MsgUtil } from './utils/message-util.js'; -export type ChatModelArgs = SetOptional< +export type ChatModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Chat.Client, + CustomConfig extends Model.Chat.Config, +> = SetOptional< ModelArgs< - Model.Chat.Client, - Model.Chat.Config, + CustomClient, + CustomConfig, Model.Chat.Run, Model.Chat.Response, CustomCtx @@ -19,25 +23,41 @@ export type ChatModelArgs = SetOptional< 'client' | 'params' >; -export type PartialChatModelArgs = Prettify< - PartialDeep>, 'params'>> & - Partial>, 'params'>> +export type PartialChatModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Chat.Client, + CustomConfig extends Model.Chat.Config, +> = Prettify< + PartialDeep< + Pick< + ChatModelArgs, CustomClient, CustomConfig>, + 'params' + > + > & + Partial< + Omit< + ChatModelArgs, CustomClient, CustomConfig>, + 'params' + > + > >; export class ChatModel< - CustomCtx extends Model.Ctx = Model.Ctx, + CustomCtx extends Model.Ctx, + CustomClient extends Model.Chat.Client, + CustomConfig extends Model.Chat.Config, > extends AbstractModel< - Model.Chat.Client, - Model.Chat.Config, + CustomClient, + CustomConfig, Model.Chat.Run, Model.Chat.Response, Model.Chat.ApiResponse, CustomCtx > { modelType = 'chat' as const; - modelProvider = 'openai' as const; + modelProvider = 'openai'; - constructor(args: ChatModelArgs = {}) { + constructor(args: ChatModelArgs = {}) { const { // Add a default client if none is provided client = createOpenAIClient(), @@ -49,8 +69,8 @@ export class ChatModel< } = args; super({ - client, - params, + client: client as CustomClient, + params: params as CustomConfig & Partial, debug, events: mergeEvents( events, @@ -63,23 +83,27 @@ export class ChatModel< ), ...rest, }); + + this.modelProvider = this.client.name; } - protected async runModel( - { - handleUpdate, - requestOpts, - ...params - }: Model.Chat.Run & Model.Chat.Config, + protected async runModel>( + { handleUpdate, requestOpts, ...params }: Partial & Model.Chat.Run, context: CustomCtx ): Promise { const start = Date.now(); + const allParams = { + ...this.params, + ...params, + messages: params.messages ?? this.params.messages ?? [], + }; + // Use non-streaming API if no handler is provided if (!handleUpdate) { // Make the OpenAI API request const response = await this.client.createChatCompletion( - params, + allParams, requestOpts ); @@ -90,7 +114,7 @@ export class ChatModel< timestamp: new Date().toISOString(), modelType: this.modelType, modelProvider: this.modelProvider, - params, + params: allParams, response, latency: Date.now() - start, context, @@ -106,14 +130,14 @@ export class ChatModel< message, cached: false, latency: Date.now() - start, - cost: calculateCost({ model: params.model, tokens: response.usage }), + cost: calculateCost({ model: allParams.model, tokens: response.usage }), }; return modelResponse; } else { // Use the streaming API if a handler is provided const stream = await this.client.streamChatCompletion( - params, + allParams, requestOpts ); @@ -228,7 +252,7 @@ export class ChatModel< timestamp: new Date().toISOString(), modelType: this.modelType, modelProvider: this.modelProvider, - params, + params: allParams, response, latency: Date.now() - start, context, @@ -244,7 +268,7 @@ export class ChatModel< message, cached: false, latency: Date.now() - start, - cost: calculateCost({ model: params.model, tokens: response.usage }), + cost: calculateCost({ model: allParams.model, tokens: response.usage }), }; return modelResponse; @@ -252,15 +276,16 @@ export class ChatModel< } /** Clone the model and merge/override the given properties. */ - extend(args?: PartialChatModelArgs): this { + extend(args?: PartialChatModelArgs>): this { + const { client, params, ...rest } = args ?? {}; return new ChatModel({ cacheKey: this.cacheKey, cache: this.cache, - client: this.client, + client: client ?? this.client, debug: this.debug, telemetry: this.telemetry, - ...args, - params: deepMerge(this.params, args?.params), + ...rest, + params: deepMerge(this.params, params), context: args?.context && Object.keys(args.context).length === 0 ? undefined diff --git a/src/model/clients/anthropic.ts b/src/model/clients/anthropic.ts new file mode 100644 index 0000000..0b953f5 --- /dev/null +++ b/src/model/clients/anthropic.ts @@ -0,0 +1,22 @@ +import { AnthropicClient } from 'anthropic-fetch'; + +/** Cached Anthropic clients. */ +const cachedClients = new Map(); + +/** Create a new anthropic-fetch AnthropicClient. */ +export function createAnthropicClient( + /** Options to pass to the Anthropic client. */ + opts?: ConstructorParameters[0], + /** Force a new client to be created. */ + forceNew = false +): AnthropicClient { + if (!forceNew) { + const cachedClient = cachedClients.get(JSON.stringify(opts)); + if (cachedClient) return cachedClient; + } + + const client = new AnthropicClient(opts); + cachedClients.set(JSON.stringify(opts), client); + + return client; +} diff --git a/src/model/completion.ts b/src/model/completion.ts index 701e059..8a30764 100644 --- a/src/model/completion.ts +++ b/src/model/completion.ts @@ -6,10 +6,14 @@ import { type Model } from './types.js'; import { calculateCost } from './utils/calculate-cost.js'; import { deepMerge, mergeEvents, type Prettify } from './utils/helpers.js'; -export type CompletionModelArgs = SetOptional< +export type CompletionModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Completion.Client, + CustomConfig extends Model.Completion.Config, +> = SetOptional< ModelArgs< - Model.Completion.Client, - Model.Completion.Config, + CustomClient, + CustomConfig, Model.Completion.Run, Model.Completion.Response, CustomCtx @@ -17,16 +21,33 @@ export type CompletionModelArgs = SetOptional< 'client' | 'params' >; -export type PartialCompletionModelArgs = Prettify< - PartialDeep>, 'params'>> & - Partial>, 'params'>> +export type PartialCompletionModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Completion.Client, + CustomConfig extends Model.Completion.Config, +> = Prettify< + PartialDeep< + Pick< + CompletionModelArgs, CustomClient, CustomConfig>, + 'params' + > + > & + Partial< + Omit< + CompletionModelArgs, CustomClient, CustomConfig>, + 'params' + > + > >; export class CompletionModel< CustomCtx extends Model.Ctx = Model.Ctx, + CustomClient extends Model.Completion.Client = Model.Completion.Client, + CustomConfig extends + Model.Completion.Config = Model.Completion.Config, > extends AbstractModel< - Model.Completion.Client, - Model.Completion.Config, + CustomClient, + CustomConfig, Model.Completion.Run, Model.Completion.Response, Model.Completion.ApiResponse, @@ -35,24 +56,41 @@ export class CompletionModel< modelType = 'completion' as const; modelProvider = 'openai' as const; - constructor(args?: CompletionModelArgs) { + constructor( + args?: CompletionModelArgs + ) { let { client, params } = args ?? {}; const { client: _, params: __, ...rest } = args ?? {}; // Add a default client if none is provided - client = client ?? createOpenAIClient(); + client = (client ?? createOpenAIClient()) as CustomClient; // Set default model if no params are provided - params = params ?? { model: 'gpt-3.5-turbo-instruct' }; + params = + params ?? + ({ model: 'gpt-3.5-turbo-instruct' } as CustomConfig & + Partial); super({ client, params, ...rest }); } protected async runModel( - { requestOpts, ...params }: Model.Completion.Run & Model.Completion.Config, + { + requestOpts, + ...params + }: Partial>, context: CustomCtx ): Promise { const start = Date.now(); + const allParams = { + ...this.params, + ...params, + prompt: params.prompt ?? this.params.prompt ?? null, + }; + // Make the OpenAI API request - const response = await this.client.createCompletions(params, requestOpts); + const response = await this.client.createCompletions( + allParams, + requestOpts + ); await Promise.allSettled( this.events?.onApiResponse?.map((event) => @@ -61,7 +99,7 @@ export class CompletionModel< timestamp: new Date().toISOString(), modelType: this.modelType, modelProvider: this.modelProvider, - params, + params: allParams, response, latency: Date.now() - start, context, @@ -74,14 +112,16 @@ export class CompletionModel< ...response, completion: response.choices[0].text, cached: false, - cost: calculateCost({ model: params.model, tokens: response.usage }), + cost: calculateCost({ model: allParams.model, tokens: response.usage }), }; return modelResponse; } /** Clone the model and merge/override the given properties. */ - extend(args?: PartialCompletionModelArgs): this { + extend( + args?: PartialCompletionModelArgs + ): this { return new CompletionModel({ cacheKey: this.cacheKey, cache: this.cache, diff --git a/src/model/embedding.ts b/src/model/embedding.ts index 1430f33..1966b62 100644 --- a/src/model/embedding.ts +++ b/src/model/embedding.ts @@ -8,10 +8,15 @@ import { type Model } from './types.js'; import { calculateCost } from './utils/calculate-cost.js'; import { deepMerge, mergeEvents, type Prettify } from './utils/helpers.js'; -export type EmbeddingModelArgs = SetOptional< +export type EmbeddingModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Embedding.Client = Model.Embedding.Client, + CustomConfig extends + Model.Embedding.Config = Model.Embedding.Config, +> = SetOptional< ModelArgs< - Model.Embedding.Client, - Model.Embedding.Config, + CustomClient, + CustomConfig, Model.Embedding.Run, Model.Embedding.Response, CustomCtx @@ -19,13 +24,33 @@ export type EmbeddingModelArgs = SetOptional< 'client' | 'params' >; -export type PartialEmbeddingModelArgs = Prettify< - PartialDeep>, 'params'>> & - Partial>, 'params'>> +export type PartialEmbeddingModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Embedding.Client = Model.Embedding.Client, + CustomConfig extends + Model.Embedding.Config = Model.Embedding.Config, +> = Prettify< + PartialDeep< + Pick< + EmbeddingModelArgs, CustomClient, CustomConfig>, + 'params' + > + > & + Partial< + Omit< + EmbeddingModelArgs, CustomClient, CustomConfig>, + 'params' + > + > >; -type BulkEmbedder = ( - params: Model.Embedding.Run & Model.Embedding.Config, +type BulkEmbedder< + CustomCtx extends Model.Ctx, + CustomClient extends Model.Embedding.Client = Model.Embedding.Client, + CustomConfig extends + Model.Embedding.Config = Model.Embedding.Config, +> = ( + params: Model.Embedding.Run & CustomConfig, context: CustomCtx ) => Promise; @@ -41,9 +66,12 @@ const DEFAULTS = { export class EmbeddingModel< CustomCtx extends Model.Ctx = Model.Ctx, + CustomClient extends Model.Embedding.Client = Model.Embedding.Client, + CustomConfig extends + Model.Embedding.Config = Model.Embedding.Config, > extends AbstractModel< - Model.Embedding.Client, - Model.Embedding.Config, + CustomClient, + CustomConfig, Model.Embedding.Run, Model.Embedding.Response, Model.Embedding.ApiResponse, @@ -51,15 +79,21 @@ export class EmbeddingModel< > { modelType = 'embedding' as const; modelProvider = 'openai' as const; - throttledModel: BulkEmbedder; + throttledModel: BulkEmbedder; - constructor(args: EmbeddingModelArgs = {}) { + constructor( + args: EmbeddingModelArgs = {} + ) { const { client = createOpenAIClient(), params = { model: DEFAULTS.model }, ...rest } = args; - super({ client, params, ...rest }); + super({ + client: client as CustomClient, + params: params as CustomConfig, + ...rest, + }); const interval = DEFAULTS.throttleInterval; const limit = @@ -68,7 +102,7 @@ export class EmbeddingModel< // Create the throttled function this.throttledModel = pThrottle({ limit, interval })( async ( - params: Model.Embedding.Run & Model.Embedding.Config, + params: Model.Embedding.Run & Model.Embedding.Config, context: CustomCtx ) => { const start = Date.now(); @@ -86,7 +120,7 @@ export class EmbeddingModel< timestamp: new Date().toISOString(), modelType: this.modelType, modelProvider: this.modelProvider, - params, + params: params as Model.Embedding.Run & CustomConfig, response, latency: Date.now() - start, context, @@ -110,17 +144,31 @@ export class EmbeddingModel< return modelResponse; } - ); + ) as BulkEmbedder< + CustomCtx, + CustomClient, + Model.Embedding.Config + >; } protected async runModel( - { requestOpts, ...params }: Model.Embedding.Run & Model.Embedding.Config, + { + requestOpts, + ...params + }: Model.Embedding.Run & Partial>, context: CustomCtx ): Promise { const start = Date.now(); + + const allParams = { + ...this.params, + ...params, + input: params.input ?? this.params.input ?? [], + }; + // Batch the inputs for the requests const batches = batchInputs({ - input: params.input, + input: allParams.input, tokenizer: this.tokenizer, options: this.params.batch, }); @@ -136,7 +184,7 @@ export class EmbeddingModel< input: batch, model: this.params.model, requestOpts, - }, + } as Model.Embedding.Run & CustomConfig, mergedContext ); return response; @@ -164,7 +212,7 @@ export class EmbeddingModel< data: embeddingsObjs, embeddings: embeddingBatches.map((batch) => batch.embeddings).flat(), cached: false, - cost: calculateCost({ model: params.model, tokens: usage }), + cost: calculateCost({ model: allParams.model, tokens: usage }), latency: Date.now() - start, }; @@ -172,8 +220,10 @@ export class EmbeddingModel< } /** Clone the model and merge/override the given properties. */ - extend(args?: PartialEmbeddingModelArgs): this { - return new EmbeddingModel({ + extend( + args?: PartialEmbeddingModelArgs + ): this { + return new EmbeddingModel({ cacheKey: this.cacheKey, cache: this.cache, client: this.client, diff --git a/src/model/model.ts b/src/model/model.ts index a5ede51..ec8b684 100644 --- a/src/model/model.ts +++ b/src/model/model.ts @@ -19,7 +19,7 @@ import { createTokenizer } from './utils/tokenizer.js'; export interface ModelArgs< MClient extends Model.Base.Client, - MConfig extends Model.Base.Config, + MConfig extends Model.Base.Config, MRun extends Model.Base.Run, MResponse extends Model.Base.Response, Ctx extends Model.Ctx, @@ -41,7 +41,7 @@ export interface ModelArgs< client: MClient; context?: Ctx; params: MConfig & Partial; - events?: Model.Events; + events?: Model.Events; telemetry?: Telemetry.Provider; /** Whether or not to add default `console.log` event handlers */ debug?: boolean; @@ -49,7 +49,7 @@ export interface ModelArgs< export type PartialModelArgs< MClient extends Model.Base.Client, - MConfig extends Model.Base.Config, + MConfig extends Model.Base.Config, MRun extends Model.Base.Run, MResponse extends Model.Base.Response, CustomCtx extends Model.Ctx, @@ -70,35 +70,46 @@ export type PartialModelArgs< export abstract class AbstractModel< MClient extends Model.Base.Client, - MConfig extends Model.Base.Config, + MConfig extends Model.Base.Config, MRun extends Model.Base.Run, MResponse extends Model.Base.Response, AResponse = any, CustomCtx extends Model.Ctx = Model.Ctx, > { /** This is used to implement specific model calls */ - protected abstract runModel( - params: Prettify, + protected abstract runModel>( + params: Prettify, context: CustomCtx ): Promise; /** Clones the model, optionally modifying its config */ - abstract extend< - Args extends PartialModelArgs, - >(args?: Args): this; + abstract extend( + args?: PartialModelArgs< + MClient, + Model.Base.Config, + MRun, + // Note: this response type maybe change over time as the user + // extends the model + // it should be inferred from some types rather than set to MResponse + MResponse, + CustomCtx + > + ): this; public abstract readonly modelType: Model.Type; - public abstract readonly modelProvider: Model.Provider; + public abstract modelProvider: Model.Provider; - protected readonly cacheKey: CacheKey; + // the cache key can be updated in a call to .extend so it doesn't necessarily conform to MRun & MConfig + protected readonly cacheKey: CacheKey, string>; protected readonly cache?: CacheStorage; public readonly client: MClient; public readonly context: CustomCtx; public readonly debug: boolean; public readonly params: MConfig & Partial; public readonly events: Model.Events< - MRun & MConfig, - MResponse, + Model.Base.Client, + Model.Base.Run & Model.Base.Config, + Model.Base.Response, CustomCtx, AResponse >; @@ -106,19 +117,28 @@ export abstract class AbstractModel< public readonly telemetry: Telemetry.Provider; constructor(args: ModelArgs) { - this.cacheKey = args.cacheKey ?? defaultCacheKey; + this.cacheKey = (args.cacheKey ?? defaultCacheKey) as CacheKey< + Model.Base.Run & Model.Base.Config, + string + >; this.cache = args.cache; this.client = args.client; this.context = args.context ?? ({} as CustomCtx); this.debug = args.debug ?? false; this.params = args.params; - this.events = args.events || {}; + this.events = (args.events || {}) as Model.Events< + Model.Base.Client, + Model.Base.Run & Model.Base.Config, + Model.Base.Response, + CustomCtx, + AResponse + >; this.tokenizer = createTokenizer(args.params.model); this.telemetry = args.telemetry ?? DefaultTelemetry; } - async run( - params: Prettify>, + async run>( + params: Prettify>, context?: CustomCtx ): Promise { const mergedContext = deepMerge(this.context, context); diff --git a/src/model/sparse-vector.ts b/src/model/sparse-vector.ts index ef9ba5a..25a22e3 100644 --- a/src/model/sparse-vector.ts +++ b/src/model/sparse-vector.ts @@ -8,11 +8,16 @@ import { AbstractModel, type ModelArgs } from './model.js'; import { type Model } from './types.js'; import { deepMerge, mergeEvents, type Prettify } from './utils/helpers.js'; -export type SparseVectorModelArgs = Prettify< +export type SparseVectorModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.SparseVector.Client = Model.SparseVector.Client, + CustomConfig extends + Model.SparseVector.Config = Model.SparseVector.Config, +> = Prettify< Omit< ModelArgs< - Model.SparseVector.Client, - Model.SparseVector.Config, + CustomClient, + CustomConfig, Model.SparseVector.Run, Model.SparseVector.Response, CustomCtx @@ -23,17 +28,34 @@ export type SparseVectorModelArgs = Prettify< } >; -export type PartialSparseVectorModelArgs = - Prettify< - PartialDeep>, 'params'>> & - Partial>, 'params'>> - >; +export type PartialSparseVectorModelArgs< + CustomCtx extends Model.Ctx, + CustomClient extends Model.SparseVector.Client = Model.SparseVector.Client, + CustomConfig extends + Model.SparseVector.Config = Model.SparseVector.Config, +> = Prettify< + PartialDeep< + Pick< + SparseVectorModelArgs, CustomClient, CustomConfig>, + 'params' + > + > & + Partial< + Omit< + SparseVectorModelArgs, CustomClient, CustomConfig>, + 'params' + > + > +>; export class SparseVectorModel< CustomCtx extends Model.Ctx = Model.Ctx, + CustomClient extends Model.SparseVector.Client = Model.SparseVector.Client, + CustomConfig extends + Model.SparseVector.Config = Model.SparseVector.Config, > extends AbstractModel< - Model.SparseVector.Client, - Model.SparseVector.Config, + CustomClient, + CustomConfig, Model.SparseVector.Run, Model.SparseVector.Response, Model.SparseVector.ApiResponse, @@ -43,9 +65,11 @@ export class SparseVectorModel< modelProvider = 'custom' as const; serviceUrl: string; - constructor(args: SparseVectorModelArgs) { + constructor( + args: SparseVectorModelArgs + ) { const { serviceUrl, ...rest } = args; - super({ client: createSpladeClient(), ...rest }); + super({ client: createSpladeClient() as CustomClient, ...rest }); const safeProcess = globalThis.process || { env: {} }; const tempServiceUrl = serviceUrl || safeProcess.env.SPLADE_SERVICE_URL; if (!tempServiceUrl) { @@ -58,7 +82,8 @@ export class SparseVectorModel< { requestOpts: _, ...params - }: Model.SparseVector.Run & Model.SparseVector.Config, + }: Model.SparseVector.Run & + Partial>, context: CustomCtx ): Promise { const start = Date.now(); @@ -66,16 +91,19 @@ export class SparseVectorModel< const limit = params.throttleLimit ?? 600; const concurrency = params.concurrency ?? 10; + const model = params.model ?? this.params.model; + const input = params.input ?? this.params.input ?? []; + // Create a throttled version of the function for a single request const throttled = pThrottle({ limit, interval })( - async (params: { input: string; model: string }) => + async (params: { input: string; model: CustomConfig['model'] }) => this.runSingle(params, context) ); // Run the requests in parallel, respecting the maxConcurrentRequests value - const inputs = params.input.map((input) => ({ + const inputs = input.map((input) => ({ input, - model: params.model, + model, })); const responses = await pMap(inputs, throttled, { concurrency }); @@ -89,7 +117,7 @@ export class SparseVectorModel< protected async runSingle( params: { input: string; - model: string; + model: CustomConfig['model']; requestOpts?: { headers?: KYOptions['headers']; }; @@ -111,6 +139,7 @@ export class SparseVectorModel< // Don't need tokens for this model const tokens = { prompt: 0, completion: 0, total: 0 } as const; const { input, model } = params; + await Promise.allSettled( this.events?.onApiResponse?.map((event) => Promise.resolve( @@ -131,7 +160,9 @@ export class SparseVectorModel< } /** Clone the model and merge/override the given properties. */ - extend(args?: PartialSparseVectorModelArgs): this { + extend( + args?: PartialSparseVectorModelArgs + ): this { return new SparseVectorModel({ cacheKey: this.cacheKey, cache: this.cache, diff --git a/src/model/telemetry/extractors.ts b/src/model/telemetry/extractors.ts index f39a2b5..27aac66 100644 --- a/src/model/telemetry/extractors.ts +++ b/src/model/telemetry/extractors.ts @@ -43,8 +43,8 @@ export function extractAttrsFromParams(params: { modelProvider?: string; max_tokens?: number; temperature?: number; - functions?: ChatParams['functions']; - tools?: ChatParams['tools']; + functions?: ChatParams['functions']; + tools?: ChatParams['tools']; messages?: ChatMessage[]; prompt?: string | string[]; input?: string[]; @@ -99,7 +99,9 @@ export function extractAttrsFromResponse(resp: { } } -function extractAttrsFromFunctions(funcs?: ChatParams['functions']): AttrMap { +function extractAttrsFromFunctions( + funcs?: ChatParams['functions'] +): AttrMap { const attrs: AttrMap = {}; if (!funcs) return attrs; funcs.forEach((func, index) => { @@ -111,7 +113,7 @@ function extractAttrsFromFunctions(funcs?: ChatParams['functions']): AttrMap { return attrs; } -function extractAttrsFromTools(tools?: ChatParams['tools']): AttrMap { +function extractAttrsFromTools(tools?: ChatParams['tools']): AttrMap { const attrs: AttrMap = {}; if (!tools) return attrs; tools.forEach((tool, index) => { diff --git a/src/model/types.ts b/src/model/types.ts index d3dad77..8eca220 100644 --- a/src/model/types.ts +++ b/src/model/types.ts @@ -1,6 +1,11 @@ /* eslint-disable no-use-before-define */ -import { type Options as KYOptions } from 'ky'; import { +<<<<<<< HEAD + type AIChatClient, + type AICompletionClient, + type AIEmbeddingClient, +======= +>>>>>>> origin/master type ChatParams, type ChatResponse, type ChatStreamResponse, @@ -8,8 +13,8 @@ import { type CompletionResponse, type EmbeddingParams, type EmbeddingResponse, - type OpenAIClient, -} from 'openai-fetch'; +} from 'ai-fetch'; +import { type Options as KYOptions } from 'ky'; import { type ChatModel } from './chat.js'; import { type CompletionModel } from './completion.js'; @@ -120,8 +125,19 @@ export namespace Model { export namespace Base { /** Client for making API calls. Extended by specific model clients. */ export type Client = any; - export interface Config { - model: string; + + export type AvailableModels = C extends Client + ? { + [K in keyof C]: C[K] extends (params: infer P) => any + ? P extends { model: infer M } + ? M + : never + : never; + }[keyof C] + : never; + + export interface Config { + model: AvailableModels; } export interface Run { [key: string]: any; @@ -130,43 +146,58 @@ export namespace Model { headers?: KYOptions['headers']; }; } - export interface Params extends Config, Run {} + export interface Params extends Config, Run {} export interface Response { cached: boolean; latency?: number; cost?: number; } - export type Model = AbstractModel; + export type Model = AbstractModel< + Client, + Config, + Run, + Response, + any + >; } /** * Chat Model */ export namespace Chat { - export type Client = { - createChatCompletion: OpenAIClient['createChatCompletion']; - streamChatCompletion: OpenAIClient['streamChatCompletion']; - }; + export type Client = AIChatClient; + + type AvailableModels = C extends Client + ? C extends { createChatCompletion: (params: infer P) => any } + ? P extends { model: infer M } + ? M + : never + : never + : never; + + type Params = ChatParams>; + export interface Run extends Base.Run { messages: Msg[]; + handleUpdate?: (chunk: string) => void; } - export interface Config extends Base.Config { + export interface Config extends Base.Config { /** Handle new chunk from streaming requests. */ handleUpdate?: (chunk: string) => void; - frequency_penalty?: ChatParams['frequency_penalty']; - function_call?: ChatParams['function_call']; - functions?: ChatParams['functions']; - logit_bias?: ChatParams['logit_bias']; - max_tokens?: ChatParams['max_tokens']; - model: ChatParams['model']; - presence_penalty?: ChatParams['presence_penalty']; - response_format?: ChatParams['response_format']; - seed?: ChatParams['seed']; - stop?: ChatParams['stop']; - temperature?: ChatParams['temperature']; - tools?: ChatParams['tools']; - tool_choice?: ChatParams['tool_choice']; - top_p?: ChatParams['top_p']; + frequency_penalty?: Params['frequency_penalty']; + function_call?: Params['function_call']; + functions?: Params['functions']; + logit_bias?: Params['logit_bias']; + max_tokens?: Params['max_tokens']; + model: Params['model']; + presence_penalty?: Params['presence_penalty']; + response_format?: Params['response_format']; + seed?: Params['seed']; + stop?: Params['stop']; + temperature?: Params['temperature']; + tools?: Params['tools']; + tool_choice?: Params['tool_choice']; + top_p?: Params['top_p']; } export interface Response extends Base.Response, ChatResponse { message: Msg; @@ -176,23 +207,32 @@ export namespace Model { /** A chunk recieved from a streaming response */ export type CompletionChunk = InnerType; export type ApiResponse = ChatResponse; - export type Model = ChatModel; + export type Model = ChatModel>; } /** * Completion model */ export namespace Completion { - export type Client = { - createCompletions: OpenAIClient['createCompletions']; - }; + export type Client = AICompletionClient; + + type AvailableModels = C extends Client + ? C extends { createCompletion: (params: infer P) => any } + ? P extends { model: infer M } + ? M + : never + : never + : never; + + type Params = CompletionParams>; + export interface Run extends Base.Run { prompt: string | string[] | number[] | number[][] | null; } - export interface Config - extends Base.Config, - Omit { - model: CompletionParams['model']; + export interface Config + extends Base.Config, + Omit, 'prompt' | 'user'> { + model: Params['model']; } export interface Response extends Base.Response, CompletionResponse { completion: string; @@ -208,9 +248,18 @@ export namespace Model { * Embedding Model */ export namespace Embedding { - export type Client = { - createEmbeddings: OpenAIClient['createEmbeddings']; - }; + export type Client = AIEmbeddingClient; + + type AvailableModels = C extends Client + ? C extends { createEmbeddings: (params: infer P) => any } + ? P extends { model: infer M } + ? M + : never + : never + : never; + + type Params = EmbeddingParams>; + export interface Run extends Base.Run { input: string[]; } @@ -224,10 +273,10 @@ export namespace Model { maxRequestsPerMin: number; maxConcurrentRequests: number; } - export interface Config - extends Base.Config, - Omit { - model: EmbeddingParams['model']; + export interface Config + extends Base.Config, + Omit, 'input' | 'user'> { + model: Params['model']; batch?: Partial; throttle?: Partial; } @@ -242,8 +291,9 @@ export namespace Model { * Event handlers for logging and debugging */ export interface Events< - MParams extends Base.Params, - MResponse extends Base.Response, + C extends Model.Base.Client, + MParams extends Model.Base.Params, + MResponse extends Model.Base.Response, MCtx extends Model.Ctx, AResponse = any, > { @@ -251,14 +301,14 @@ export namespace Model { timestamp: string; modelType: Type; modelProvider: Provider; - params: Readonly; + params: Readonly>>; context: Readonly; }) => void | Promise)[]; - onApiResponse?: ((event: { + onApiResponse?: (

>(event: { timestamp: string; modelType: Type; modelProvider: Provider; - params: Readonly; + params: Readonly

; response: Readonly; latency: number; context: Readonly; @@ -334,7 +384,7 @@ export namespace Model { export interface Run extends Model.Base.Run { input: string[]; } - export interface Config extends Model.Base.Config { + export interface Config extends Model.Base.Config { concurrency?: number; throttleLimit?: number; throttleInterval?: number; diff --git a/src/model/utils/message-util.ts b/src/model/utils/message-util.ts index 4e0ab4b..17ccc21 100644 --- a/src/model/utils/message-util.ts +++ b/src/model/utils/message-util.ts @@ -1,5 +1,5 @@ +import { type ChatMessage, type ChatResponse } from 'ai-fetch'; import dedent from 'dedent'; -import { type ChatMessage, type ChatResponse } from 'openai-fetch'; import { type Jsonifiable } from 'type-fest'; import { type Msg } from '../types.js'; diff --git a/src/swarm/swarm.ts b/src/swarm/swarm.ts index 59e3eca..a4bfabf 100644 --- a/src/swarm/swarm.ts +++ b/src/swarm/swarm.ts @@ -1,6 +1,7 @@ import pMap from 'p-map'; -import { ChatModel, type Msg, MsgUtil } from '../index.js'; +import { ChatModel, createOpenAIClient, type Msg, MsgUtil } from '../index.js'; +import { type Model } from '../model/types.js'; import { getErrorMsg } from '../model/utils/errors.js'; import { type Agent, @@ -11,14 +12,27 @@ import { } from './types.js'; export class Swarm { - chatModel: ChatModel; - defaultModel: string; - - constructor(args?: { chatModel?: ChatModel }) { - this.defaultModel = 'gpt-4o'; + chatModel: ChatModel< + Model.Ctx, + Model.Chat.Client, + Model.Chat.Config + >; + defaultModel: Model.Base.AvailableModels; + + constructor(args?: { + chatModel?: ChatModel< + Model.Ctx, + Model.Chat.Client, + Model.Chat.Config + >; + }) { this.chatModel = args?.chatModel || - new ChatModel({ params: { model: this.defaultModel } }); + new ChatModel({ + client: createOpenAIClient(), + params: { model: 'gpt-4o' }, + }); + this.defaultModel = this.chatModel.params.model; } async run(args: {