From 93b038dca93e11457c05cd27ff7d6650a34b8d13 Mon Sep 17 00:00:00 2001 From: Riley Tomasek Date: Sun, 13 Oct 2024 22:40:59 -0400 Subject: [PATCH] Add rough Swarm implementation This commit introduces a new `Swarm` module that includes REPL functionality and tools for handling functions. Additional changes include the integration of the `chalk` library for colorful terminal output. Modifications are made to `package.json` to reflect these new dependencies and module structures. --- package.json | 5 + pnpm-lock.yaml | 3 + src/swarm/index.ts | 3 + src/swarm/repl.ts | 79 +++++++++++++++ src/swarm/swarm-tools.ts | 98 +++++++++++++++++++ src/swarm/swarm.ts | 202 +++++++++++++++++++++++++++++++++++++++ src/swarm/test.ts | 69 +++++++++++++ src/swarm/types.ts | 42 ++++++++ 8 files changed, 501 insertions(+) create mode 100644 src/swarm/index.ts create mode 100644 src/swarm/repl.ts create mode 100644 src/swarm/swarm-tools.ts create mode 100644 src/swarm/swarm.ts create mode 100644 src/swarm/test.ts create mode 100644 src/swarm/types.ts diff --git a/package.json b/package.json index 81f762e..c6ed349 100644 --- a/package.json +++ b/package.json @@ -26,6 +26,10 @@ "./extract": { "types": "./dist/extract/index.d.ts", "import": "./dist/extract/index.js" + }, + "./swarm": { + "types": "./dist/swarm/index.d.ts", + "import": "./dist/swarm/index.js" } }, "sideEffects": false, @@ -68,6 +72,7 @@ "@dexaai/eslint-config": "^1.3.6", "@sentry/node": "^8.34.0", "@types/node": "^20.14.11", + "chalk": "^5.3.0", "dotenv-cli": "^7.4.2", "eslint": "^8.57.0", "knip": "^5.33.3", diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 711752b..ab60087 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -63,6 +63,9 @@ importers: '@types/node': specifier: ^20.14.11 version: 20.14.11 + chalk: + specifier: ^5.3.0 + version: 5.3.0 dotenv-cli: specifier: ^7.4.2 version: 7.4.2 diff --git a/src/swarm/index.ts b/src/swarm/index.ts new file mode 100644 index 0000000..f23d766 --- /dev/null +++ b/src/swarm/index.ts @@ -0,0 +1,3 @@ +export { runSwarmRepl } from './repl.js'; +export { Swarm } from './swarm.js'; +export { swarmFunction, swarmHandoff } from './swarm-tools.js'; diff --git a/src/swarm/repl.ts b/src/swarm/repl.ts new file mode 100644 index 0000000..68d86eb --- /dev/null +++ b/src/swarm/repl.ts @@ -0,0 +1,79 @@ +import * as readline from 'node:readline'; + +import chalk from 'chalk'; + +import { type Msg, MsgUtil } from '../model/index.js'; +import { Swarm } from './swarm.js'; +import { type Agent } from './types.js'; + +function prettyPrintMessages(messages: Msg[]): void { + for (const message of messages) { + // Print tool results + if (MsgUtil.isToolResult(message)) { + const { content } = message; + console.log(`<== ${chalk.green(content)}`); + } + + if (message.role !== 'assistant') continue; + + // Print agent name in blue + if ('name' in message) { + process.stdout.write(`${chalk.blue(message.name || '')}: `); + } + + // Print response, if any + if (message.content) { + console.log(message.content); + } + + // Print tool calls in purple, if any + const toolCalls = MsgUtil.isToolCall(message) ? message.tool_calls : []; + if (toolCalls.length > 1) { + console.log(); + } + for (const toolCall of toolCalls) { + const { name, arguments: args } = toolCall.function; + const argObj = JSON.parse(args); + const argStr = JSON.stringify(argObj).replace(/:/g, '='); + if (name.startsWith('transfer_')) { + console.log(`<> ${chalk.yellow(name)}(${argStr.slice(1, -1)})`); + } else { + console.log(`${chalk.magenta(name)}(${argStr.slice(1, -1)})`); + } + } + } +} + +export async function runSwarmRepl( + startingAgent: Agent, + contextVariables: Record = {} +): Promise { + const client = new Swarm(); + console.log('Swarm initialized.'); + + let messages: Msg[] = []; + let agent = startingAgent; + + const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, + }); + + while (true) { + const userInput = await new Promise((resolve) => { + rl.question(`${chalk.gray('User')}: `, resolve); + }); + + messages.push({ role: 'user', content: userInput }); + + const response = await client.run({ + agent, + messages, + ctx: contextVariables, + }); + + prettyPrintMessages(response.messages); + messages = messages.concat(response.messages); + agent = response.agent; + } +} diff --git a/src/swarm/swarm-tools.ts b/src/swarm/swarm-tools.ts new file mode 100644 index 0000000..5533ede --- /dev/null +++ b/src/swarm/swarm-tools.ts @@ -0,0 +1,98 @@ +import { z } from 'zod'; + +import { + extractZodObject, + type Msg, + MsgUtil, + zodToJsonSchema, +} from '../model/index.js'; +import { getErrorMsg } from '../model/utils/errors.js'; +import { cleanString } from '../model/utils/message-util.js'; +import { type Agent, type SwarmFunc } from './types.js'; + +export function swarmFunction, Return>( + spec: { + /** Name of the function. */ + name: string; + /** Description of the function. */ + description?: string; + /** Zod schema for the arguments string. */ + argsSchema: Schema; + }, + /** Implementation of the function to call with the parsed arguments. */ + implementation: (params: z.infer) => Promise +): SwarmFunc { + /** Parse the arguments string, optionally reading from a message. */ + const parseArgs = (input: string | Msg) => { + if (typeof input === 'string') { + return extractZodObject({ schema: spec.argsSchema, json: input }); + } else if (MsgUtil.isFuncCall(input)) { + const args = input.function_call.arguments; + return extractZodObject({ schema: spec.argsSchema, json: args }); + } else { + throw new Error('Invalid input type'); + } + }; + + // Call the implementation function with the parsed arguments. + const aiFunction = async (input: string | Msg) => { + const parsedArgs = parseArgs(input); + const result = await implementation(parsedArgs); + try { + const resultStr = + typeof result === 'string' ? result : JSON.stringify(result); + return { value: resultStr }; + } catch (err) { + console.error(`Error stringifying function ${spec.name} result:`, err); + const errMsg = getErrorMsg(err); + return { + value: `Error stringifying function ${spec.name} result: ${errMsg}`, + }; + } + }; + + aiFunction.parseArgs = parseArgs; + aiFunction.argsSchema = spec.argsSchema; + aiFunction.spec = { + name: spec.name, + description: cleanString(spec.description ?? ''), + parameters: zodToJsonSchema(spec.argsSchema), + }; + + return aiFunction; +} + +/** This is a simple no-op function that can be used to transfer context to another agent. */ +export function swarmHandoff(args: { + agent: Agent; + description?: string; +}): SwarmFunc { + const { agent, description } = args; + const schema = z.object({}); + + /** Parse the arguments string, optionally reading from a message. */ + const parseArgs = (input: string | Msg) => { + if (typeof input === 'string') { + return extractZodObject({ schema, json: input }); + } else if (MsgUtil.isFuncCall(input)) { + const args = input.function_call.arguments; + return extractZodObject({ schema, json: args }); + } else { + throw new Error(`Invalid input type`); + } + }; + const aiFunction = async () => { + const value = `Transfered to ${agent.name}. Adopt the role and responsibilities of ${agent.name} and continue the conversation.`; + return { value, agent }; + }; + + aiFunction.parseArgs = parseArgs; + aiFunction.argsSchema = schema; + aiFunction.spec = { + name: `transfer_to_${agent.name}`, + description: description || '', + parameters: zodToJsonSchema(schema), + }; + + return aiFunction; +} diff --git a/src/swarm/swarm.ts b/src/swarm/swarm.ts new file mode 100644 index 0000000..2507adb --- /dev/null +++ b/src/swarm/swarm.ts @@ -0,0 +1,202 @@ +import pMap from 'p-map'; + +import { ChatModel } from '../model/chat.js'; +import { MsgUtil } from '../model/index.js'; +import { type Msg } from '../model/types.js'; +import { getErrorMsg } from '../model/utils/errors.js'; +import { + type Agent, + type CtxVal, + type SwarmFunc, + type SwarmFuncResult, + type SwarmResponse, +} from './types.js'; + +export class Swarm { + chatModel: ChatModel; + defaultModel: string; + + constructor(args?: { chatModel?: ChatModel }) { + this.defaultModel = 'gpt-4o'; + this.chatModel = + args?.chatModel || + new ChatModel({ params: { model: this.defaultModel } }); + } + + async run(args: { + agent: Agent; + messages: Msg[]; + ctx?: CtxVal; + modelOverride?: string; + maxTurns?: number; + }): Promise { + const { agent, messages, modelOverride, maxTurns = Infinity } = args; + + let activeAgent: Agent = agent; + const ctx: CtxVal = { ...args.ctx }; + const history: Msg[] = [...messages]; + const initLen = messages.length; + + while (history.length - initLen < maxTurns && activeAgent) { + // HOOK: beforeGetChatCompletion(??) => Promise + // - Use this to manage conversation history length + + // Get completion with current history, agent + const message = await this.getChatCompletion({ + agent: activeAgent, + history, + ctx, + modelOverride, + }); + + history.push({ ...message }); + + // HOOK: afterGetChatCompletion(??) => Promise + // - Use this to post-process the message from the model + + // HOOK: shouldEndTurn(??) => Promise + // - Use this to end the conversation early (eg: tool to halt loop) + + if (!MsgUtil.isToolCall(message)) { + break; + } + + // Handle function calls, updating context_variables, and switching agents + const partialResponse = await this.handleToolCalls({ + message, + functions: activeAgent.functions || [], + functionCallConcurrency: 1, + ctx, + }); + + history.push(...partialResponse.messages); + Object.assign(ctx, partialResponse.ctx); + + if (partialResponse.agent) { + activeAgent = partialResponse.agent; + } + } + + return { + messages: history.slice(initLen), + agent: activeAgent, + ctx, + }; + } + + private async getChatCompletion(args: { + agent: Agent; + history: Msg[]; + ctx: CtxVal; + modelOverride?: string; + }): Promise { + const { agent, history, modelOverride } = args; + const ctx: CtxVal = { ...args.ctx }; + const instructions = + typeof agent.instructions === 'function' + ? agent.instructions(ctx) + : agent.instructions; + const messages: Msg[] = [ + { role: 'system', content: instructions }, + ...history, + ]; + + const tools = agent.functions.map((func) => ({ + function: func.spec, + type: 'function' as const, + })); + + const response = await this.chatModel.run({ + messages, + model: modelOverride || agent.model || this.defaultModel, + tools: tools.length > 0 ? tools : undefined, + // handleUpdate: (c) => console.log('ChatModel.run update:', c), + }); + + if (MsgUtil.isToolCall(response.message)) { + return response.message; + } else if (MsgUtil.isAssistant(response.message)) { + return { ...response.message, name: agent.name }; + } else { + // TODO: not sure when this would happen so log and cast for now... + console.error('Unexpected message type:', response.message); + return response.message as unknown as Msg.Assistant; + } + } + + /** + * Handle messages that require calling functions. + * @returns An array of the new messages from the function calls + * Note: Does not include args.message in the returned array + */ + private async handleToolCalls(args: { + message: Msg; + functions?: SwarmFunc[]; + functionCallConcurrency?: number; + ctx: CtxVal; + }): Promise<{ + messages: Msg[]; + agent?: Agent; + ctx: CtxVal; + }> { + const { ctx, message, functions = [], functionCallConcurrency = 8 } = args; + const messages: Msg[] = [message]; + const funcMap = this.getFuncMap(functions); + let agent: Agent | undefined; + + // Run all the tool_calls functions and add the result messages. + if (MsgUtil.isToolCall(message)) { + await pMap( + message.tool_calls, + async (toolCall) => { + const result = await this.callFunction({ + ...toolCall.function, + funcMap, + }); + messages.push(MsgUtil.toolResult(result.value, toolCall.id)); + if (result.agent) { + agent = result.agent; + } + }, + { concurrency: functionCallConcurrency } + ); + } + + return { + messages: messages.slice(1), + agent, + ctx, + }; + } + + /** Call a function and return the result. */ + private async callFunction(args: { + name: string; + arguments: string; + funcMap: Map; + }): Promise { + const { arguments: funcArgs, name, funcMap } = args; + + const func = funcMap.get(name); + if (!func) { + console.error(`Tool ${name} not found in function map.`); + return { value: `Error: Tool ${name} not found.` }; + } + + try { + return await func(funcArgs); + } catch (err) { + const errMsg = getErrorMsg(err); + console.error(`Error running function ${name}:`, err); + return { value: `Error running function ${name}: ${errMsg}` }; + } + } + + /** Create a map of function names to functions for easy lookup. */ + private getFuncMap(functions: SwarmFunc[]): Map { + return functions.reduce((map, func) => { + map.set(func.spec.name, func); + return map; + }, new Map()); + } +} diff --git a/src/swarm/test.ts b/src/swarm/test.ts new file mode 100644 index 0000000..dce0adc --- /dev/null +++ b/src/swarm/test.ts @@ -0,0 +1,69 @@ +import { z } from 'zod'; + +import { runSwarmRepl } from './repl.js'; +import { swarmFunction, swarmHandoff } from './swarm-tools.js'; + +const weatherFunc = swarmFunction( + { + name: 'get_weather', + description: 'Gets the weather for a given location', + argsSchema: z.object({ + location: z + .string() + .describe('The city and state e.g. San Francisco, CA'), + }), + }, + async ({ location }) => { + await new Promise((resolve) => setTimeout(resolve, 500)); + const temperature = (30 + Math.random() * 70) | 0; + return { location, temperature }; + } +); + +const weatherAgent = { + name: 'weather', + instructions: 'Get the weather for a given location.', + functions: [weatherFunc], +}; + +const calendarFunc = swarmFunction( + { + name: 'get_calendar', + description: 'Gets the calendar events for today', + argsSchema: z.object({}), + }, + async () => { + return Promise.resolve( + 'Calendar events for today: Go to the central park zoo at 10am' + ); + } +); + +const calendarAgent = { + name: 'calendar', + instructions: 'Get the calendar events for today.', + functions: [calendarFunc], +}; + +const weatherHandoff = swarmHandoff({ agent: weatherAgent }); +const calendarHandoff = swarmHandoff({ agent: calendarAgent }); + +const dispatchAgent = { + name: 'dispatch', + instructions: + 'You are a helpful assistant that can dispatch tasks to other agents.', + functions: [weatherHandoff, calendarHandoff], +}; + +const dispatchHandoff = swarmHandoff({ + agent: dispatchAgent, + description: + "Transfer to the dispatch agent if you aren't able to handle the user's request.", +}); +weatherAgent.functions.push(dispatchHandoff); +calendarAgent.functions.push(dispatchHandoff); + +await runSwarmRepl(dispatchAgent).catch((err) => { + console.error(err); + process.exit(1); +}); diff --git a/src/swarm/types.ts b/src/swarm/types.ts new file mode 100644 index 0000000..eb5b15a --- /dev/null +++ b/src/swarm/types.ts @@ -0,0 +1,42 @@ +import { type z } from 'zod'; + +import { type Msg } from '../model/types.js'; + +export type CtxVal = Record; + +export type Agent = { + name: string; + model?: string; + functions: SwarmFunc[]; // eslint-disable-line no-use-before-define + instructions: string | ((args: CtxVal) => string); + toolChoice?: string; + parallelToolCalls?: boolean; +}; + +export type SwarmFuncResult = { + value: string; + agent?: Agent; + ctx?: CtxVal; +}; + +export interface SwarmFunc = z.ZodObject> { + // TODO: add support for context injection + /** The implementation of the function, with arg parsing and validation. */ + (input: string | Msg): Promise; + /** The Zod schema for the arguments string. */ + argsSchema: Schema; + /** Parse the function arguments from a message. */ + parseArgs(input: string | Msg): z.infer; + /** The function spec for the OpenAI API `functions` property. */ + spec: { + name: string; + description?: string; + parameters: Record; + }; +} + +export type SwarmResponse = { + messages: Msg[]; + agent: Agent; + ctx: CtxVal; +};