diff --git a/src/config/modelProviders/cloudflare.ts b/src/config/modelProviders/cloudflare.ts index e1a5f55b91a2..5516ae5db2bd 100644 --- a/src/config/modelProviders/cloudflare.ts +++ b/src/config/modelProviders/cloudflare.ts @@ -1,7 +1,6 @@ import { ModelProviderCard } from '@/types/llm'; // ref https://developers.cloudflare.com/workers-ai/models/#text-generation -// api https://developers.cloudflare.com/workers-ai/configuration/open-ai-compatibility const Cloudflare: ModelProviderCard = { chatModels: [ { @@ -19,7 +18,7 @@ const Cloudflare: ModelProviderCard = { { displayName: 'hermes-2-pro-mistral-7b', enabled: true, - // functionCall: true, + functionCall: true, id: '@hf/nousresearch/hermes-2-pro-mistral-7b', tokens: 4096, }, @@ -78,6 +77,7 @@ const Cloudflare: ModelProviderCard = { }, ], checkModel: '@hf/meta-llama/meta-llama-3-8b-instruct', + disableBrowserRequest: false, id: 'cloudflare', modelList: { showModelFetcher: true, diff --git a/src/libs/agent-runtime/cloudflare/index.test.ts b/src/libs/agent-runtime/cloudflare/index.test.ts index 1e8535f07e19..9fa8363674c1 100644 --- a/src/libs/agent-runtime/cloudflare/index.test.ts +++ b/src/libs/agent-runtime/cloudflare/index.test.ts @@ -3,6 +3,8 @@ import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { ChatCompletionTool } from '@/libs/agent-runtime'; +import type { OpenAIChatMessage } from '../types'; +import * as CloudflareUtils from '../utils/cloudflareHelpers'; import * as debugStreamModule from '../utils/debugStream'; import { LobeCloudflareAI } from './index'; @@ -272,12 +274,47 @@ describe('LobeCloudflareAI', () => { }); describe('chat with tools', () => { - it('should call client.beta.tools.messages.create when tools are provided', async () => { + const tools: ChatCompletionTool[] = [ + { function: { name: 'tool1', description: 'desc1' }, type: 'function' }, + ]; + + it('should disable stream when tools are provided', async () => { + // Act & Assert + await instance.chat({ + messages: [{ content: 'Hello', role: 'user' }], + model: '@hf/meta-llama/meta-llama-3-8b-instruct', + temperature: 1, + tools, + }); + expect(globalThis.fetch).toHaveBeenCalled(); + + const fetchCallArgs = (globalThis.fetch as Mock).mock.calls[0]; + const body = JSON.parse(fetchCallArgs[1].body); + expect(body).toEqual( + expect.objectContaining({ + stream: false, + }), + ); + }); + + it('should remove plugin info from messages', async () => { // Arrange - const tools: ChatCompletionTool[] = [ - { function: { name: 'tool1', description: 'desc1' }, type: 'function' }, - ]; + vi.spyOn(CloudflareUtils, 'removePluginInfo').mockImplementation((messages) => messages); + const messages: OpenAIChatMessage[] = [{ content: 'Hello', role: 'user' }]; + + // Act + await instance.chat({ + messages, + model: '@hf/meta-llama/meta-llama-3-8b-instruct', + temperature: 1, + tools, + }); + + // Assert + expect(CloudflareUtils.removePluginInfo).toHaveBeenCalledWith(messages); + }); + it('should include tools', async () => { // Act await instance.chat({ messages: [{ content: 'Hello', role: 'user' }], diff --git a/src/libs/agent-runtime/cloudflare/index.ts b/src/libs/agent-runtime/cloudflare/index.ts index 885e3fd7543b..db01387be855 100644 --- a/src/libs/agent-runtime/cloudflare/index.ts +++ b/src/libs/agent-runtime/cloudflare/index.ts @@ -9,6 +9,7 @@ import { convertModelManifest, desensitizeCloudflareUrl, fillUrl, + removePluginInfo, } from '../utils/cloudflareHelpers'; import { AgentRuntimeError } from '../utils/createError'; import { debugStream } from '../utils/debugStream'; @@ -47,7 +48,9 @@ export class LobeCloudflareAI implements LobeRuntimeAI { async chat(payload: ChatStreamPayload, options?: ChatCompetitionOptions): Promise { try { - const { model, tools, ...restPayload } = payload; + const { messages: _messages, model, stream: _stream, tools, ...restPayload } = payload; + const messages = tools ? removePluginInfo(_messages) : _messages; + const stream = tools ? false : _stream; const functions = tools?.map((tool) => tool.function); const headers = options?.headers || {}; if (this.apiKey) { @@ -55,7 +58,12 @@ export class LobeCloudflareAI implements LobeRuntimeAI { } const url = new URL(model, this.baseURL); const response = await fetch(url, { - body: JSON.stringify({ tools: functions, ...restPayload }), + body: JSON.stringify({ + messages, + stream, + tools: functions, + ...restPayload, + }), headers: { 'Content-Type': 'application/json', ...headers }, method: 'POST', signal: options?.signal, @@ -86,7 +94,7 @@ export class LobeCloudflareAI implements LobeRuntimeAI { return StreamingResponse( responseBody - .pipeThrough(new TransformStream(new CloudflareStreamTransformer())) + .pipeThrough(new TransformStream(new CloudflareStreamTransformer(stream))) .pipeThrough(createCallbacksTransformer(options?.callback)), { headers: options?.headers }, ); diff --git a/src/libs/agent-runtime/utils/cloudflareHelpers.test.ts b/src/libs/agent-runtime/utils/cloudflareHelpers.test.ts index 3a69bd7c6994..d92f874f179c 100644 --- a/src/libs/agent-runtime/utils/cloudflareHelpers.test.ts +++ b/src/libs/agent-runtime/utils/cloudflareHelpers.test.ts @@ -1,6 +1,7 @@ // @vitest-environment node import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import type { OpenAIChatMessage } from '../types'; import * as desensitizeTool from '../utils/desensitizeUrl'; import { CloudflareStreamTransformer, @@ -10,20 +11,15 @@ import { getModelDisplayName, getModelFunctionCalling, getModelTokens, + removePluginInfo, } from './cloudflareHelpers'; -//const { -// getModelBeta, -// getModelDisplayName, -// getModelFunctionCalling, -// getModelTokens, -//} = require('./cloudflareHelpers'); - -//const cloudflareHelpers = require('./cloudflareHelpers'); -//const getModelBeta = cloudflareHelpers.__get__('getModelBeta'); -//const getModelDisplayName = cloudflareHelpers.__get__('getModelDisplayName'); -//const getModelFunctionCalling = cloudflareHelpers.__get__('getModelFunctionCalling'); -//const getModelTokens = cloudflareHelpers.__get__('getModelTokens'); +declare module './cloudflareHelpers' { + function getModelBeta(model: any): boolean; + function getModelDisplayName(model: any, beta: boolean): string; + function getModelFunctionCalling(model: any): boolean; + function getModelTokens(model: any): number | undefined; +} afterEach(() => { vi.restoreAllMocks(); @@ -36,6 +32,68 @@ describe('cloudflareHelpers', () => { transformer = new CloudflareStreamTransformer(); }); + describe('constructor', () => { + describe('stream', () => { + const testCases = [true, false, undefined]; + testCases.forEach((stream) => { + it(`should set stream to ${stream}`, () => { + // Act + const transformer = new CloudflareStreamTransformer(stream); + + // Assert + expect(transformer['stream']).toBe(stream); + }); + }); + }); + }); + + describe('transformDispatch', () => { + let chunk: Uint8Array; + let controller: TransformStreamDefaultController; + beforeEach(() => { + chunk = new Uint8Array(); + controller = Object.create(TransformStreamDefaultController.prototype); + if (!transformer) { + throw new Error('transformer is undefined'); + } + vi.spyOn(transformer as any, 'transformStream').mockImplementation(async (_, __) => {}); + vi.spyOn(transformer as any, 'transformNonStream').mockImplementation(async (_, __) => {}); + }); + + it('should call transformStream when stream is true', async () => { + // Arrange + transformer['stream'] = true; + + // Act + await transformer.transform(chunk, controller); + + // Assert + expect(transformer['transformStream']).toHaveBeenCalled(); // Why does toHaveBeenCalledWith here throw undefined error? + }); + + it('should call transformStream when stream is undefined', async () => { + // Arrange + transformer['stream'] = undefined; + + // Act + await transformer.transform(chunk, controller); + + // Assert + expect(transformer['transformStream']).toHaveBeenCalled(); + }); + + it('should call transformNonStream when stream is undefined', async () => { + // Arrange + transformer['stream'] = false; + + // Act + await transformer.transform(chunk, controller); + + // Assert + expect(transformer['transformNonStream']).toHaveBeenCalled(); + }); + }); + describe('parseChunk', () => { let chunks: string[]; let controller: TransformStreamDefaultController; @@ -51,7 +109,6 @@ describe('cloudflareHelpers', () => { it('should parse chunk', () => { // Arrange const chunk = 'data: {"key": "value", "response": "response1"}'; - const textDecoder = new TextDecoder(); // Act transformer['parseChunk'](chunk, controller); @@ -65,7 +122,6 @@ describe('cloudflareHelpers', () => { it('should not replace `data` in text', () => { // Arrange const chunk = 'data: {"key": "value", "response": "data: a"}'; - const textDecoder = new TextDecoder(); // Act transformer['parseChunk'](chunk, controller); @@ -75,113 +131,337 @@ describe('cloudflareHelpers', () => { expect(chunks[0]).toBe('event: text\n'); expect(chunks[1]).toBe('data: "data: a"\n\n'); }); + + it('should stop at <|im_end|>', () => { + // Arrange + const chunk = 'data: {"key": "value", "response": "<|im_end|>"}'; + + // Act + transformer['parseChunk'](chunk, controller); + + // Assert + expect(chunks.length).toBe(2); + expect(chunks[0]).toBe('event: stop\n'); + expect(chunks[1]).toBe('data: "<|im_end|>"\n\n'); + }); }); describe('transform', () => { - const textDecoder = new TextDecoder(); const textEncoder = new TextEncoder(); let chunks: string[]; beforeEach(() => { chunks = []; - vi.spyOn( - transformer as any as { - parseChunk: (chunk: string, controller: TransformStreamDefaultController) => void; - }, - 'parseChunk', - ).mockImplementation((chunk: string, _) => { - chunks.push(chunk); - }); }); - it('should split single chunk', async () => { - // Arrange - const chunk = textEncoder.encode('data: {"key": "value", "response": "response1"}\n\n'); + describe('transformStream', () => { + beforeEach(() => { + vi.spyOn( + transformer as any as { + parseChunk: (chunk: string, controller: TransformStreamDefaultController) => void; + }, + 'parseChunk', + ).mockImplementation((chunk: string, _) => { + chunks.push(chunk); + }); + }); - // Act - await transformer.transform(chunk, undefined!); + it('should split single chunk', async () => { + // Arrange + const chunk = textEncoder.encode('data: {"key": "value", "response": "response1"}\n\n'); - // Assert - expect(chunks.length).toBe(1); - expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); - }); + // Act + await transformer['transformStream'](chunk, undefined!); - it('should split multiple chunks', async () => { - // Arrange - const chunk = textEncoder.encode( - 'data: {"key": "value", "response": "response1"}\n\n' + - 'data: {"key": "value", "response": "response2"}\n\n', - ); + // Assert + expect(chunks.length).toBe(1); + expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); + }); - // Act - await transformer.transform(chunk, undefined!); + it('should split multiple chunks', async () => { + // Arrange + const chunk = textEncoder.encode( + 'data: {"key": "value", "response": "response1"}\n\n' + + 'data: {"key": "value", "response": "response2"}\n\n', + ); - // Assert - expect(chunks.length).toBe(2); - expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); - expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}'); - }); + // Act + await transformer['transformStream'](chunk, undefined!); - it('should ignore empty chunk', async () => { - // Arrange - const chunk = textEncoder.encode('\n\n'); + // Assert + expect(chunks.length).toBe(2); + expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); + expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}'); + }); - // Act - await transformer.transform(chunk, undefined!); + it('should ignore empty chunk', async () => { + // Arrange + const chunk = textEncoder.encode('\n\n'); - // Assert - expect(chunks.join()).toBe(''); + // Act + await transformer['transformStream'](chunk, undefined!); + + // Assert + expect(chunks.join()).toBe(''); + }); + + it('should split and concat delayed chunks', async () => { + // Arrange + const chunk1 = textEncoder.encode('data: {"key": "value", "respo'); + const chunk2 = textEncoder.encode('nse": "response1"}\n\ndata: {"key": "val'); + const chunk3 = textEncoder.encode('ue", "response": "response2"}\n\n'); + + // Act & Assert + await transformer['transformStream'](chunk1, undefined!); + expect(transformer['parseChunk']).not.toHaveBeenCalled(); + expect(chunks.length).toBe(0); + expect(transformer['buffer']).toBe('data: {"key": "value", "respo'); + + await transformer['transformStream'](chunk2, undefined!); + expect(chunks.length).toBe(1); + expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); + expect(transformer['buffer']).toBe('data: {"key": "val'); + + await transformer['transformStream'](chunk3, undefined!); + expect(chunks.length).toBe(2); + expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}'); + expect(transformer['buffer']).toBe(''); + }); + + it('should ignore standalone [DONE]', async () => { + // Arrange + const chunk = textEncoder.encode('data: [DONE]\n\n'); + + // Act + await transformer['transformStream'](chunk, undefined!); + + // Assert + expect(transformer['parseChunk']).not.toHaveBeenCalled(); + expect(chunks.length).toBe(0); + expect(transformer['buffer']).toBe(''); + }); + + it('should ignore [DONE] in chunk', async () => { + // Arrange + const chunk = textEncoder.encode( + 'data: {"key": "value", "response": "response1"}\n\ndata: [DONE]\n\n', + ); + + // Act + await transformer['transformStream'](chunk, undefined!); + + // Assert + expect(chunks.length).toBe(1); + expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); + expect(transformer['buffer']).toBe(''); + }); }); - it('should split and concat delayed chunks', async () => { - // Arrange - const chunk1 = textEncoder.encode('data: {"key": "value", "respo'); - const chunk2 = textEncoder.encode('nse": "response1"}\n\ndata: {"key": "val'); - const chunk3 = textEncoder.encode('ue", "response": "response2"}\n\n'); - - // Act & Assert - await transformer.transform(chunk1, undefined!); - expect(transformer['parseChunk']).not.toHaveBeenCalled(); - expect(chunks.length).toBe(0); - expect(transformer['buffer']).toBe('data: {"key": "value", "respo'); - - await transformer.transform(chunk2, undefined!); - expect(chunks.length).toBe(1); - expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); - expect(transformer['buffer']).toBe('data: {"key": "val'); - - await transformer.transform(chunk3, undefined!); - expect(chunks.length).toBe(2); - expect(chunks[1]).toBe('data: {"key": "value", "response": "response2"}'); - expect(transformer['buffer']).toBe(''); + describe('transformNonStream', () => { + let controller: TransformStreamDefaultController; + + beforeEach(() => { + controller = Object.create(TransformStreamDefaultController.prototype); + vi.spyOn(controller, 'enqueue').mockImplementation((chunk) => { + chunks.push(chunk); + }); + }); + + it('should parse text response', async () => { + // Arrange + const chunk = textEncoder.encode('{"result": {"key": "value", "response": "result1"}}'); + + // Act + await transformer['transformNonStream'](chunk, controller); + + // Assert + expect(chunks.length).toBe(2); + expect(chunks[0]).toBe('event: text\n'); + expect(chunks[1]).toBe('data: "result1"\n\n'); + }); + + it('should parse tool response', async () => { + // Arrange + const chunk = textEncoder.encode( + '{"result": {"response": null, "tool_calls": [{"key": "value"}]}}', + ); + vi.spyOn(CloudflareStreamTransformer as any, 'enqueueToolCalls').mockImplementation( + (_, controller: any) => { + controller.enqueue('event: tool_calls\n'); + controller.enqueue('data: [{"key": "value"}]\n\n'); + }, + ); + + // Act + await transformer['transformNonStream'](chunk, controller); + + // Assert + expect(chunks.length).toBe(2); + expect(chunks[0]).toBe('event: tool_calls\n'); + expect(chunks[1]).toBe('data: [{"key": "value"}]\n\n'); + expect(CloudflareStreamTransformer['enqueueToolCalls']).toHaveBeenCalled(); + }); + + it('should combine text and tool response', async () => { + // Arrange + const chunk = textEncoder.encode( + '{"result": {"key": "value", "response": "result1", "tool_calls": [{"key": "value"}]}}', + ); + vi.spyOn(CloudflareStreamTransformer as any, 'enqueueToolCalls').mockImplementation( + (_, controller: any) => { + controller.enqueue('event: tool_calls\n'); + controller.enqueue('data: [{"key": "value"}]\n\n'); + }, + ); + + // Act + await transformer['transformNonStream'](chunk, controller); + + // Assert + expect(chunks.length).toBe(4); + expect(chunks[0]).toBe('event: text\n'); + expect(chunks[1]).toBe('data: "result1"\n\n'); + expect(chunks[2]).toBe('event: tool_calls\n'); + expect(chunks[3]).toBe('data: [{"key": "value"}]\n\n'); + expect(CloudflareStreamTransformer['enqueueToolCalls']).toHaveBeenCalled(); + }); }); - it('should ignore standalone [DONE]', async () => { - // Arrange - const chunk = textEncoder.encode('data: [DONE]\n\n'); + describe('getRandomId', () => { + it('should contain prefix', () => { + // Arrange + const prefix = 'prefix'; + const length = 8; - // Act - await transformer.transform(chunk, undefined!); + // Act + const id = CloudflareStreamTransformer['getRandomId'](prefix, length); - // Assert - expect(transformer['parseChunk']).not.toHaveBeenCalled(); - expect(chunks.length).toBe(0); - expect(transformer['buffer']).toBe(''); + // Assert + expect(id).toSatisfy((id: string) => id.startsWith(prefix)); + }); + + it('should have correct length', () => { + // Arrange + const prefix = 'prefix'; + const length = 8; + const expectedLength = prefix.length + length; + + // Act + const id = CloudflareStreamTransformer['getRandomId'](prefix, length); + const idLength = id.length; + + // Assert + expect(idLength).toBe(expectedLength); + }); + + it('should contain only alphanumeric characters', () => { + // Arrange + const prefix = ''; + const length = 32; + + // Act + const id = CloudflareStreamTransformer['getRandomId'](prefix, length); + + // Assert + expect(id).toMatch(/^[a-zA-Z0-9]+$/); + }); + + it('should be unique', () => { + // Arrange + const prefix = ''; + const length = 8; + const arrLen = 16; + + // Act + const ids = Array.from({ length: arrLen }, () => + CloudflareStreamTransformer['getRandomId'](prefix, length), + ); + const uniqueIds = new Set(ids); + const uniqueCount = uniqueIds.size; + + // Assert + expect(uniqueCount).toBe(arrLen); + }); }); - it('should ignore [DONE] in chunk', async () => { - // Arrange - const chunk = textEncoder.encode( - 'data: {"key": "value", "response": "response1"}\n\ndata: [DONE]\n\n', - ); + describe('convertToolCall', () => { + const randomId = 'randomId'; + beforeEach(() => { + vi.spyOn(CloudflareStreamTransformer as any, 'getRandomId').mockReturnValue(randomId); + vi.spyOn(JSON, 'stringify'); + }); - // Act - await transformer.transform(chunk, undefined!); + it('should convert tool call', () => { + // Arrange + const toolCall = { name: 'name', arguments: { key: 'value' } }; + const index = 6; - // Assert - expect(chunks.length).toBe(1); - expect(chunks[0]).toBe('data: {"key": "value", "response": "response1"}'); - expect(transformer['buffer']).toBe(''); + // Act & Assert + const converted = CloudflareStreamTransformer['convertToolCall'](toolCall, index); + expect(converted).toBeInstanceOf(Object); + + const keys = Object.keys(converted); + expect(keys.length).toBe(4); + + expect(keys).toContain('function'); + expect(converted['function']).toBeInstanceOf(Object); + + const functionKeys = Object.keys(converted['function']); + expect(functionKeys.length).toBe(2); + expect(functionKeys).toContain('arguments'); + expect(functionKeys).toContain('name'); + + const _functionArguments = converted['function']['arguments']; + expect(JSON.stringify).toHaveBeenCalledWith(toolCall.arguments); + expect(typeof _functionArguments).toBe('string'); + + const functionArguments = JSON.parse(_functionArguments); + expect(functionArguments).toEqual(toolCall.arguments); + + expect(converted['function']['name']).toBe(toolCall.name); + + expect(keys).toContain('id'); + expect(CloudflareStreamTransformer['getRandomId']).toHaveBeenCalledWith('call_', 24); + expect(converted['id']).toBe(randomId); + + expect(keys).toContain('index'); + expect(converted['index']).toBe(index); + + expect(keys).toContain('type'); + expect(converted['type']).toBe('function'); + }); + }); + + describe('enqueueToolCalls', () => { + let controller: TransformStreamDefaultController; + const convertedToolCall = { name: 'convertedToolCall' }; + + beforeEach(() => { + controller = Object.create(TransformStreamDefaultController.prototype); + vi.spyOn(controller, 'enqueue').mockImplementation((chunk) => { + chunks.push(chunk); + }); + vi.spyOn(CloudflareStreamTransformer as any, 'convertToolCall').mockReturnValue( + convertedToolCall, + ); + }); + + it('should enqueue tool calls', async () => { + // Arrange + const toolCalls = [ + { name: 'name1', arguments: { key1: 'value1', key2: 'value2' } }, + { name: 'name2', arguments: { key: 'value' } }, + ]; + const expected = `data: ${JSON.stringify([convertedToolCall, convertedToolCall])}\n\n`; + + // Act + await CloudflareStreamTransformer['enqueueToolCalls'](toolCalls, controller); + + // Assert + expect(chunks.length).toBe(2); + expect(chunks[0]).toBe('event: tool_calls\n'); + expect(chunks[1]).toBe(expected); + }); }); }); }); @@ -320,6 +600,12 @@ describe('cloudflareHelpers', () => { const functionCalling = getModelFunctionCalling(model); expect(functionCalling).toBe(false); }); + + it('should return false if exception occurs', () => { + const model = {}; + const functionCalling = getModelFunctionCalling(model); + expect(functionCalling).toBe(false); + }); }); describe('getModelTokens', () => { @@ -336,4 +622,57 @@ describe('cloudflareHelpers', () => { }); }); }); + + describe('removePluginInfo', () => { + it('should return messages as is if no plugin info', () => { + // Arrange + const messages: OpenAIChatMessage[] = [ + { content: 'content1', role: 'system' }, + { content: 'content2', role: 'user' }, + ]; + + // Act + const result = removePluginInfo(messages); + + // Assert + expect(result).toEqual(messages); + }); + + it('should remove plugin info', () => { + // Arrange + const system: OpenAIChatMessage = { + content: ` +plugin info +`, + role: 'system', + }; + const user: OpenAIChatMessage = { content: 'content', role: 'user' }; + const messages: OpenAIChatMessage[] = [system, user]; + + // Act + const result = removePluginInfo(messages); + + // Assert + expect(result).toEqual([user]); + }); + + it('should remove plugin info and keep other system messages', () => { + // Arrange + const system: OpenAIChatMessage = { + content: `system + +plugin info +`, + role: 'system', + }; + const user: OpenAIChatMessage = { content: 'content', role: 'user' }; + const messages: OpenAIChatMessage[] = [system, user]; + + // Act + const result = removePluginInfo(messages); + + // Assert + expect(result).toEqual([{ content: 'system\n', role: 'system' }, user]); + }); + }); }); diff --git a/src/libs/agent-runtime/utils/cloudflareHelpers.ts b/src/libs/agent-runtime/utils/cloudflareHelpers.ts index 5fd596f4c7c5..873392c74597 100644 --- a/src/libs/agent-runtime/utils/cloudflareHelpers.ts +++ b/src/libs/agent-runtime/utils/cloudflareHelpers.ts @@ -1,18 +1,52 @@ +import { OpenAIChatMessage } from '../types'; import { desensitizeUrl } from '../utils/desensitizeUrl'; +type CloudflareToolCall = { + arguments: object; + name: string; +}; + +const RANDOM_CHARSET = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'; + class CloudflareStreamTransformer { private textDecoder = new TextDecoder(); private buffer: string = ''; + private stream: boolean | undefined; + private static readonly END_TOKEN = '<|im_end|>'; + + constructor(stream: boolean | undefined = undefined) { + this.stream = stream; + } + + public async transform( + chunk: Uint8Array, + controller: TransformStreamDefaultController, + ): Promise { + if (this.stream === false) { + await this.transformNonStream(chunk, controller); + } else { + await this.transformStream(chunk, controller); + } + } private parseChunk(chunk: string, controller: TransformStreamDefaultController) { const dataPrefix = /^data: /; const json = chunk.replace(dataPrefix, ''); const parsedChunk = JSON.parse(json); + const response = parsedChunk.response; + if (response === CloudflareStreamTransformer.END_TOKEN) { + controller.enqueue(`event: stop\n`); + controller.enqueue(`data: ${JSON.stringify(response)}\n\n`); + return; + } controller.enqueue(`event: text\n`); - controller.enqueue(`data: ${JSON.stringify(parsedChunk.response)}\n\n`); + controller.enqueue(`data: ${JSON.stringify(response)}\n\n`); } - public async transform(chunk: Uint8Array, controller: TransformStreamDefaultController) { + private async transformStream( + chunk: Uint8Array, + controller: TransformStreamDefaultController, + ): Promise { let textChunk = this.textDecoder.decode(chunk); if (this.buffer.trim() !== '') { textChunk = this.buffer + textChunk; @@ -20,7 +54,8 @@ class CloudflareStreamTransformer { } const splits = textChunk.split('\n\n'); for (let i = 0; i < splits.length - 1; i++) { - if (/\[DONE]/.test(splits[i].trim())) { + const trimmed = splits[i].trim(); + if (/\[DONE]/.test(trimmed)) { return; } this.parseChunk(splits[i], controller); @@ -30,6 +65,55 @@ class CloudflareStreamTransformer { this.buffer += lastChunk; // does not need to be trimmed. } // else drop. } + + private async transformNonStream( + chunk: Uint8Array, + controller: TransformStreamDefaultController, + ): Promise { + const textChunk = this.textDecoder.decode(chunk); + const j = JSON.parse(textChunk); + const result = j['result']; + const response: string | null | undefined = result['response']; + const toolCalls: CloudflareToolCall[] | undefined = result['tool_calls']; + if (response) { + controller.enqueue(`event: text\n`); + controller.enqueue(`data: ${JSON.stringify(response)}\n\n`); + } + if (toolCalls) { + await CloudflareStreamTransformer.enqueueToolCalls(toolCalls, controller); + } + } + + private static async enqueueToolCalls( + toolCalls: CloudflareToolCall[], + controller: TransformStreamDefaultController, + ) { + controller.enqueue(`event: tool_calls\n`); + controller.enqueue( + `data: ${JSON.stringify( + // eslint-disable-next-line unicorn/no-array-callback-reference + toolCalls.map(CloudflareStreamTransformer.convertToolCall), + )}\n\n`, + ); + } + + private static convertToolCall(toolCall: CloudflareToolCall, index: number) { + return { + function: { + arguments: JSON.stringify(toolCall.arguments), + name: toolCall.name, + }, + id: CloudflareStreamTransformer.getRandomId('call_', 24), + index, + type: 'function', + }; + } + + private static getRandomId(prefix: string, length: number): string { + const array = new Uint8Array(length); + crypto.getRandomValues(array); + return prefix + Array.from(array, (n) => RANDOM_CHARSET[n % RANDOM_CHARSET.length]).join(''); + } } const CF_PROPERTY_NAME = 'property_id'; @@ -80,7 +164,6 @@ function getModelDisplayName(model: any, beta: boolean): string { return name; } -// eslint-disable-next-line @typescript-eslint/no-unused-vars, unused-imports/no-unused-vars function getModelFunctionCalling(model: any): boolean { try { const fcProperty = model['properties'].filter( @@ -115,20 +198,50 @@ function convertModelManifest(model: any) { description: model['description'], displayName: getModelDisplayName(model, modelBeta), enabled: !modelBeta, - functionCall: false, //getModelFunctionCalling(model), + functionCall: getModelFunctionCalling(model), id: model['name'], tokens: getModelTokens(model), }; } +const PLUGIN_INFO_REGEX = /(.*<\/plugins_info>)/s; + +function removePluginInfo(messages: OpenAIChatMessage[]): OpenAIChatMessage[] { + const [systemMessage, ...restMesssages] = messages; + if (systemMessage?.role !== 'system') { + // Unlikely + return messages; + } + const message = systemMessage.content as string; + const system = message.replace(PLUGIN_INFO_REGEX, ''); + if (system.trim() === '') { + return restMesssages; + } else { + return [ + { + ...systemMessage, + content: system, + }, + ...restMesssages, + ]; + } +} + export { CloudflareStreamTransformer, convertModelManifest, DEFAULT_BASE_URL_PREFIX, desensitizeCloudflareUrl, fillUrl, - getModelBeta, - getModelDisplayName, - getModelFunctionCalling, - getModelTokens, + removePluginInfo, }; + +if (process?.env?.NODE_ENV === 'test') { + module.exports = { + ...module.exports, + getModelBeta, + getModelDisplayName, + getModelFunctionCalling, + getModelTokens, + }; +}