From 8b8181cd1e61d84e35e27d9fdb64466e8a1b4546 Mon Sep 17 00:00:00 2001 From: rj Date: Mon, 16 Dec 2024 20:02:46 +0000 Subject: [PATCH] (feature) basic function calling (#425) * function calling initial commit * function calling works at a basic level * refactor styles and base classes * add missing file * lint * lint * lint * move interfaces * 3.20.0 --- package-lock.json | 4 +- package.json | 6 +- src/common/constants.ts | 18 +- src/common/tool-definitions.ts | 151 +++++++++ src/common/types.ts | 115 +++++-- src/extension/base.ts | 55 +++- src/extension/chat-service.ts | 223 ++++++------- src/extension/conversation-history.ts | 73 ++--- src/extension/embeddings.ts | 18 +- src/extension/{api.ts => llm.ts} | 23 +- src/extension/provider-manager.ts | 20 +- src/extension/provider-options.ts | 11 +- src/extension/providers/base.ts | 72 ++--- src/extension/providers/completion.ts | 68 ++-- src/extension/providers/panel.ts | 1 - src/extension/providers/sidebar.ts | 2 - src/extension/review-service.ts | 79 ++--- src/extension/symmetry-service.ts | 81 ++--- src/extension/symmetry-ws.ts | 4 +- src/extension/templates.ts | 12 - src/extension/tools.ts | 292 +++++++++++++++++ src/extension/tree.ts | 2 +- src/extension/utils.ts | 57 ++-- src/index.ts | 41 +-- src/webview/assets/locales/en.json | 15 +- src/webview/chat.tsx | 198 +++++++----- src/webview/hooks.ts | 96 +++--- src/webview/main.tsx | 2 +- src/webview/message.tsx | 38 ++- src/webview/provider-select.tsx | 34 +- src/webview/styles/providers.module.css | 4 + src/webview/styles/tool-execution.module.css | 324 +++++++++++++++++++ src/webview/tool-execution.tsx | 174 ++++++++++ src/webview/utils.ts | 11 +- 34 files changed, 1674 insertions(+), 650 deletions(-) create mode 100644 src/common/tool-definitions.ts rename src/extension/{api.ts => llm.ts} (87%) create mode 100644 src/extension/tools.ts create mode 100644 src/webview/styles/tool-execution.module.css create mode 100644 src/webview/tool-execution.tsx diff --git a/package-lock.json b/package-lock.json index 465dba7f..de0e9b5b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "twinny", - "version": "3.19.25", + "version": "3.20.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "twinny", - "version": "3.19.25", + "version": "3.20.0", "cpu": [ "x64", "arm64" diff --git a/package.json b/package.json index 4376a3d3..df47f69d 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "twinny", "displayName": "twinny - AI Code Completion and Chat", "description": "Locally hosted AI code completion plugin for vscode", - "version": "3.19.25", + "version": "3.20.0", "icon": "assets/icon.png", "keywords": [ "code-inference", @@ -159,10 +159,6 @@ "command": "twinny.stopGeneration", "title": "Stop generation" }, - { - "command": "twinny.getGitCommitMessage", - "title": "Generate git commit message" - }, { "command": "twinny.disable", "title": "Disable twinny", diff --git a/src/common/constants.ts b/src/common/constants.ts index 201fcbdb..cda9e7d3 100644 --- a/src/common/constants.ts +++ b/src/common/constants.ts @@ -62,7 +62,7 @@ export const EVENT_NAME = { twinnyNewDocument: "twinny-new-document", twinnyNotification: "twinny-notification", twinnyOnCompletion: "twinny-on-completion", - twinnyOnEnd: "twinny-on-end", + twinnyOnCompletionEnd: "twinny-on-end", twinnyOnLoading: "twinny-on-loading", twinnyOpenDiff: "twinny-open-diff", twinnyRerankThresholdChanged: "twinny-rerank-threshold-changed", @@ -104,7 +104,6 @@ export const TWINNY_COMMAND_NAME = { openPanelChat: "twinny.openPanelChat", openChat: "twinny.openChat", refactor: "twinny.refactor", - sendTerminalText: "twinny.sendTerminalText", settings: "twinny.settings", stopGeneration: "twinny.stopGeneration", templateCompletion: "twinny.templateCompletion", @@ -122,6 +121,12 @@ export const CONVERSATION_EVENT_NAME = { setActiveConversation: "twinny.set-active-conversation", } +export const TOOL_EVENT_NAME = { + runAllTools: "run-all-tools", + runTool: "run-tool", + rejectTool: "run-on-tool", +} + export const PROVIDER_EVENT_NAME = { addProvider: "twinny.add-provider", copyProvider: "twinny.copy-provider", @@ -171,18 +176,19 @@ export const EXTENSION_SETTING_KEY = { export const EXTENSION_CONTEXT_NAME = { twinnyConversationHistory: "twinnyConversationHistory", + twinnyEnableRag: "twinnyEnableRag", + twinnyEnableTools: "twinnyEnableTools", twinnyGeneratingText: "twinnyGeneratingText", twinnyManageProviders: "twinnyManageProviders", twinnyManageTemplates: "twinnyManageTemplates", - twinnyReviewTab: "twinnyReviewTab", - twinnyRerankThreshold: "twinnyRerankThreshold", twinnyMaxChunkSize: "twinnyMaxChunkSize", twinnyMinChunkSize: "twinnyMinChunkSize", twinnyOverlapSize: "twinnyOverlapSize", - twinnyRelevantFilePaths: "twinnyRelevantFilePaths", twinnyRelevantCodeSnippets: "twinnyRelevantCodeSnippets", + twinnyRelevantFilePaths: "twinnyRelevantFilePaths", + twinnyRerankThreshold: "twinnyRerankThreshold", + twinnyReviewTab: "twinnyReviewTab", twinnySymmetryTab: "twinnySymmetryTab", - twinnyEnableRag: "twinnyEnableRag", } export const EXTENSION_SESSION_NAME = { diff --git a/src/common/tool-definitions.ts b/src/common/tool-definitions.ts new file mode 100644 index 00000000..80da1b6c --- /dev/null +++ b/src/common/tool-definitions.ts @@ -0,0 +1,151 @@ +import { JSONSchema7 } from "json-schema" + +import { FunctionTool } from "./types" + +export const tools: FunctionTool[] = [ + { + type: "function", + function: { + name: "openFile", + description: "Open a file in the editor", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "Path to the file relative to workspace root" + }, + preview: { + type: "boolean", + description: "Open in preview mode" + }, + viewColumn: { + type: "string", + enum: ["beside", "active", "new"], + description: "Where to open the file" + }, + encoding: { + type: "string", + description: "File encoding (e.g. 'utf-8')" + }, + revealIfOpen: { + type: "boolean", + description: "If true, reveal the tab if the file is already open" + } + }, + required: ["path"] + } satisfies JSONSchema7 + } + }, + + { + type: "function", + function: { + name: "editFile", + description: "Edit a file", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "Path to the file relative to workspace root" + }, + edit: { + type: "string", + description: "Text edit to apply" + }, + createIfNotExists: { + type: "boolean", + description: "Create file if it doesn't exist" + }, + backupBeforeEdit: { + type: "boolean", + description: "Create a backup copy before editing" + } + }, + required: ["path", "edit"] + } satisfies JSONSchema7 + } + }, + + { + type: "function", + function: { + name: "createFile", + description: "Create a new file in the workspace", + parameters: { + type: "object", + properties: { + path: { + type: "string", + description: "Relative path from workspace root" + }, + content: { + type: "string", + description: "Content to write to file" + }, + openAfterCreate: { + type: "boolean", + description: "Whether to open the file after creation" + }, + createIntermediateDirs: { + type: "boolean", + description: "Create intermediate directories if they don't exist" + }, + fileTemplate: { + type: "string", + description: "Template to use for file content" + }, + permissions: { + type: "string", + description: "File permissions (e.g. '0644')" + } + }, + required: ["path", "content"] + } satisfies JSONSchema7 + } + }, + + { + type: "function", + function: { + name: "runCommand", + description: "Run a shell command in the integrated terminal", + parameters: { + type: "object", + properties: { + command: { + type: "string", + description: "Command to execute" + }, + cwd: { + type: "string", + description: "Working directory relative to workspace root" + }, + env: { + type: "object", + description: "Additional environment variables" + }, + shell: { + type: "string", + description: "Specific shell to use" + }, + timeout: { + type: "number", + description: "Command timeout in milliseconds" + }, + captureOutput: { + type: "boolean", + description: "If true, capture command output and return it" + }, + runInBackground: { + type: "boolean", + description: + "If true, run the command in background without blocking" + } + }, + required: ["command"] + } satisfies JSONSchema7 + } + } +] diff --git a/src/common/types.ts b/src/common/types.ts index d364828a..a1f5c8f4 100644 --- a/src/common/types.ts +++ b/src/common/types.ts @@ -1,3 +1,4 @@ +import type { JSONSchema7 } from "json-schema" import { serverMessageKeys } from "symmetry-core" import { InlineCompletionItem, InlineCompletionList, Uri } from "vscode" @@ -22,6 +23,7 @@ export interface RequestOptionsOllama extends RequestBodyBase { export interface StreamBodyOpenAI extends RequestBodyBase { messages?: Message[] | Message max_tokens: number + tools?: FunctionTool[] } export interface PrefixSuffix { @@ -29,13 +31,19 @@ export interface PrefixSuffix { suffix: string } - export interface RepositoryLevelData { - uri: Uri; - text: string; - name: string; - isOpen: boolean; - relevanceScore: number; + uri: Uri + text: string + name: string + isOpen: boolean + relevanceScore: number +} + +type ToolCall = { + function: { + name: string + arguments: Record + } } export interface StreamResponse { @@ -44,8 +52,9 @@ export interface StreamResponse { response: string content: string message: { - content: string, + content: string role: "assistant" + tool_calls?: ToolCall[] } done: boolean context: number[] @@ -56,14 +65,35 @@ export interface StreamResponse { eval_count: number eval_duration: number type?: string + tool_calls?: ToolCall[] + system_fingerprint: string choices: [ { text: string delta: { content: string } + index: number + message: { + role: "assistant" + content: string + tool_calls?: Array<{ + id: string + type: "function" + function: { + name: string + arguments: string + } + }> + } + finish_reason: "stop" | "tool_calls" } ] + usage: { + prompt_tokens: number + completion_tokens: number + total_tokens: number + } } export interface LanguageType { @@ -82,20 +112,16 @@ export type ClientMessageWithData = ClientMessage & ClientMessage & ClientMessage -export interface ServerMessage { +export interface ServerMessage { type: string - value: { - completion: string - data?: T - error?: boolean - errorMessage?: string - type: string - } + data: T } export interface Message { role: string content: string | undefined + tools?: Record + id?: string } export interface GithubPullRequestMessage { @@ -113,7 +139,7 @@ export interface Conversation { export const Theme = { Light: "Light", Dark: "Dark", - Contrast: "Contrast", + Contrast: "Contrast" } as const export interface DefaultTemplate { @@ -170,10 +196,10 @@ export interface StreamRequestOptions { export interface StreamRequest { body: RequestBodyBase | StreamBodyOpenAI options: StreamRequestOptions - onEnd?: () => void + onEnd?: (response?: StreamResponse) => void onStart?: (controller: AbortController) => void onError?: (error: Error) => void - onData: (streamResponse: T) => void + onData: (streamResponse: StreamResponse) => void } export interface UiTabs { @@ -186,7 +212,7 @@ export const apiProviders = { LMStudio: "lmstudio", Ollama: "ollama", Oobabooga: "oobabooga", - OpenWebUI: "openwebui", + OpenWebUI: "openwebui" } as const export interface ApiModel { @@ -334,3 +360,54 @@ export interface LMStudioEmbedding { usage: LMSEmbeddingUsage } +export type FunctionTool = { + type: "function" + function: { + name: string + description: string + parameters: JSONSchema7 + } +} + +export interface Tool { + id: string + name: string + arguments: Record + status?: string + error?: string +} + +export interface CreateFileArgs { + path: string + content: string + openAfterCreate?: boolean + createIntermediateDirs?: boolean + fileTemplate?: string + permissions?: string +} + +export interface RunCommandArgs { + command: string + cwd?: string + env?: Record + shell?: string + timeout?: number + captureOutput?: boolean + runInBackground?: boolean +} + +export interface OpenFileArgs { + path: string + preview?: boolean + viewColumn?: "beside" | "active" | "new" + encoding?: string + revealIfOpen?: boolean +} + +export interface EditFileArgs { + path: string + edit: string + createIfNotExists?: boolean + backupBeforeEdit?: boolean +} + diff --git a/src/extension/base.ts b/src/extension/base.ts index 66216d03..eec9abf1 100644 --- a/src/extension/base.ts +++ b/src/extension/base.ts @@ -1,9 +1,18 @@ import * as vscode from "vscode" +import { ACTIVE_CHAT_PROVIDER_STORAGE_KEY, EVENT_NAME, EXTENSION_CONTEXT_NAME } from "../common/constants" +import { tools } from "../common/tool-definitions" +import { Message,StreamRequestOptions as LlmRequestOptions } from "../common/types" + +import { TwinnyProvider } from "./provider-manager" +import { createStreamRequestBody } from "./provider-options" + export class Base { public config = vscode.workspace.getConfiguration("twinny") + public context?: vscode.ExtensionContext - constructor () { + constructor (context: vscode.ExtensionContext) { + this.context = context vscode.workspace.onDidChangeConfiguration((event) => { if (!event.affectsConfiguration("twinny")) { return @@ -12,6 +21,50 @@ export class Base { }) } + public getProvider = () => { + const provider = this.context?.globalState.get( + ACTIVE_CHAT_PROVIDER_STORAGE_KEY + ) + return provider + } + + public buildStreamRequest(messages?: Message[] | Message[]) { + const provider = this.getProvider() + + if (!provider) return + + const requestOptions: LlmRequestOptions = { + hostname: provider.apiHostname, + port: provider.apiPort ? Number(provider.apiPort) : undefined, + path: provider.apiPath, + protocol: provider.apiProtocol, + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Bearer ${provider.apiKey}` + } + } + + const useToolsName = `${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyEnableTools}` + const toolsEnabled = this.context?.globalState.get(useToolsName) as number + const functionTools = toolsEnabled ? tools : undefined + + const requestBody = createStreamRequestBody( + provider.provider, + { + model: provider.modelName, + numPredictChat: this.config.numPredictChat, + temperature: this.config.temperature, + messages, + keepAlive: this.config.keepAlive + }, + functionTools + ) + + return { requestOptions, requestBody } + } + + public updateConfig() { this.config = vscode.workspace.getConfiguration("twinny") } diff --git a/src/extension/chat-service.ts b/src/extension/chat-service.ts index a54fa3e2..a35a8a81 100644 --- a/src/extension/chat-service.ts +++ b/src/extension/chat-service.ts @@ -12,7 +12,7 @@ import { } from "vscode" import { - ACTIVE_CHAT_PROVIDER_STORAGE_KEY, + ASSISTANT, DEFAULT_RELEVANT_CODE_COUNT, DEFAULT_RELEVANT_FILE_COUNT, DEFAULT_RERANK_THRESHOLD, @@ -33,28 +33,23 @@ import { ServerMessage, StreamRequestOptions, StreamResponse, - TemplateData + TemplateData, + Tool } from "../common/types" import { kebabToSentence } from "../webview/utils" -import { streamResponse } from "./api" import { Base } from "./base" import { EmbeddingDatabase } from "./embeddings" -import { TwinnyProvider } from "./provider-manager" -import { createStreamRequestBody } from "./provider-options" +import { llm } from "./llm" import { Reranker } from "./reranker" import { SessionManager } from "./session-manager" import { SymmetryService } from "./symmetry-service" import { TemplateProvider } from "./template-provider" -import { - getChatDataFromProvider, - getLanguage, - updateLoadingMessage -} from "./utils" +import { Tools } from "./tools" +import { getLanguage, getResponseData, updateLoadingMessage } from "./utils" export class ChatService extends Base { private _completion = "" - private _context?: ExtensionContext private _controller?: AbortController private _db?: EmbeddingDatabase private _promptTemplate = "" @@ -64,6 +59,8 @@ export class ChatService extends Base { private _templateProvider?: TemplateProvider private _webView?: Webview private _sessionManager: SessionManager | undefined + private _tools?: Tools + private _conversation: Message[] = [] constructor( statusBar: StatusBarItem, @@ -74,16 +71,16 @@ export class ChatService extends Base { sessionManager: SessionManager | undefined, symmetryService: SymmetryService ) { - super() + super(extensionContext) this._webView = webView this._statusBar = statusBar this._templateProvider = new TemplateProvider(templateDir) this._reranker = new Reranker() - this._context = extensionContext this._db = db this._sessionManager = sessionManager this._symmetryService = symmetryService this.setupSymmetryListeners() + this._tools = new Tools(webView, extensionContext) } private setupSymmetryListeners() { @@ -92,11 +89,11 @@ export class ChatService extends Base { (completion: string) => { this._webView?.postMessage({ type: EVENT_NAME.twinnyOnCompletion, - value: { - completion: completion.trimStart(), - data: getLanguage() + data: { + content: completion.trimStart(), + role: ASSISTANT } - } as ServerMessage) + } as ServerMessage) } ) } @@ -113,7 +110,7 @@ export class ChatService extends Base { if (!embedding) return [] const relevantFileCountContext = `${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyRelevantFilePaths}` - const stored = this._context?.globalState.get( + const stored = this.context?.globalState.get( relevantFileCountContext ) as number const relevantFileCount = Number(stored) || DEFAULT_RELEVANT_FILE_COUNT @@ -134,7 +131,7 @@ export class ChatService extends Base { private getRerankThreshold() { const rerankThresholdContext = `${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyRerankThreshold}` - const stored = this._context?.globalState.get( + const stored = this.context?.globalState.get( rerankThresholdContext ) as number const rerankThreshold = stored || DEFAULT_RERANK_THRESHOLD @@ -206,7 +203,7 @@ export class ChatService extends Base { if (await this._db.hasEmbeddingTable(table)) { const relevantCodeCountContext = `${EVENT_NAME.twinnyGlobalContext}-${EXTENSION_CONTEXT_NAME.twinnyRelevantCodeSnippets}` - const stored = this._context?.globalState.get( + const stored = this.context?.globalState.get( relevantCodeCountContext ) as number const relevantCodeCount = Number(stored) || DEFAULT_RELEVANT_CODE_COUNT @@ -270,104 +267,106 @@ export class ChatService extends Base { return "" } - private getProvider = () => { - const provider = this._context?.globalState.get( - ACTIVE_CHAT_PROVIDER_STORAGE_KEY - ) - return provider - } - private buildStreamRequest(messages?: Message[] | Message[]) { - const provider = this.getProvider() - if (!provider) return - const requestOptions: StreamRequestOptions = { - hostname: provider.apiHostname, - port: provider.apiPort ? Number(provider.apiPort) : undefined, - path: provider.apiPath, - protocol: provider.apiProtocol, - method: "POST", - headers: { - "Content-Type": "application/json" - } - } - if (provider.apiKey) { - requestOptions.headers["Authorization"] = `Bearer ${provider.apiKey}` - } + async getMessageTools(data: { + type: "function_call" + calls: Tool[] + }): Promise> { + const tools: Record = {} + if (!data.calls?.length) return {} - const requestBody = createStreamRequestBody(provider.provider, { - model: provider.modelName, - numPredictChat: this.config.numPredictChat, - temperature: this.config.temperature, - messages, - keepAlive: this.config.keepAlive - }) + for (const call of data.calls) { + tools[call.name] = { + arguments: call.arguments, + name: call.name, + status: "pending", + id: call.id + } + } - return { requestOptions, requestBody } + return tools } - private onStreamData = ( - streamResponse: StreamResponse, - onEnd?: (completion: string) => void - ) => { - const provider = this.getProvider() - if (!provider) return - + private onLlmData = async (response: StreamResponse) => { try { - const data = getChatDataFromProvider(provider.provider, streamResponse) - this._completion = this._completion + data - if (onEnd) return + const data = getResponseData(response) + + this._completion = this._completion + data.content + this._webView?.postMessage({ type: EVENT_NAME.twinnyOnCompletion, - value: { - completion: this._completion.trimStart(), - data: getLanguage(), - type: this._promptTemplate + data: { + content: this._completion.trimStart(), + role: ASSISTANT } - } as ServerMessage) + } as ServerMessage) } catch (error) { console.error("Error parsing JSON:", error) return } } - private onStreamEnd = (onEnd?: (completion: string) => void) => { + private onLlmEnd = async (response?: StreamResponse) => { this._statusBar.text = "$(code)" commands.executeCommand( "setContext", EXTENSION_CONTEXT_NAME.twinnyGeneratingText, false ) - if (onEnd) { - onEnd(this._completion) + + if (response) { + const data = getResponseData(response) + + if (data.calls) { + const tools = await this.getMessageTools(data) + + this._webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + content: "Twinny would like to use the following tools:", + role: ASSISTANT, + tools, + id: crypto.randomUUID() + } + } as ServerMessage) + + return + } + this._webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd - } as ServerMessage) + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + content: data.content, + role: ASSISTANT + } + } as ServerMessage) + return } + this._webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - completion: this._completion.trimStart(), - data: getLanguage(), - type: this._promptTemplate + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + content: this._completion.trimStart(), + role: ASSISTANT } - } as ServerMessage) + } as ServerMessage) } - private onStreamError = (error: Error) => { + private onLlmError = (error: Error) => { this._webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - error: true, - errorMessage: `==## ERROR ##== : ${error.message}` // Highlight errors on webview + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + content: `==## ERROR ##== : ${error.message}`, + role: ASSISTANT } } as ServerMessage) } - private onStreamStart = (controller: AbortController) => { + private onLlmStart = (controller: AbortController) => { this._controller = controller commands.executeCommand( "setContext", @@ -390,13 +389,12 @@ export class ChatService extends Base { true ) this._webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - completion: this._completion.trimStart(), - data: getLanguage(), - type: this._promptTemplate + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + content: this._completion.trimStart(), + role: ASSISTANT } - } as ServerMessage) + } as ServerMessage) } private buildTemplatePrompt = async ( @@ -419,41 +417,34 @@ export class ChatService extends Base { return { prompt: prompt || "", selection: selectionContext } } - private streamResponse({ + private callLlm({ requestBody, - requestOptions, - onEnd + requestOptions }: { requestBody: RequestBodyBase requestOptions: StreamRequestOptions - onEnd?: (completion: string) => void }) { - return streamResponse({ + return llm({ body: requestBody, options: requestOptions, - onData: (streamResponse) => - this.onStreamData(streamResponse as StreamResponse, onEnd), - onEnd: () => this.onStreamEnd(onEnd), - onStart: this.onStreamStart, - onError: this.onStreamError + onStart: this.onLlmStart, + onData: this.onLlmData, + onEnd: this.onLlmEnd, + onError: this.onLlmError }) } private sendEditorLanguage = () => { this._webView?.postMessage({ type: EVENT_NAME.twinnySendLanguage, - value: { - data: getLanguage() - } + data: getLanguage() } as ServerMessage) } private focusChatTab = () => { this._webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.chat - } + data: WEBUI_TABS.chat } as ServerMessage) } @@ -589,31 +580,31 @@ export class ChatService extends Base { if (!provider) return - const conversation = [] + this._conversation = [] - conversation.push(...messages.slice(0, -1)) + this._conversation.push(...messages.slice(0, -1)) if (!provider.modelName.includes("claude")) { - conversation.unshift(systemMessage) + this._conversation.unshift(systemMessage) } if (additionalContext) { const lastMessageContent = `${cleanedText}\n\n${additionalContext.trim()}` - conversation.push({ + this._conversation.push({ role: USER, content: lastMessageContent }) } else { - conversation.push({ + this._conversation.push({ ...lastMessage, content: cleanedText }) } updateLoadingMessage(this._webView, "Thinking") - const request = this.buildStreamRequest(conversation) + const request = this.buildStreamRequest(this._conversation) if (!request) return const { requestBody, requestOptions } = request - return this.streamResponse({ requestBody, requestOptions }) + return this.callLlm({ requestBody, requestOptions }) } public async getTemplateMessages( @@ -640,11 +631,10 @@ export class ChatService extends Base { }) this._webView?.postMessage({ type: EVENT_NAME.twinnyAddMessage, - value: { - completion: kebabToSentence(template) + "\n\n" + "```\n" + selection, - data: getLanguage() + data: { + content: kebabToSentence(template) + "\n\n" + "```\n" + selection } - } as ServerMessage) + } as ServerMessage) } const systemMessage = { @@ -685,7 +675,6 @@ export class ChatService extends Base { public async streamTemplateCompletion( promptTemplate: string, context?: string, - onEnd?: (completion: string) => void, skipMessage?: boolean ) { const messages = await this.getTemplateMessages( @@ -697,6 +686,6 @@ export class ChatService extends Base { if (!request) return const { requestBody, requestOptions } = request - return this.streamResponse({ requestBody, requestOptions, onEnd }) + return this.callLlm({ requestBody, requestOptions }) } } diff --git a/src/extension/conversation-history.ts b/src/extension/conversation-history.ts index 3e2ec6c6..2a6801c3 100644 --- a/src/extension/conversation-history.ts +++ b/src/extension/conversation-history.ts @@ -8,7 +8,7 @@ import { CONVERSATION_STORAGE_KEY, EXTENSION_SESSION_NAME, TITLE_GENERATION_PROMPT_MESAGE, - USER, + USER } from "../common/constants" import { ClientMessage, @@ -16,22 +16,20 @@ import { Message, RequestBodyBase, ServerMessage, - StreamRequestOptions, - StreamResponse, + StreamRequestOptions } from "../common/types" -import { streamResponse } from "./api" import { Base } from "./base" +import { llm } from "./llm" import { TwinnyProvider } from "./provider-manager" import { createStreamRequestBody } from "./provider-options" import { SessionManager } from "./session-manager" import { SymmetryService } from "./symmetry-service" -import { getChatDataFromProvider } from "./utils" +import { getResponseData } from "./utils" type Conversations = Record | undefined export class ConversationHistory extends Base { - public context: ExtensionContext public webView: Webview private _sessionManager: SessionManager | undefined private _symmetryService: SymmetryService @@ -43,8 +41,7 @@ export class ConversationHistory extends Base { sessionManager: SessionManager | undefined, symmetryService: SymmetryService ) { - super() - this.context = context + super(context) this.webView = webView this._sessionManager = sessionManager this._symmetryService = symmetryService @@ -82,7 +79,7 @@ export class ConversationHistory extends Base { streamConversationTitle({ requestBody, - requestOptions, + requestOptions }: { requestBody: RequestBodyBase requestOptions: StreamRequestOptions @@ -94,19 +91,16 @@ export class ConversationHistory extends Base { return new Promise((resolve, reject) => { try { - return streamResponse({ + return llm({ body: requestBody, options: requestOptions, onData: (streamResponse) => { - const data = getChatDataFromProvider( - provider.provider, - streamResponse as StreamResponse - ) - this._title = this._title + data + const data = getResponseData(streamResponse) + this._title = this._title + data.content }, onEnd: () => { return resolve(this._title) - }, + } }) } catch (e) { return reject(e) @@ -129,8 +123,8 @@ export class ConversationHistory extends Base { method: "POST", headers: { "Content-Type": "application/json", - Authorization: `Bearer ${provider.apiKey}`, - }, + Authorization: `Bearer ${provider.apiKey}` + } } } @@ -149,10 +143,10 @@ export class ConversationHistory extends Base { ...messages, { role: USER, - content: TITLE_GENERATION_PROMPT_MESAGE, - }, + content: TITLE_GENERATION_PROMPT_MESAGE + } ], - keepAlive: this.config.keepAlive, + keepAlive: this.config.keepAlive }) return { requestOptions, requestBody } @@ -170,7 +164,7 @@ export class ConversationHistory extends Base { numPredictChat: this.config.numPredictChat, temperature: this.config.temperature, messages, - keepAlive: this.config.keepAlive, + keepAlive: this.config.keepAlive }) return { requestOptions, requestBody } @@ -192,52 +186,47 @@ export class ConversationHistory extends Base { const conversations = this.getConversations() || {} this.webView?.postMessage({ type: CONVERSATION_EVENT_NAME.getConversations, - value: { - data: conversations, - }, + data: conversations }) } getConversations(): Conversations { - const conversations = this.context.globalState.get< + const conversations = this.context?.globalState.get< Record >(CONVERSATION_STORAGE_KEY) return conversations } resetConversation() { - this.context.globalState.update(ACTIVE_CONVERSATION_STORAGE_KEY, undefined) + this.context?.globalState.update(ACTIVE_CONVERSATION_STORAGE_KEY, undefined) this.setActiveConversation(undefined) } updateConversation(conversation: Conversation) { const conversations = this.getConversations() || {} if (!conversation.id) return - this.context.globalState.update(CONVERSATION_STORAGE_KEY, { + this.context?.globalState.update(CONVERSATION_STORAGE_KEY, { ...conversations, - [conversation.id]: conversation, + [conversation.id]: conversation }) this.setActiveConversation(conversation) } setActiveConversation(conversation: Conversation | undefined) { - this.context.globalState.update( + this.context?.globalState.update( ACTIVE_CONVERSATION_STORAGE_KEY, conversation ) this.webView?.postMessage({ type: CONVERSATION_EVENT_NAME.setActiveConversation, - value: { - data: conversation, - }, + data: conversation } as ServerMessage) this.getAllConversations() } getActiveConversation() { - const conversation: Conversation | undefined = this.context.globalState.get( - ACTIVE_CONVERSATION_STORAGE_KEY - ) + const conversation: Conversation | undefined = + this.context?.globalState.get(ACTIVE_CONVERSATION_STORAGE_KEY) this.setActiveConversation(conversation) return conversation } @@ -246,15 +235,15 @@ export class ConversationHistory extends Base { const conversations = this.getConversations() || {} if (!conversation?.id) return delete conversations[conversation.id] - this.context.globalState.update(CONVERSATION_STORAGE_KEY, { - ...conversations, + this.context?.globalState.update(CONVERSATION_STORAGE_KEY, { + ...conversations }) this.setActiveConversation(undefined) this.getAllConversations() } clearAllConversations() { - this.context.globalState.update(CONVERSATION_STORAGE_KEY, {}) + this.context?.globalState.update(CONVERSATION_STORAGE_KEY, {}) this.setActiveConversation(undefined) } @@ -263,7 +252,7 @@ export class ConversationHistory extends Base { if (activeConversation) return this.updateConversation({ ...activeConversation, - messages: conversation.messages, + messages: conversation.messages }) if (!conversation.messages.length || conversation.messages.length > 2) @@ -280,10 +269,10 @@ export class ConversationHistory extends Base { const newConversation: Conversation = { id, title: this._title || "", - messages: conversation.messages, + messages: conversation.messages } conversations[id] = newConversation - this.context.globalState.update(CONVERSATION_STORAGE_KEY, conversations) + this.context?.globalState.update(CONVERSATION_STORAGE_KEY, conversations) this.setActiveConversation(newConversation) this._title = "" } diff --git a/src/extension/embeddings.ts b/src/extension/embeddings.ts index 0be63d07..544b9d88 100644 --- a/src/extension/embeddings.ts +++ b/src/extension/embeddings.ts @@ -5,7 +5,6 @@ import ignore from "ignore" import path from "path" import * as vscode from "vscode" -import { ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY } from "../common/constants" import { logger } from "../common/logger" import { apiProviders, @@ -16,8 +15,8 @@ import { StreamRequestOptions as RequestOptions } from "../common/types" -import { fetchEmbedding } from "./api" import { Base } from "./base" +import { fetchEmbedding } from "./llm" import { TwinnyProvider } from "./provider-manager" import { getDocumentSplitChunks, readGitSubmodulesFile } from "./utils" @@ -26,15 +25,13 @@ export class EmbeddingDatabase extends Base { private _filePaths: EmbeddedDocument[] = [] private _db: lancedb.Connection | null = null private _dbPath: string - private _extensionContext?: vscode.ExtensionContext private _workspaceName = vscode.workspace.name || "" private _documentTableName = `${this._workspaceName}-documents` private _filePathTableName = `${this._workspaceName}-file-paths` - constructor(dbPath: string, extensionContext: vscode.ExtensionContext) { - super() + constructor(dbPath: string, context: vscode.ExtensionContext) { + super(context) this._dbPath = dbPath - this._extensionContext = extensionContext } public async connect() { @@ -45,11 +42,6 @@ export class EmbeddingDatabase extends Base { } } - private getProvider = () => - this._extensionContext?.globalState.get( - ACTIVE_EMBEDDINGS_PROVIDER_STORAGE_KEY - ) - public async fetchModelEmbedding(content: string) { const provider = this.getProvider() @@ -144,14 +136,14 @@ export class EmbeddingDatabase extends Base { cancellable: true }, async (progress) => { - if (!this._extensionContext) return + if (!this.context) return const promises = filePaths.map(async (filePath) => { const content = await fs.promises.readFile(filePath, "utf-8") const chunks = await getDocumentSplitChunks( content, filePath, - this._extensionContext + this.context ) const filePathEmbedding = await this.fetchModelEmbedding(filePath) diff --git a/src/extension/api.ts b/src/extension/llm.ts similarity index 87% rename from src/extension/api.ts rename to src/extension/llm.ts index 5ba242ca..c6a5b262 100644 --- a/src/extension/api.ts +++ b/src/extension/llm.ts @@ -1,5 +1,5 @@ import { Logger } from "../common/logger" -import { StreamRequest } from "../common/types" +import { StreamRequest as LlmRequest } from "../common/types" import { logStreamOptions, @@ -9,7 +9,7 @@ import { const log = Logger.getInstance() -export async function streamResponse(request: StreamRequest) { +export async function llm(request: LlmRequest) { logStreamOptions(request) const { body, options, onData, onEnd, onError, onStart } = request const controller = new AbortController() @@ -45,6 +45,16 @@ export async function streamResponse(request: StreamRequest) { onStart?.(controller) + if (body.stream === false) { + const text = await response.text() + const json = safeParseJsonResponse(text) + + if (!json || !onData) return + + onEnd?.(json) + return + } + const reader = response.body .pipeThrough(new TextDecoderStream()) .pipeThrough( @@ -70,6 +80,7 @@ export async function streamResponse(request: StreamRequest) { if (buffer) { try { const json = safeParseJsonResponse(buffer) + if (!json) return onData(json) } catch (e) { onError?.(new Error("Error parsing JSON data from event")) @@ -98,7 +109,11 @@ export async function streamResponse(request: StreamRequest) { onEnd?.() } else if (error.name === "TimeoutError") { onError?.(error) - log.logConsoleError(Logger.ErrorType.Timeout, "Failed to establish connection", error) + log.logConsoleError( + Logger.ErrorType.Timeout, + "Failed to establish connection", + error + ) } else { log.logConsoleError(Logger.ErrorType.Fetch_Error, "Fetch error", error) onError?.(error) @@ -108,7 +123,7 @@ export async function streamResponse(request: StreamRequest) { } } -export async function fetchEmbedding(request: StreamRequest) { +export async function fetchEmbedding(request: LlmRequest) { const { body, options, onData } = request const controller = new AbortController() diff --git a/src/extension/provider-manager.ts b/src/extension/provider-manager.ts index 092f8099..0e6af1cf 100644 --- a/src/extension/provider-manager.ts +++ b/src/extension/provider-manager.ts @@ -81,9 +81,7 @@ export class ProviderManager { public focusProviderTab = () => { this._webView.postMessage({ type: PROVIDER_EVENT_NAME.focusProviderTab, - value: { - data: WEBUI_TABS.providers, - }, + data: WEBUI_TABS.providers, } as ServerMessage) } @@ -194,9 +192,7 @@ export class ProviderManager { const providers = this.getProviders() || {} this._webView?.postMessage({ type: PROVIDER_EVENT_NAME.getAllProviders, - value: { - data: providers, - }, + data: providers, }) } @@ -206,9 +202,7 @@ export class ProviderManager { ) this._webView?.postMessage({ type: PROVIDER_EVENT_NAME.getActiveChatProvider, - value: { - data: provider, - }, + data: provider, }) return provider } @@ -219,9 +213,7 @@ export class ProviderManager { ) this._webView?.postMessage({ type: PROVIDER_EVENT_NAME.getActiveFimProvider, - value: { - data: provider, - }, + data: provider, }) return provider } @@ -232,9 +224,7 @@ export class ProviderManager { ) this._webView?.postMessage({ type: PROVIDER_EVENT_NAME.getActiveEmbeddingsProvider, - value: { - data: provider, - }, + data: provider, }) return provider } diff --git a/src/extension/provider-options.ts b/src/extension/provider-options.ts index d56c73d0..d635c2f8 100644 --- a/src/extension/provider-options.ts +++ b/src/extension/provider-options.ts @@ -1,6 +1,7 @@ import { USER } from "../common/constants" import { apiProviders, + FunctionTool, Message, RequestBodyBase, RequestOptionsOllama, @@ -10,21 +11,22 @@ import { export function createStreamRequestBody( provider: string, options: { - temperature: number numPredictChat: number model: string messages?: Message[] keepAlive?: string | number - } + }, + tools?: FunctionTool[], ): RequestBodyBase | RequestOptionsOllama | StreamBodyOpenAI { switch (provider) { case apiProviders.Ollama: case apiProviders.OpenWebUI: return { model: options.model, - stream: true, + stream: !tools?.length, messages: options.messages, + tools: tools, keep_alive: options.keepAlive === "-1" ? -1 : options.keepAlive, @@ -37,7 +39,8 @@ export function createStreamRequestBody( default: return { model: options.model, - stream: true, + stream: !tools?.length, + tools: tools, max_tokens: options.numPredictChat, messages: options.messages, temperature: options.temperature, diff --git a/src/extension/providers/base.ts b/src/extension/providers/base.ts index 735397fc..31a01e85 100644 --- a/src/extension/providers/base.ts +++ b/src/extension/providers/base.ts @@ -2,12 +2,12 @@ import { serverMessageKeys } from "symmetry-core" import * as vscode from "vscode" import { + ACTIVE_FIM_PROVIDER_STORAGE_KEY, EVENT_NAME, EXTENSION_SESSION_NAME, SYMMETRY_EMITTER_KEY, SYSTEM, TWINNY_COMMAND_NAME, - WORKSPACE_STORAGE_KEY } from "../../common/constants" import { logger } from "../../common/logger" import { @@ -15,15 +15,17 @@ import { ClientMessage, FileItem, InferenceRequest, + LanguageType, Message, - ServerMessage + ServerMessage, + ThemeType } from "../../common/types" import { ChatService } from "../chat-service" import { ConversationHistory } from "../conversation-history" import { DiffManager } from "../diff" import { EmbeddingDatabase } from "../embeddings" import { OllamaService } from "../ollama" -import { ProviderManager } from "../provider-manager" +import { ProviderManager, TwinnyProvider } from "../provider-manager" import { GithubService as ReviewService } from "../review-service" import { SessionManager } from "../session-manager" import { SymmetryService } from "../symmetry-service" @@ -156,6 +158,12 @@ export class BaseProvider { }) } + public getFimProvider = () => { + return this.context.globalState.get( + ACTIVE_FIM_PROVIDER_STORAGE_KEY + ) + } + private sendLocaleToWebView = () => { this.webView?.postMessage({ type: EVENT_NAME.twinnySetLocale, @@ -225,14 +233,6 @@ export class BaseProvider { return } this.conversationHistory?.resetConversation() - this._chatService?.streamTemplateCompletion( - "commit-message", - diff, - (completion: string) => { - vscode.commands.executeCommand("twinny.sendTerminalText", completion) - }, - true - ) } private twinnyNewConversation = () => { @@ -250,9 +250,7 @@ export class BaseProvider { private setTab = (tab: ClientMessage) => { this.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: tab as string - } + data: tab, } as ServerMessage) } @@ -275,10 +273,7 @@ export class BaseProvider { const config = vscode.workspace.getConfiguration("twinny") this.webView?.postMessage({ type: EVENT_NAME.twinnyGetConfigValue, - value: { - data: config.get(message.key as string), - type: message.key - } + data: config.get(message.key), } as ServerMessage) } @@ -287,9 +282,7 @@ export class BaseProvider { const files = await this._fileTreeProvider?.getAllFiles() this.webView?.postMessage({ type: EVENT_NAME.twinnyFileListResponse, - value: { - data: files - } + data: files }) } } @@ -308,9 +301,7 @@ export class BaseProvider { } this.webView?.postMessage({ type: EVENT_NAME.twinnyFetchOllamaModels, - value: { - data: models - } + data: models } as ServerMessage) } catch (e) { return @@ -321,9 +312,7 @@ export class BaseProvider { const templates = this._templateProvider.listTemplates() this.webView?.postMessage({ type: EVENT_NAME.twinnyListTemplates, - value: { - data: templates - } + data: templates } as ServerMessage) } @@ -398,7 +387,7 @@ export class BaseProvider { ) this.webView?.postMessage({ type: `${EVENT_NAME.twinnyGlobalContext}-${message.key}`, - value: storedData + data: storedData }) } @@ -409,17 +398,15 @@ export class BaseProvider { private getCurrentLanguage = () => { this.webView?.postMessage({ type: EVENT_NAME.twinnySendLanguage, - value: { - data: getLanguage() - } - } as ServerMessage) + data: getLanguage() + } as ServerMessage) } private getSessionContext = (data: ClientMessage) => { if (!data.key) return undefined return this.webView?.postMessage({ type: `${EVENT_NAME.twinnySessionContext}-${data.key}`, - value: this._sessionManager?.get(data.key) + data: this._sessionManager?.get(data.key) }) } @@ -436,19 +423,19 @@ export class BaseProvider { ) this.webView?.postMessage({ type: `${EVENT_NAME.twinnyGetWorkspaceContext}-${message.key}`, - value: storedData + data: storedData } as ServerMessage) } private setWorkspaceContext = (message: ClientMessage) => { - const value = message.data + const data = message.data this.context.workspaceState.update( `${EVENT_NAME.twinnyGetWorkspaceContext}-${message.key}`, - value + data ) this.webView?.postMessage({ type: `${EVENT_NAME.twinnyGetWorkspaceContext}-${message.key}`, - value + data }) } @@ -460,19 +447,14 @@ export class BaseProvider { private sendThemeToWebView() { this.webView?.postMessage({ type: EVENT_NAME.twinnySendTheme, - value: { - data: getTheme() - } - }) + data: getTheme() + } as ServerMessage) } private sendTextSelectionToWebView(text: string) { this.webView?.postMessage({ type: EVENT_NAME.twinnyTextSelection, - value: { - type: WORKSPACE_STORAGE_KEY.selection, - completion: text - } + data: text }) } } diff --git a/src/extension/providers/completion.ts b/src/extension/providers/completion.ts index 71c88c80..ebc3e420 100644 --- a/src/extension/providers/completion.ts +++ b/src/extension/providers/completion.ts @@ -22,7 +22,6 @@ import Parser, { SyntaxNode } from "web-tree-sitter" import "string_score" import { - ACTIVE_FIM_PROVIDER_STORAGE_KEY, FIM_TEMPLATE_FORMAT, LINE_BREAK_REGEX, MAX_CONTEXT_LINE_COUNT, @@ -43,7 +42,6 @@ import { StreamResponse } from "../../common/types" import { getLineBreakCount } from "../../webview/utils" -import { streamResponse } from "../api" import { Base } from "../base" import { cache } from "../cache" import { CompletionFormatter } from "../completion-formatter" @@ -53,6 +51,7 @@ import { getFimTemplateRepositoryLevel, getStopWords } from "../fim-templates" +import { llm } from "../llm" import { getNodeAtPosition, getParser } from "../parser" import { TwinnyProvider } from "../provider-manager" import { createStreamRequestBodyFim } from "../provider-options" @@ -76,7 +75,6 @@ export class CompletionProvider private _completion = "" private _debouncer: NodeJS.Timeout | undefined private _document: TextDocument | null - private _extensionContext: ExtensionContext private _fileInteractionCache: FileInteractionCache private _isMultilineCompletion = false private _lastCompletionMultiline = false @@ -96,9 +94,9 @@ export class CompletionProvider statusBar: StatusBarItem, fileInteractionCache: FileInteractionCache, templateProvider: TemplateProvider, - extensionContext: ExtensionContext + context: ExtensionContext ) { - super() + super(context) this._abortController = null this._document = null this._lock = new AsyncLock() @@ -106,7 +104,29 @@ export class CompletionProvider this._statusBar = statusBar this._fileInteractionCache = fileInteractionCache this._templateProvider = templateProvider - this._extensionContext = extensionContext + } + + private buildFimRequest(prompt: string, provider: TwinnyProvider) { + const body = createStreamRequestBodyFim(provider.provider, prompt, { + model: provider.modelName, + numPredictFim: this.config.numPredictFim, + temperature: this.config.temperature, + keepAlive: this.config.eepAlive + }) + + const options: StreamRequestOptions = { + hostname: provider.apiHostname, + port: provider.apiPort ? Number(provider.apiPort) : undefined, + path: provider.apiPath, + protocol: provider.apiProtocol, + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: provider.apiKey ? `Bearer ${provider.apiKey}` : "" + } + } + + return { options, body } } public async provideInlineCompletionItems( @@ -182,9 +202,12 @@ export class CompletionProvider this._lock.acquire("twinny.completion", async () => { const provider = this.getProvider() if (!provider) return - const request = this.buildStreamRequest(prompt, provider) + const request = this.buildFimRequest(prompt, provider) + + if (!request) return + try { - await streamResponse({ + await llm({ body: request.body, options: request.options, onStart: (controller) => (this._abortController = controller), @@ -224,29 +247,6 @@ export class CompletionProvider } } - private buildStreamRequest(prompt: string, provider: TwinnyProvider) { - const body = createStreamRequestBodyFim(provider.provider, prompt, { - model: provider.modelName, - numPredictFim: this.config.numPredictFim, - temperature: this.config.temperature, - keepAlive: this.config.eepAlive - }) - - const options: StreamRequestOptions = { - hostname: provider.apiHostname, - port: provider.apiPort ? Number(provider.apiPort) : undefined, - path: provider.apiPath, - protocol: provider.apiProtocol, - method: "POST", - headers: { - "Content-Type": "application/json", - Authorization: provider.apiKey ? `Bearer ${provider.apiKey}` : "" - } - } - - return { options, body } - } - private onData(data: StreamResponse | undefined): string { if (!this._provider) return "" @@ -588,11 +588,7 @@ export class CompletionProvider ) } - private getProvider = () => { - return this._extensionContext.globalState.get( - ACTIVE_FIM_PROVIDER_STORAGE_KEY - ) - } + public setAcceptedLastCompletion(value: boolean) { this._acceptedLastCompletion = value diff --git a/src/extension/providers/panel.ts b/src/extension/providers/panel.ts index 9ee90120..cda25578 100644 --- a/src/extension/providers/panel.ts +++ b/src/extension/providers/panel.ts @@ -4,7 +4,6 @@ import { getNonce } from "../utils" import { BaseProvider } from "./base" -// TODO export class FullScreenProvider extends BaseProvider { private _panel?: vscode.WebviewPanel diff --git a/src/extension/providers/sidebar.ts b/src/extension/providers/sidebar.ts index a1955f83..83cde52a 100644 --- a/src/extension/providers/sidebar.ts +++ b/src/extension/providers/sidebar.ts @@ -7,8 +7,6 @@ import { getNonce } from "../utils" import { BaseProvider } from "./base" export class SidebarProvider extends BaseProvider { - public context: vscode.ExtensionContext - constructor( statusBarItem: vscode.StatusBarItem, context: vscode.ExtensionContext, diff --git a/src/extension/review-service.ts b/src/extension/review-service.ts index 060ea3ad..67398d64 100644 --- a/src/extension/review-service.ts +++ b/src/extension/review-service.ts @@ -1,27 +1,28 @@ import { commands, ExtensionContext, Webview } from "vscode" import { + ASSISTANT, EVENT_NAME, EXTENSION_CONTEXT_NAME, GITHUB_EVENT_NAME, USER, - WEBUI_TABS, + WEBUI_TABS } from "../common/constants" import { ClientMessage, + Message, RequestBodyBase, ServerMessage, StreamRequestOptions, - StreamResponse, - TemplateData, + TemplateData } from "../common/types" -import { streamResponse } from "./api" import { ConversationHistory } from "./conversation-history" +import { llm } from "./llm" import { SessionManager } from "./session-manager" import { SymmetryService } from "./symmetry-service" import { TemplateProvider } from "./template-provider" -import { getChatDataFromProvider, updateLoadingMessage } from "./utils" +import { getResponseData, updateLoadingMessage } from "./utils" export class GithubService extends ConversationHistory { private _completion = "" @@ -57,7 +58,7 @@ export class GithubService extends ConversationHistory { private async loadReviewTemplate(diff: string): Promise { return await this._templateProvider.readTemplate("review", { - code: diff, + code: diff }) } @@ -68,7 +69,7 @@ export class GithubService extends ConversationHistory { const prs = await this.getPullRequests(data.owner, data.repo) this.webView.postMessage({ type: GITHUB_EVENT_NAME.getPullRequests, - value: { data: prs }, + data: prs }) } @@ -92,30 +93,28 @@ export class GithubService extends ConversationHistory { ) this.webView.postMessage({ type: GITHUB_EVENT_NAME.getPullRequestReview, - value: { data: review }, + data: review }) } getHeaders() { return { Authorization: `Bearer ${this.config.githubToken}`, - Accept: "application/vnd.github.v3.diff", + Accept: "application/vnd.github.v3.diff" } } private focusChatTab = () => { this.webView.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.chat, - }, + data: WEBUI_TABS.chat } as ServerMessage) } async getPullRequests(owner: string, repo: string) { const url = `https://api.github.com/repos/${owner}/${repo}/pulls` const response = await fetch(url, { - headers: this.getHeaders(), + headers: this.getHeaders() }) return response.json() } @@ -129,7 +128,7 @@ export class GithubService extends ConversationHistory { const headers = this.getHeaders() const url = `https://api.github.com/repos/${owner}/${repo}/pulls/${number}` const response = await fetch(url, { - headers, + headers }) const diff = await response.text() const prompt = await this.loadReviewTemplate(`${title} \n\n ${diff}`) @@ -137,8 +136,8 @@ export class GithubService extends ConversationHistory { const messages = [ { role: USER, - content: prompt, - }, + content: prompt + } ] const request = this.buildStreamRequest(messages) @@ -150,16 +149,13 @@ export class GithubService extends ConversationHistory { this.resetConversation() setTimeout(async () => { - this.webView?.postMessage({ type: EVENT_NAME.twinnyAddMessage, - value: { - completion: prompt, - }, + data: prompt }) this.webView?.postMessage({ - type: EVENT_NAME.twinnyOnLoading, + type: EVENT_NAME.twinnyOnLoading }) commands.executeCommand( @@ -176,7 +172,7 @@ export class GithubService extends ConversationHistory { streamCodeReview({ requestBody, - requestOptions, + requestOptions }: { requestBody: RequestBodyBase requestOptions: StreamRequestOptions @@ -188,7 +184,7 @@ export class GithubService extends ConversationHistory { return new Promise((_, reject) => { try { - return streamResponse({ + return llm({ body: requestBody, options: requestOptions, onStart: (controller: AbortController) => { @@ -204,39 +200,28 @@ export class GithubService extends ConversationHistory { if (!provider) return try { - const data = getChatDataFromProvider( - provider.provider, - streamResponse as StreamResponse - ) - this._completion = this._completion + data + const data = getResponseData(streamResponse) + this._completion = this._completion + data.content this.webView.postMessage({ type: EVENT_NAME.twinnyOnCompletion, - value: { - completion: this._completion.trimStart(), - }, - } as ServerMessage) + data: { + role: ASSISTANT, + content: this._completion.trimStart() + } + } as ServerMessage) } catch (error) { console.error("Error parsing JSON:", error) return } }, - onEnd: () => { - this.webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - completion: this._completion.trimStart(), - }, - }) - this._completion = "" - }, onError: (error: Error) => { this.webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - error: true, - errorMessage: error.message, - }, - } as ServerMessage) + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + role: ASSISTANT, + content: `Something went wrong ${error.message}` + } + } as ServerMessage) } }) } catch (e) { diff --git a/src/extension/symmetry-service.ts b/src/extension/symmetry-service.ts index d5aafb65..9eea29ef 100644 --- a/src/extension/symmetry-service.ts +++ b/src/extension/symmetry-service.ts @@ -16,6 +16,7 @@ import { commands, ExtensionContext, Webview, workspace } from "vscode" import { ACTIVE_CHAT_PROVIDER_STORAGE_KEY, + ASSISTANT, EVENT_NAME, EXTENSION_CONTEXT_NAME, EXTENSION_SESSION_NAME, @@ -25,6 +26,7 @@ import { } from "../common/constants" import { ClientMessage, + Message, Peer, ServerMessage, StreamResponse, @@ -38,7 +40,7 @@ import { SessionManager } from "./session-manager" import { SymmetryWs } from "./symmetry-ws" import { createSymmetryMessage, - getChatDataFromProvider, + getResponseData, safeParseJson, safeParseJsonResponse, updateSymmetryStatus @@ -173,13 +175,11 @@ export class SymmetryService extends EventEmitter { this._providerPeer = peer this.setupProviderListeners(peer) this.notifyWebView(EVENT_NAME.twinnyConnectedToSymmetry, { - data: { - modelName: connection.modelName, - name: connection.name, - provider: connection.provider - } + modelName: connection.modelName, + name: connection.name, + provider: connection.provider }) - this.notifyWebView(EVENT_NAME.twinnySetTab, { data: WEBUI_TABS.chat }) + this.notifyWebView(EVENT_NAME.twinnySetTab, WEBUI_TABS.chat) this._sessionManager?.set( EXTENSION_SESSION_NAME.twinnySymmetryConnection, connection @@ -193,8 +193,8 @@ export class SymmetryService extends EventEmitter { private setupProviderListeners(peer: Peer) { peer.on("data", (chunk: Buffer) => { - const str = chunk.toString() - if (str.includes(serverMessageKeys.inferenceEnded)) + const response = chunk.toString() + if (response.includes(serverMessageKeys.inferenceEnded)) this.handleInferenceEnd() this.handleIncomingData(chunk, (response: StreamResponse) => this.processResponseData(response) @@ -204,28 +204,11 @@ export class SymmetryService extends EventEmitter { private processResponseData(response: StreamResponse) { if (!this._symmetryProvider) return - const data = getChatDataFromProvider(this._symmetryProvider, response) - this._completion += data + const data = getResponseData(response) + this._completion += data.content if (data) this.emit(SYMMETRY_EMITTER_KEY.inference, this._completion) } - private handleInferenceEnd() { - commands.executeCommand( - "setContext", - EXTENSION_CONTEXT_NAME.twinnyGeneratingText, - false - ) - if (!this._completion) return - - this._webView?.postMessage({ - type: EVENT_NAME.twinnyOnEnd, - value: { - completion: this._completion.trimStart() - } - } as ServerMessage) - this._completion = "" - } - private handleIncomingData = ( chunk: Buffer, cb: (data: StreamResponse) => void @@ -245,6 +228,24 @@ export class SymmetryService extends EventEmitter { } } + private handleInferenceEnd() { + commands.executeCommand( + "setContext", + EXTENSION_CONTEXT_NAME.twinnyGeneratingText, + false + ) + if (!this._completion) return + + this._webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: { + role: ASSISTANT, + content: this._completion.trimStart() + } + } as ServerMessage) + this._completion = "" + } + private getSymmetryConfigPath(): string { const homeDir = os.homedir() return path.join(homeDir, ".config", "symmetry", "provider.yaml") @@ -274,7 +275,7 @@ export class SymmetryService extends EventEmitter { public: true, serverKey: this._config.symmetryServerKey, systemMessage: "", - userSecret: "", + userSecret: "" } const symmetryConfiguration = yaml.dump(config) @@ -290,7 +291,9 @@ export class SymmetryService extends EventEmitter { return yaml.load(configStr) as ProviderConfig } - private updateProviderConfig = async (provider: TwinnyProvider): Promise => { + private updateProviderConfig = async ( + provider: TwinnyProvider + ): Promise => { const configPath = this.getSymmetryConfigPath() const configDir = path.dirname(configPath) @@ -312,9 +315,15 @@ export class SymmetryService extends EventEmitter { if (!config.dataPath) updates.dataPath = configDir const updatedConfig = { ...config, ...updates } - await fs.promises.writeFile(configPath, yaml.dump(updatedConfig), "utf8") + await fs.promises.writeFile( + configPath, + yaml.dump(updatedConfig), + "utf8" + ) } else { - config = this.createProviderConfig(this.getChatProvider() as TwinnyProvider) + config = this.createProviderConfig( + this.getChatProvider() as TwinnyProvider + ) await fs.promises.writeFile(configPath, yaml.dump(config), "utf8") } } catch (error) { @@ -338,7 +347,7 @@ export class SymmetryService extends EventEmitter { const sessionTypeName = `${EVENT_NAME.twinnySessionContext}-${sessionKey}` this._webView?.postMessage({ type: sessionTypeName, - value: "connecting" + data: "connecting" }) await this._client.init() @@ -346,7 +355,7 @@ export class SymmetryService extends EventEmitter { this._sessionManager?.set(sessionKey, "connected") this._webView?.postMessage({ type: sessionTypeName, - value: "connected" + data: "connected" }) } catch (error) { console.error("Failed to start provider:", error) @@ -357,8 +366,8 @@ export class SymmetryService extends EventEmitter { } } - private notifyWebView(type: string, value: any = {}) { - this._webView?.postMessage({ type, value }) + private notifyWebView(type: string, data: any = {}) { + this._webView?.postMessage({ type, data }) } public getChatProvider() { diff --git a/src/extension/symmetry-ws.ts b/src/extension/symmetry-ws.ts index efd6fefc..a9f7c34f 100644 --- a/src/extension/symmetry-ws.ts +++ b/src/extension/symmetry-ws.ts @@ -20,9 +20,7 @@ export class SymmetryWs { const parsedData = JSON.parse(data.toString()) this._webView?.postMessage({ type: EVENT_NAME.twinnySymmetryModels, - value: { - data: parsedData?.allPeers?.filter((peer: any) => peer.online) - } + data: parsedData?.allPeers?.filter((peer: any) => peer.online) }) } catch (error) { console.error("Error parsing WebSocket message:", error) diff --git a/src/extension/templates.ts b/src/extension/templates.ts index fe5a082e..b860c1c0 100644 --- a/src/extension/templates.ts +++ b/src/extension/templates.ts @@ -76,18 +76,6 @@ These file paths may be relevant to your query: {{{code}}} Consider these in your response if pertinent. Disregard if not relevant.`.trim() - }, - { - name: "commit-message", - template: ` -Generate a concise git commit message. -Respond with a single line of text, maximum 100 characters. - -Example: "Added a new feature" - -Unidiff: \`\`\`{{code}}\`\`\` - -`.trim() }, { name: "fim", diff --git a/src/extension/tools.ts b/src/extension/tools.ts new file mode 100644 index 00000000..7ee51942 --- /dev/null +++ b/src/extension/tools.ts @@ -0,0 +1,292 @@ +import * as path from "path" +import { TextEncoder } from "util" +import * as vscode from "vscode" + +import { EVENT_NAME, TOOL_EVENT_NAME } from "../common/constants" +import { ClientMessage, CreateFileArgs, EditFileArgs, Message, OpenFileArgs, RunCommandArgs, ServerMessage, Tool } from "../common/types" + +import { Base } from "./base" + +export class Tools extends Base { + private _workspaceRoot: string | undefined + private webView: vscode.Webview + + constructor(webView: vscode.Webview, context: vscode.ExtensionContext) { + super(context) + this.webView = webView + const workspaceFolders = vscode.workspace.workspaceFolders + if (workspaceFolders) this._workspaceRoot = workspaceFolders[0].uri.fsPath + this.setUpEventListeners() + } + + setUpEventListeners() { + this.webView?.onDidReceiveMessage((message: ClientMessage) => { + this.handleMessage(message) + }) + } + + public async rejectTool(call: Tool, message: Message): Promise { + if (message.tools && message.tools[call.name]) { + message.tools[call.name].status = "rejected" + } + + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: message + } as ServerMessage) + } + + public async runTool(call: Tool, message?: Message): Promise { + const method = this?.[call.name as keyof Tools] + if (method && typeof method === "function") { + if (message?.tools && message.tools[call.name]) { + message.tools[call.name].status = "running" + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: message + } as ServerMessage) + } + + const boundMethod = method.bind(this) as ( + args: unknown + ) => Promise + const result = await boundMethod(call.arguments) + if (message?.tools && message.tools[call.name]) { + message.tools[call.name].status = "success" + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: message + } as ServerMessage) + } + + return result + } + + return "" + } + + private async runAllTools(message: Message | undefined) { + if (!message?.tools) return + + for (const [toolName, tool] of Object.entries(message.tools)) { + try { + message.tools[toolName].status = "running" + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: message + } as ServerMessage) + + await this.runTool(tool, message) + } catch (error: unknown) { + if (error instanceof Error && message.tools[toolName]) { + message.tools[toolName].error = error.message + message.tools[toolName].status = "error" + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: message + } as ServerMessage) + } + } + } + } + + async handleMessage( + message: ClientMessage + ) { + const { type } = message + switch (type) { + case TOOL_EVENT_NAME.runAllTools: + return await this.runAllTools(message.data as Message) + + case TOOL_EVENT_NAME.runTool: { + const data = message.data as { message: Message; tool: Tool } + try { + await this.runTool(data.tool, data.message) + } catch (error: unknown) { + if (error instanceof Error) { + const tools = data.message.tools + if (!tools) return + tools[data.tool.name].error = error.message + tools[data.tool.name].status = "error" + this.webView?.postMessage({ + type: EVENT_NAME.twinnyOnCompletionEnd, + data: data.message + } as ServerMessage) + } + } + return + } + + case TOOL_EVENT_NAME.rejectTool: { + const data = message.data as { message: Message; tool: Tool } + await this.rejectTool(data.tool, data.message) + return + } + } + } + + async createFile(args: CreateFileArgs): Promise { + const { + path: filePath, + content, + openAfterCreate = false, + createIntermediateDirs, + fileTemplate + } = args + const fullPath = path.join(this._workspaceRoot || "", filePath) + const uri = vscode.Uri.file(fullPath) + + const finalContent = fileTemplate ? `${fileTemplate}\n${content}` : content + + try { + if (createIntermediateDirs) { + await vscode.workspace.fs.createDirectory( + vscode.Uri.file(path.dirname(fullPath)) + ) + } + + const ws = new vscode.WorkspaceEdit() + ws.createFile(uri, { overwrite: false }) + await vscode.workspace.applyEdit(ws) + + const enc = new TextEncoder() + await vscode.workspace.fs.writeFile(uri, enc.encode(finalContent)) + + if (openAfterCreate) { + const doc = await vscode.workspace.openTextDocument(uri) + await vscode.window.showTextDocument(doc) + } + + return `File created successfully at ${fullPath}` + } catch (error) { + throw new Error( + `Failed to create file: ${ + error instanceof Error ? error.message : String(error) + }` + ) + } + } + + async runCommand(args: RunCommandArgs): Promise { + const { + command, + cwd = this._workspaceRoot, + env, + shell, + timeout, + runInBackground + } = args + + return new Promise((resolve, reject) => { + const terminal = vscode.window.createTerminal({ + name: "Extension Command", + cwd: cwd, + env: env, + shellPath: shell + }) + + if (timeout) { + setTimeout(() => { + reject(new Error("Command timed out")) + }, timeout) + } + + terminal.sendText(command, true) + + if (!runInBackground) { + terminal.show() + } + + resolve("Command executed successfully!") + }) + } + + async openFile(args: OpenFileArgs): Promise { + const { + path: filePath, + preview = false, + viewColumn = "active", + } = args + const fullPath = path.join(this._workspaceRoot || "", filePath) + const uri = vscode.Uri.file(fullPath) + + try { + const doc = await vscode.workspace.openTextDocument(uri) + let column: vscode.ViewColumn | undefined = vscode.ViewColumn.Active + if (viewColumn === "beside") { + column = vscode.ViewColumn.Beside + } else if (viewColumn === "new") { + column = vscode.ViewColumn.Active + } + + await vscode.window.showTextDocument(doc, { + preview, + viewColumn: column + }) + + return `File opened successfully: ${fullPath}` + } catch (error) { + throw new Error( + `Failed to open file: ${ + error instanceof Error ? error.message : String(error) + }` + ) + } + } + + async editFile(args: EditFileArgs): Promise { + const { + path: filePath, + createIfNotExists = true, + backupBeforeEdit = false, + edit + } = args + const fullPath = path.join(this._workspaceRoot || "", filePath) + const uri = vscode.Uri.file(fullPath) + + try { + let fileExists = true + try { + await vscode.workspace.fs.stat(uri) + } catch { + fileExists = false + } + + if (!fileExists) { + if (createIfNotExists) { + const ws = new vscode.WorkspaceEdit() + ws.createFile(uri, { overwrite: false }) + await vscode.workspace.applyEdit(ws) + } else { + throw new Error(`File does not exist: ${fullPath}`) + } + } + + if (backupBeforeEdit && fileExists) { + const backupUri = vscode.Uri.file(fullPath + ".bak") + const data = await vscode.workspace.fs.readFile(uri) + await vscode.workspace.fs.writeFile(backupUri, data) + } + + const doc = await vscode.workspace.openTextDocument(uri) + const entireRange = new vscode.Range( + doc.positionAt(0), + doc.positionAt(doc.getText().length) + ) + + const ws = new vscode.WorkspaceEdit() + ws.replace(uri, entireRange, edit) + + await vscode.workspace.applyEdit(ws) + await doc.save() + + return `File edited successfully: ${fullPath}` + } catch (error) { + throw new Error( + `Failed to edit file: ${ + error instanceof Error ? error.message : String(error) + }` + ) + } + } +} diff --git a/src/extension/tree.ts b/src/extension/tree.ts index cf75d367..3b3b9f80 100644 --- a/src/extension/tree.ts +++ b/src/extension/tree.ts @@ -21,7 +21,7 @@ export class FileTreeProvider { this._workspaceRoot = workspaceFolders[0].uri.fsPath } - provideTextDocumentContent(): string { + getWorkSpaceTree(): string { return this.generateFileTree(this._workspaceRoot) } diff --git a/src/extension/utils.ts b/src/extension/utils.ts index d7178d22..4ac4f9b0 100644 --- a/src/extension/utils.ts +++ b/src/extension/utils.ts @@ -26,12 +26,10 @@ import { EVENT_NAME, EXTENSION_CONTEXT_NAME, knownErrorMessages, - LINE_BREAK_REGEX, MULTILINE_TYPES, NORMALIZE_REGEX, OPENING_BRACKETS, QUOTES, - QUOTES_REGEX, SKIP_DECLARATION_SYMBOLS, TWINNY } from "../common/constants" @@ -313,24 +311,26 @@ export const getTheme = () => { } } -export const getChatDataFromProvider = ( - provider: string, - data: StreamResponse -) => { - switch (provider) { - case apiProviders.Ollama: - case apiProviders.OpenWebUI: - return data?.choices[0].delta?.content - ? data?.choices[0].delta.content - : "" - case apiProviders.LlamaCpp: - return data?.content - case apiProviders.LiteLLM: - default: - if (data?.choices[0].delta.content === "undefined") return "" - return data?.choices[0].delta?.content - ? data?.choices[0].delta.content - : "" +export const getResponseData = (data: StreamResponse) => { + const toolCalls = data?.choices?.[0]?.message?.tool_calls + + if (toolCalls?.length) { + return { + type: "function_call" as const, + calls: toolCalls.map((call) => ({ + id: call.id, + name: call.function.name, + arguments: JSON.parse(call.function.arguments) + })) + } + } + + return { + type: "content" as const, + content: + data?.choices?.[0]?.delta?.content || + data.choices[0].message?.content || + "" } } @@ -451,15 +451,6 @@ export function createSymmetryMessage( return JSON.stringify({ key, data }) } -export const getSanitizedCommitMessage = (commitMessage: string) => { - const sanitizedMessage = commitMessage - .replace(QUOTES_REGEX, "") - .replace(LINE_BREAK_REGEX, "") - .trim() - - return `git commit -m "${sanitizedMessage}"` -} - export const getNormalisedText = (text: string) => text.replace(NORMALIZE_REGEX, " ") @@ -597,9 +588,7 @@ export const updateLoadingMessage = ( ) => { webView?.postMessage({ type: EVENT_NAME.twinnySendLoader, - value: { - data: message - } + data: message } as ServerMessage) } @@ -609,9 +598,7 @@ export const updateSymmetryStatus = ( ) => { webView?.postMessage({ type: EVENT_NAME.twinnySendSymmetryMessage, - value: { - data: message - } + data: message } as ServerMessage) } diff --git a/src/index.ts b/src/index.ts index efab9926..467c4ed5 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,7 +7,7 @@ import { languages, StatusBarAlignment, window, - workspace, + workspace } from "vscode" import * as vscode from "vscode" @@ -16,7 +16,7 @@ import { EXTENSION_CONTEXT_NAME, EXTENSION_NAME, TWINNY_COMMAND_NAME, - WEBUI_TABS, + WEBUI_TABS } from "./common/constants" import { ServerMessage } from "./common/types" import { setContext } from "./extension/context" @@ -29,8 +29,6 @@ import { SessionManager } from "./extension/session-manager" import { TemplateProvider } from "./extension/template-provider" import { delayExecution, - getSanitizedCommitMessage, - getTerminal, } from "./extension/utils" import { getLineBreakCount } from "./webview/utils" @@ -137,9 +135,7 @@ export async function activate(context: ExtensionContext) { ) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.providers, - }, + data: WEBUI_TABS.providers } as ServerMessage) }), commands.registerCommand( @@ -152,9 +148,7 @@ export async function activate(context: ExtensionContext) { ) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.symmetry, - }, + data: WEBUI_TABS.symmetry } as ServerMessage) } ), @@ -168,9 +162,7 @@ export async function activate(context: ExtensionContext) { ) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.history, - }, + data: WEBUI_TABS.history } as ServerMessage) } ), @@ -182,9 +174,7 @@ export async function activate(context: ExtensionContext) { ) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.review, - }, + data: WEBUI_TABS.review } as ServerMessage) }), commands.registerCommand(TWINNY_COMMAND_NAME.manageTemplates, async () => { @@ -195,9 +185,7 @@ export async function activate(context: ExtensionContext) { ) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.settings, - }, + data: WEBUI_TABS.settings } as ServerMessage) }), commands.registerCommand(TWINNY_COMMAND_NAME.hideBackButton, () => { @@ -231,9 +219,7 @@ export async function activate(context: ExtensionContext) { commands.executeCommand(TWINNY_COMMAND_NAME.hideBackButton) sidebarProvider.webView?.postMessage({ type: EVENT_NAME.twinnySetTab, - value: { - data: WEBUI_TABS.chat, - }, + data: WEBUI_TABS.chat } as ServerMessage) }), commands.registerCommand(TWINNY_COMMAND_NAME.settings, () => { @@ -242,13 +228,6 @@ export async function activate(context: ExtensionContext) { EXTENSION_NAME ) }), - commands.registerCommand( - TWINNY_COMMAND_NAME.sendTerminalText, - async (commitMessage: string) => { - const terminal = await getTerminal() - terminal?.sendText(getSanitizedCommitMessage(commitMessage), false) - } - ), commands.registerCommand(TWINNY_COMMAND_NAME.getGitCommitMessage, () => { commands.executeCommand(TWINNY_COMMAND_NAME.focusSidebar) sidebarProvider.conversationHistory?.resetConversation() @@ -258,11 +237,11 @@ export async function activate(context: ExtensionContext) { sidebarProvider.conversationHistory?.resetConversation() sidebarProvider.newConversation() sidebarProvider.webView?.postMessage({ - type: EVENT_NAME.twinnyStopGeneration, + type: EVENT_NAME.twinnyStopGeneration } as ServerMessage) }), commands.registerCommand(TWINNY_COMMAND_NAME.openPanelChat, () => { - commands.executeCommand("workbench.action.closeSidebar"); + commands.executeCommand("workbench.action.closeSidebar") fullScreenProvider.createOrShowPanel() }), workspace.onDidCloseTextDocument((document) => { diff --git a/src/webview/assets/locales/en.json b/src/webview/assets/locales/en.json index cb95389c..3f741e98 100644 --- a/src/webview/assets/locales/en.json +++ b/src/webview/assets/locales/en.json @@ -19,6 +19,7 @@ "conversation-history": "Conversation History", "copy-code": "Copy Code", "copy-provider": "Copy Provider", + "createFile": "Create file", "delete-message": "Delete message", "delete-provider": "Delete Provider", "disconnect": "Disconnect", @@ -28,6 +29,8 @@ "edit-provider": "Edit Provider", "embed-documents": "Embed documents", "embedding-provider": "Embedding provider", + "enable-tools": "Enable tools (Function calling)", + "error": "Error", "fim-template": "FIM Template", "fim": "Fill-in-middle", "hostname-placeholder": "Enter a hostname e.g 'localhost'", @@ -43,14 +46,17 @@ "new-document": "New Document", "no-connections-found": "No connections found. Please add a new connection to get started.", "no-result": "No result", + "no-tools": "No tools to run", "nothing-to-see-here": "Nothing to see here.", "number-code-filepaths": "The number of file paths to be used as context.", "number-code-snippets": "The number of code snippets to be used as context.", "open-diff": "Open Diff", "open-template-editor": "Open template editor", + "openFile": "Open file", "overlap-size": "Overlap size", "owner-repo-name": "This tab will help you review pull requests in your repository, enter the owner and repository name below to get started. For now only GitHub is supported, set your GitHub token in the settings tab to get started.", "path": "Path", + "pending": "Pending", "placeholder": "How can twinny help you today?", "port-placeholder": "Enter a port number e.g '11434'", "port": "Port", @@ -65,6 +71,7 @@ "regenerate-message": "Regenerate message", "relevant-code-snippets": "Relevant code snippets", "relevant-file-paths": "Relevant file paths", + "remove": "Remove", "repository-level": "Repository level", "rerank-probability-threshold": "Rerank probability threshold", "rerank-threshold-description": "The lower the threshold, the more likely a result is to be included.", @@ -72,20 +79,24 @@ "reset-providers": "Reset Providers", "reset-to-default": "Reset to default", "review-pull-requests": "Review pull requests", + "run-all": "Run All", + "run": "Run", + "runCommand": "Run command", + "running": "Running", "save-edit": "Save edit", "save": "Save", "scroll-down": "Scroll down to the bottom", "share-gpu-resources": "You can also share your GPU resources by connecting to Symmetry as a provider using your active twinny provider configuration. All connections are peer to peer, encrypted end-to-end and secure.", "status": "Status", "stop-generation": "Stop generation", + "success": "Success", "symmetry-description": "Symmetry is a peer-to-peer AI inference network that allows secure, direct connections between users. When you connect as a consumer, Symmetry matches you with a provider based on your model selection.", "symmetry-inference-network": "Symmetry Inference Network", "template-settings-description": "Select the templates you want to use in the chat interface.", - "template-settings": "Template settings", "thinking": "Thinking...", "toggle-auto-scroll": "Toggle auto scroll on/off", "toggle-embedding-options": "Toggle embedding options on/off", "toggle-provider-selection": "Toggle provider selection", + "tools": "Tools", "type": "Type" } - \ No newline at end of file diff --git a/src/webview/chat.tsx b/src/webview/chat.tsx index 06a860fe..fe7ef6e0 100644 --- a/src/webview/chat.tsx +++ b/src/webview/chat.tsx @@ -2,27 +2,29 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react" import { useTranslation } from "react-i18next" import Mention from "@tiptap/extension-mention" import Placeholder from "@tiptap/extension-placeholder" -import { Editor, EditorContent, JSONContent,useEditor } from "@tiptap/react" +import { Editor, EditorContent, JSONContent, useEditor } from "@tiptap/react" import StarterKit from "@tiptap/starter-kit" import { VSCodeBadge, VSCodeButton, VSCodeDivider, - VSCodePanelView, + VSCodePanelView } from "@vscode/webview-ui-toolkit/react" import cn from "classnames" import { ASSISTANT, EVENT_NAME, + TOOL_EVENT_NAME, USER, - WORKSPACE_STORAGE_KEY, + WORKSPACE_STORAGE_KEY } from "../common/constants" import { ClientMessage, MentionType, - Message as MessageType, + Message, ServerMessage, + Tool } from "../common/types" import { EmbeddingOptions } from "./embedding-options" @@ -32,17 +34,14 @@ import useAutosizeTextArea, { useSuggestion, useSymmetryConnection, useTheme, - useWorkSpaceContext, + useWorkSpaceContext } from "./hooks" -import { - DisabledAutoScrollIcon, - EnabledAutoScrollIcon, -} from "./icons" +import { DisabledAutoScrollIcon, EnabledAutoScrollIcon } from "./icons" import ChatLoader from "./loader" -import { Message } from "./message" +import { Message as MessageComponent } from "./message" import { ProviderSelect } from "./provider-select" import { Suggestions } from "./suggestions" -import { CustomKeyMap, getCompletionContent } from "./utils" +import { CustomKeyMap } from "./utils" import styles from "./styles/index.module.css" @@ -60,8 +59,8 @@ export const Chat = (props: ChatProps): JSX.Element => { const theme = useTheme() const { t } = useTranslation() const [isLoading, setIsLoading] = useState(false) - const [messages, setMessages] = useState() - const [completion, setCompletion] = useState() + const [messages, setMessages] = useState() + const [completion, setCompletion] = useState() const markdownRef = useRef(null) const { symmetryConnection } = useSymmetryConnection() @@ -71,7 +70,7 @@ export const Chat = (props: ChatProps): JSX.Element => { useWorkSpaceContext(WORKSPACE_STORAGE_KEY.showProviders) const { context: showEmbeddingOptionsContext, - setContext: setShowEmbeddingOptionsContext, + setContext: setShowEmbeddingOptionsContext } = useWorkSpaceContext(WORKSPACE_STORAGE_KEY.showEmbeddingOptions) const { conversation, saveLastConversation, setActiveConversation } = useConversationHistory() @@ -89,34 +88,50 @@ export const Chat = (props: ChatProps): JSX.Element => { const selection = useSelection(scrollToBottom) - const handleCompletionEnd = (message: ServerMessage) => { - if (message.value) { - setMessages((prev) => { - const messages = [ - ...(prev || []), - { - role: ASSISTANT, - content: getCompletionContent(message), - }, - ] + const handleCompletionEnd = (message: ServerMessage) => { + if (!message.data) { + setCompletion(null) + setIsLoading(false) + generatingRef.current = false + return + } - saveLastConversation({ - ...conversation, - messages: messages, - }) - return messages + setMessages((prev) => { + if (message.data.id) { + const existingIndex = prev?.findIndex((m) => m.id === message.data.id) + + if (existingIndex !== -1) { + const updatedMessages = [...(prev || [])] + + updatedMessages[existingIndex || 0] = message.data + + saveLastConversation({ + ...conversation, + messages: updatedMessages + }) + return updatedMessages + } + } + + const messages = [...(prev || []), message.data] + saveLastConversation({ + ...conversation, + messages: messages }) - setTimeout(() => { - editorRef.current?.commands.focus() - stopRef.current = false - }, 200) - } + return messages + }) + + setTimeout(() => { + editorRef.current?.commands.focus() + stopRef.current = false + }, 200) + setCompletion(null) setIsLoading(false) generatingRef.current = false } - const handleAddTemplateMessage = (message: ServerMessage) => { + const handleAddTemplateMessage = (message: ServerMessage) => { if (stopRef.current) { generatingRef.current = false return @@ -124,24 +139,15 @@ export const Chat = (props: ChatProps): JSX.Element => { generatingRef.current = true setIsLoading(false) scrollToBottom() - setMessages((prev) => [ - ...(prev || []), - { - role: USER, - content: message.value.completion as string, - }, - ]) + setMessages((prev) => [...(prev || []), message.data]) } - const handleCompletionMessage = (message: ServerMessage) => { + const handleCompletionMessage = (message: ServerMessage) => { if (stopRef.current) { generatingRef.current = false return } - setCompletion({ - role: ASSISTANT, - content: getCompletionContent(message), - }) + setCompletion(message.data) scrollToBottom() } @@ -154,19 +160,19 @@ export const Chat = (props: ChatProps): JSX.Element => { const message: ServerMessage = event.data switch (message.type) { case EVENT_NAME.twinnyAddMessage: { - handleAddTemplateMessage(message) + handleAddTemplateMessage(message as ServerMessage) break } case EVENT_NAME.twinnyOnCompletion: { - handleCompletionMessage(message) + handleCompletionMessage(message as ServerMessage) break } case EVENT_NAME.twinnyOnLoading: { handleLoadingMessage() break } - case EVENT_NAME.twinnyOnEnd: { - handleCompletionEnd(message) + case EVENT_NAME.twinnyOnCompletionEnd: { + handleCompletionEnd(message as ServerMessage) break } case EVENT_NAME.twinnyStopGeneration: { @@ -186,7 +192,7 @@ export const Chat = (props: ChatProps): JSX.Element => { const handleStopGeneration = () => { stopRef.current = true global.vscode.postMessage({ - type: EVENT_NAME.twinnyStopGeneration, + type: EVENT_NAME.twinnyStopGeneration } as ClientMessage) setCompletion(null) setIsLoading(false) @@ -205,7 +211,7 @@ export const Chat = (props: ChatProps): JSX.Element => { global.vscode.postMessage({ type: EVENT_NAME.twinnyChatMessage, - data: updatedMessages, + data: updatedMessages } as ClientMessage) return updatedMessages @@ -220,12 +226,12 @@ export const Chat = (props: ChatProps): JSX.Element => { const updatedMessages = [ ...prev.slice(0, index), - ...prev.slice(index + 2), + ...prev.slice(index + 2) ] saveLastConversation({ ...conversation, - messages: updatedMessages, + messages: updatedMessages }) return updatedMessages @@ -239,12 +245,12 @@ export const Chat = (props: ChatProps): JSX.Element => { const updatedMessages = [ ...prev.slice(0, index), - { ...prev[index], content: message }, + { ...prev[index], content: message } ] global.vscode.postMessage({ type: EVENT_NAME.twinnyChatMessage, - data: updatedMessages, + data: updatedMessages } as ClientMessage) return updatedMessages @@ -262,7 +268,7 @@ export const Chat = (props: ChatProps): JSX.Element => { innerNode.attrs.label || innerNode.attrs.id.split("/").pop() || "", - path: innerNode.attrs.id, + path: innerNode.attrs.id }) } }) @@ -296,18 +302,18 @@ export const Chat = (props: ChatProps): JSX.Element => { setMessages((prevMessages) => { const updatedMessages = [ ...(prevMessages || []), - { role: USER, content: replaceMentionsInText(input, mentions) }, + { role: USER, content: replaceMentionsInText(input, mentions) } ] - const clientMessage: ClientMessage = { + const clientMessage: ClientMessage = { type: EVENT_NAME.twinnyChatMessage, data: updatedMessages, - meta: mentions, + meta: mentions } saveLastConversation({ ...conversation, - messages: updatedMessages, + messages: updatedMessages }) global.vscode.postMessage(clientMessage) @@ -331,7 +337,7 @@ export const Chat = (props: ChatProps): JSX.Element => { global.vscode.postMessage({ type: EVENT_NAME.twinnySetWorkspaceContext, key: WORKSPACE_STORAGE_KEY.autoScroll, - data: !prev, + data: !prev } as ClientMessage) if (!prev) scrollToBottom() @@ -346,7 +352,7 @@ export const Chat = (props: ChatProps): JSX.Element => { global.vscode.postMessage({ type: EVENT_NAME.twinnySetWorkspaceContext, key: WORKSPACE_STORAGE_KEY.showProviders, - data: !prev, + data: !prev } as ClientMessage) return !prev }) @@ -358,7 +364,7 @@ export const Chat = (props: ChatProps): JSX.Element => { global.vscode.postMessage({ type: EVENT_NAME.twinnySetWorkspaceContext, key: WORKSPACE_STORAGE_KEY.showEmbeddingOptions, - data: !prev, + data: !prev } as ClientMessage) return !prev }) @@ -370,9 +376,42 @@ export const Chat = (props: ChatProps): JSX.Element => { } } + const handleNewConversation = () => { + global.vscode.postMessage({ + type: EVENT_NAME.twinnyNewConversation + }) + } + + const handleRejectTool = (message: Message, tool: Tool) => { + global.vscode.postMessage({ + type: TOOL_EVENT_NAME.rejectTool, + data: { + message, + tool + } + } as ClientMessage<{ message: Message; tool: Tool }>) + } + + const handleRunTool = (message: Message, tool: Tool) => { + global.vscode.postMessage({ + type: TOOL_EVENT_NAME.runTool, + data: { + message, + tool + } + } as ClientMessage<{ message: Message; tool: Tool }>) + } + + const handleRunAllTools = (message: Message) => { + global.vscode.postMessage({ + type: TOOL_EVENT_NAME.runAllTools, + data: message + } as ClientMessage) + } + useEffect(() => { global.vscode.postMessage({ - type: EVENT_NAME.twinnyHideBackButton, + type: EVENT_NAME.twinnyHideBackButton }) }, []) @@ -404,7 +443,7 @@ export const Chat = (props: ChatProps): JSX.Element => { StarterKit, Mention.configure({ HTMLAttributes: { - class: "mention", + class: "mention" }, suggestion: memoizedSuggestion, renderText({ node }) { @@ -412,16 +451,16 @@ export const Chat = (props: ChatProps): JSX.Element => { return `${node.attrs.name ?? node.attrs.id}` } return node.attrs.id ?? "" - }, + } }), CustomKeyMap.configure({ handleSubmitForm, - clearEditor, + clearEditor }), Placeholder.configure({ placeholder: t("placeholder") // "How can twinny help you today?", - }), - ], + }) + ] }, [memoizedSuggestion] ) @@ -443,12 +482,6 @@ export const Chat = (props: ChatProps): JSX.Element => { } }, [memoizedSuggestion]) - const handleNewConversation = () => { - global.vscode.postMessage({ - type: EVENT_NAME.twinnyNewConversation, - }) - } - return (
@@ -471,12 +504,12 @@ export const Chat = (props: ChatProps): JSX.Element => {
{messages?.map((message, index) => ( - { message={message} theme={theme} index={index} + onRejectTool={handleRejectTool} + onRunTool={handleRunTool} + onRunAllTools={handleRunAllTools} /> ))} {isLoading && !completion ? ( ) : ( !!completion && ( - ) diff --git a/src/webview/hooks.ts b/src/webview/hooks.ts index 01080548..2016152e 100644 --- a/src/webview/hooks.ts +++ b/src/webview/hooks.ts @@ -44,9 +44,10 @@ const global = globalThis as any export const useSelection = (onSelect?: () => void) => { const [selection, setSelection] = useState("") const handler = (event: MessageEvent) => { - const message: ServerMessage = event.data + const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnyTextSelection) { - setSelection(message?.value.completion.trim()) + const selection = message?.data?.trim() + setSelection(selection || "") onSelect?.() } } @@ -68,7 +69,7 @@ export const useGlobalContext = (key: string) => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === `${EVENT_NAME.twinnyGlobalContext}-${key}`) { - setContextState(event.data.value) + setContextState(event.data.data) } } @@ -101,7 +102,7 @@ export const useSessionContext = (key: string) => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === `${EVENT_NAME.twinnySessionContext}-${key}`) { - setContext(event.data.value) + setContext(event.data.data) } } @@ -123,7 +124,7 @@ export const useWorkSpaceContext = (key: string) => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === `${EVENT_NAME.twinnyGetWorkspaceContext}-${key}`) { - setContext(event.data.value) + setContext(event.data.data) } } @@ -145,7 +146,7 @@ export const useTheme = () => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnySendTheme) { - setTheme(message?.value.data) + setTheme(message?.data) } return () => window.removeEventListener("message", handler) } @@ -164,7 +165,7 @@ export const useLoading = () => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnySendLoader) { - setLoader(message?.value.data) + setLoader(message?.data) } return () => window.removeEventListener("message", handler) } @@ -179,11 +180,14 @@ export const useLoading = () => { } export const useLanguage = (): LanguageType | undefined => { - const [language, setLanguage] = useState() + const [language, setLanguage] = useState() const handler = (event: MessageEvent) => { - const message: ServerMessage = event.data + const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnySendLanguage) { - setLanguage(message?.value.data) + const language = message.data + if (language) { + setLanguage(language) + } } return () => window.removeEventListener("message", handler) } @@ -202,7 +206,7 @@ export const useTemplates = () => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnyListTemplates) { - setTemplates(message?.value.data) + setTemplates(message?.data) } return () => window.removeEventListener("message", handler) } @@ -239,7 +243,7 @@ export const useGithubPRs = () => { const handler = (event: MessageEvent) => { const message = event.data if (message.type === GITHUB_EVENT_NAME.getPullRequests) { - setPRs(message.value.data) + setPRs(message.data) setIsLoading(false) } } @@ -288,26 +292,26 @@ export const useProviders = () => { Record | TwinnyProvider > = event.data if (message?.type === PROVIDER_EVENT_NAME.getAllProviders) { - if (message.value.data) { - const providers = message.value.data as Record + if (message.data) { + const providers = message.data as Record setProviders(providers) } } if (message?.type === PROVIDER_EVENT_NAME.getActiveChatProvider) { - if (message.value.data) { - const provider = message.value.data as TwinnyProvider + if (message.data) { + const provider = message.data as TwinnyProvider setChatProvider(provider) } } if (message?.type === PROVIDER_EVENT_NAME.getActiveFimProvider) { - if (message.value.data) { - const provider = message.value.data as TwinnyProvider + if (message.data) { + const provider = message.data as TwinnyProvider setFimProvider(provider) } } if (message?.type === PROVIDER_EVENT_NAME.getActiveEmbeddingsProvider) { - if (message.value.data) { - const provider = message.value.data as TwinnyProvider + if (message.data) { + const provider = message.data as TwinnyProvider setEmbeddingProvider(provider) } } @@ -409,33 +413,6 @@ export const useProviders = () => { } } -export const useConfigurationSetting = (key: string) => { - const [configurationSetting, setConfigurationSettings] = useState< - string | boolean | number - >() - - const handler = (event: MessageEvent) => { - const message: ServerMessage = event.data - if ( - message?.type === EVENT_NAME.twinnyGetConfigValue && - message.value.type === key - ) { - setConfigurationSettings(message?.value.data) - } - } - - useEffect(() => { - global.vscode.postMessage({ - type: EVENT_NAME.twinnyGetConfigValue, - key - }) - window.addEventListener("message", handler) - return () => window.removeEventListener("message", handler) - }, [key]) - - return { configurationSetting } -} - export const useConversationHistory = () => { const [conversations, setConversations] = useState< Record @@ -483,13 +460,15 @@ export const useConversationHistory = () => { } const handler = (event: MessageEvent) => { - const message = event.data - if (message.value?.data) { + const message = event.data as ServerMessage< + Record | Conversation + > + if (message?.data) { if (message?.type === CONVERSATION_EVENT_NAME.getConversations) { - setConversations(message.value.data) + setConversations(message.data as Record) } if (message?.type === CONVERSATION_EVENT_NAME.setActiveConversation) { - setConversation(message.value.data) + setConversation(message.data as Conversation) } } } @@ -518,7 +497,7 @@ export const useOllamaModels = () => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnyFetchOllamaModels) { - setModels(message?.value.data) + setModels(message?.data) } return () => window.removeEventListener("message", handler) } @@ -556,7 +535,7 @@ export const useFilePaths = () => { !filePaths.current?.length && message?.type === EVENT_NAME.twinnyFileListResponse ) { - filePaths.current = message.value.data // response sets the list from vscode backend + filePaths.current = message.data } } @@ -684,6 +663,8 @@ export const useSymmetryConnection = () => { EXTENSION_SESSION_NAME.twinnySymmetryConnection ) + console.log(symmetryConnectionSession) + const { context: symmetryProviderStatus, setContext: setSymmetryProviderStatus @@ -731,19 +712,18 @@ export const useSymmetryConnection = () => { > = event.data if (message?.type === EVENT_NAME.twinnyConnectedToSymmetry) { setConnecting(false) - setSymmetryConnectionSession(message.value.data as SymmetryConnection) + setSymmetryConnectionSession(message.data as SymmetryConnection) } if (message?.type === EVENT_NAME.twinnyDisconnectedFromSymmetry) { setConnecting(false) setSymmetryConnectionSession(undefined) } if (message?.type === EVENT_NAME.twinnySendSymmetryMessage) { - setSymmetryProviderStatus(message?.value.data as string) + setSymmetryProviderStatus(message?.data as string) } if (message?.type === EVENT_NAME.twinnySymmetryModels) { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - setModels(message?.value.data as unknown as SymmetryModelProvider[]) + setModels(message?.data as SymmetryModelProvider[]) } return () => window.removeEventListener("message", handler) } @@ -785,7 +765,7 @@ export const useSymmetryConnection = () => { export const useLocale = () => { const [locale, setLocale] = useState("en") - const [ renderKey, setRenderKey ] = useState(0) + const [renderKey, setRenderKey] = useState(0) useEffect(() => { const messageHandler = (event: MessageEvent) => { if (event.data.type === EVENT_NAME.twinnySetLocale) { diff --git a/src/webview/main.tsx b/src/webview/main.tsx index f209510a..937d9eb4 100644 --- a/src/webview/main.tsx +++ b/src/webview/main.tsx @@ -34,7 +34,7 @@ export const Main = ({ fullScreen }: MainProps) => { const handler = (event: MessageEvent) => { const message: ServerMessage = event.data if (message?.type === EVENT_NAME.twinnySetTab) { - setTab(message?.value.data) + setTab(message?.data) } return () => window.removeEventListener("message", handler) } diff --git a/src/webview/message.tsx b/src/webview/message.tsx index 47bf1f16..e48caaeb 100644 --- a/src/webview/message.tsx +++ b/src/webview/message.tsx @@ -8,9 +8,10 @@ import remarkGfm from "remark-gfm" import { Markdown as TiptapMarkdown } from "tiptap-markdown" import { ASSISTANT, TWINNY, YOU } from "../common/constants" -import { Message as MessageType, ThemeType } from "../common/types" +import { Message as MessageType, ThemeType, Tool } from "../common/types" -import CodeBlock from "./code-block" +import { CodeBlock } from "./code-block" +import { ToolExecution } from "./tool-execution" import styles from "./styles/index.module.css" @@ -24,6 +25,9 @@ interface MessageProps { onRegenerate?: (index: number) => void onUpdate?: (message: string, index: number) => void theme: ThemeType | undefined + onRejectTool?: (message: MessageType, tool: Tool) => void + onRunTool?: (message: MessageType, tool: Tool) => void + onRunAllTools?: (message: MessageType) => void } const CustomKeyMap = Extension.create({ @@ -42,9 +46,9 @@ const CustomKeyMap = Extension.create({ "Shift-Enter": ({ editor }) => { editor.commands.insertContent("\n") return true - }, + } } - }, + } }) const MemoizedCodeBlock = React.memo(CodeBlock) @@ -61,7 +65,20 @@ export const Message: React.FC = React.memo( onRegenerate, onUpdate, theme, + onRunTool, + onRejectTool, + onRunAllTools }) => { + if (message?.tools) + return ( + + ) + const { t } = useTranslation() const [editing, setEditing] = React.useState(false) @@ -94,11 +111,11 @@ export const Message: React.FC = React.memo( extensions: [ StarterKit, CustomKeyMap.configure({ - handleToggleSave, + handleToggleSave }), - TiptapMarkdown, + TiptapMarkdown ], - content: message?.content, + content: message?.content }, [index] ) @@ -133,7 +150,7 @@ export const Message: React.FC = React.memo( const markdownComponents = useMemo( () => ({ pre: renderPre, - code: renderCode, + code: renderCode }), [renderPre, renderCode] ) @@ -202,10 +219,7 @@ export const Message: React.FC = React.memo(
{editing ? ( - + ) : ( { @@ -16,9 +20,12 @@ export const ProviderSelect = () => { setActiveFimProvider, providers, chatProvider, - fimProvider, + fimProvider } = useProviders() + const { context: enableTools = false, setContext: setEnableTools } = + useGlobalContext(EXTENSION_CONTEXT_NAME.twinnyEnableTools) + const handleChangeChatProvider = (e: unknown): void => { const event = e as React.ChangeEvent const value = event.target.value @@ -36,9 +43,7 @@ export const ProviderSelect = () => { return (
-
- {t("chat")} -
+
{t("chat")}
{
-
- {t("fim")} -
+
{t("fim")}
{ ))}
+
+
+ +
+
) } diff --git a/src/webview/styles/providers.module.css b/src/webview/styles/providers.module.css index 2325514e..71ab1f92 100644 --- a/src/webview/styles/providers.module.css +++ b/src/webview/styles/providers.module.css @@ -103,3 +103,7 @@ .divider { margin: 10px 0; } + +.enableTools { + margin-top: 5px; +} diff --git a/src/webview/styles/tool-execution.module.css b/src/webview/styles/tool-execution.module.css new file mode 100644 index 00000000..b92e74cd --- /dev/null +++ b/src/webview/styles/tool-execution.module.css @@ -0,0 +1,324 @@ +.root { + display: flex; + flex-direction: column; + color: var(--vscode-foreground); + font-family: var(--vscode-font-family); +} + +.headerBar { + padding: 12px 16px; + background: var(--vscode-sideBarSectionHeader-background); + display: flex; + align-items: center; + justify-content: space-between; + border-bottom: 1px solid var(--vscode-panel-border); + position: sticky; + top: 0; + z-index: 2; + backdrop-filter: blur(8px); +} + +.headerTitle { + display: flex; + align-items: center; + gap: 8px; + font-size: 13px; + font-weight: 600; + color: var(--vscode-sideBarTitle-foreground); + letter-spacing: 0.1px; +} + +.statusBadge { + display: inline-flex; + align-items: center; + justify-content: center; + min-width: 20px; + height: 20px; + padding: 0 6px; + border-radius: 10px; + font-size: 11px; + font-weight: 500; + line-height: 1; + letter-spacing: 0.2px; + transition: all 0.2s ease; +} + +.statusBadge[data-status="running"] { + background: var(--vscode-progressBar-background); + color: var(--vscode-foreground); + box-shadow: 0 0 0 1px rgba(255, 255, 255, 0.1); +} + +.statusBadge[data-status="error"] { + background: var(--vscode-errorForeground); + color: var(--vscode-editor-background); + box-shadow: 0 0 0 1px rgba(255, 255, 255, 0.1); +} + +.headerControls { + display: flex; + gap: 8px; +} + +.runAllButton { + display: inline-flex; + align-items: center; + gap: 6px; + font-size: 12px; + line-height: 1; + background: var(--vscode-button-background); + color: var(--vscode-button-foreground); + transition: all 0.2s ease; +} + +.runAllButton:hover:not(:disabled) { + background: var(--vscode-button-hoverBackground); +} + +.runAllButton:disabled { + opacity: 0.5; + cursor: not-allowed; +} + +.toolsList { + flex: 1; + overflow-y: auto; + padding: 12px 0; +} + +.toolGroup { + margin-bottom: 20px; +} + +.toolGroupHeader { + display: flex; + align-items: center; + gap: 8px; + padding: 6px 16px; + font-size: 12px; + text-transform: capitalize; + color: var(--vscode-foreground); + font-weight: 500; + letter-spacing: 0.1px; +} + +.toolGroupHeader[data-status="error"] { + color: var(--vscode-errorForeground); +} + +.toolCount { + font-size: 11px; + opacity: 0.7; + font-weight: normal; +} + +.toolItem { + position: relative; + margin: 6px 12px; + border-radius: 6px; + background: var(--vscode-editor-background); + box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1); + transition: all 0.2s ease; + border: 1px solid var(--vscode-panel-border); +} + +.toolItem:hover { + box-shadow: 0 2px 6px rgba(0, 0, 0, 0.15); +} + +.toolItem[data-status="running"] { + border-color: var(--vscode-progressBar-background); +} + +.toolItem[data-status="error"] { + border-color: var(--vscode-errorForeground); +} + +.toolItem[data-status="rejected"] { + border-color: var(--vscode-disabledForeground); + opacity: 0.7; +} + +.toolRow { + display: flex; + align-items: center; + justify-content: space-between; + padding: 10px 14px; + border-bottom: 1px solid var(--vscode-panel-border); + transition: background 0.2s ease; +} + +.toolRow:hover { + background: var(--vscode-list-hoverBackground); +} + +.toolName { + font-size: 13px; + font-weight: 600; + color: var(--vscode-foreground); +} + +.toolActions { + display: flex; + gap: 6px; +} + +.actionButton { + width: 28px; + height: 28px; + background: none; + border: 1px solid transparent; + padding: 0; + border-radius: 4px; + color: var(--vscode-foreground); + cursor: pointer; + transition: all 0.2s ease; + display: flex; + align-items: center; + justify-content: center; +} + +.actionButton:hover:not(:disabled) { + background: var(--vscode-toolbar-hoverBackground); + border-color: var(--vscode-panel-border); +} + +.actionButton:disabled { + opacity: 0.4; + cursor: not-allowed; +} + +.toolContent { + padding: 12px 14px; + font-size: 12px; + line-height: 1.5; +} + +.argumentsContainer { + margin-bottom: 12px; +} + +.argumentRow { + margin-bottom: 16px; + background: var(--vscode-textBlockQuote-background); + border-radius: 4px; + padding: 12px; +} + +.argumentHeader { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 8px; +} + +.argumentKey { + font-weight: 600; + color: var(--vscode-symbolIcon-propertyForeground); + font-size: 12px; + text-transform: lowercase; +} + +.argumentType { + font-size: 11px; + color: var(--vscode-foreground); + opacity: 0.7; + background: var(--vscode-badge-background); + padding: 2px 6px; + border-radius: 3px; +} + +.argumentValue { + padding: 8px; + border-radius: 4px; + overflow-x: auto; + background: var(--vscode-editor-background); + border: 1px solid var(--vscode-panel-border); +} + +.argumentValue pre { + margin: 0; + color: var(--vscode-foreground); + white-space: pre-wrap; + font-size: 12px; + line-height: 1.5; + font-family: var(--vscode-editor-font-family); +} + +.toolError { + font-size: 12px; + color: var(--vscode-errorForeground); + margin-top: 12px; + background: var(--vscode-inputValidation-errorBackground); + padding: 8px 12px; + border-radius: 4px; + border: 1px solid var(--vscode-inputValidation-errorBorder); + line-height: 1.5; +} + +.emptyState { + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + height: 100%; + gap: 16px; + color: var(--vscode-foreground); + opacity: 0.6; + font-size: 13px; + text-align: center; + padding: 40px 20px; +} + +.emptyState .codicon { + font-size: 36px; + opacity: 0.8; +} + +/* Animation for the running state */ +.codicon-modifier-spin { + animation: spin 1.5s linear infinite; +} + +@keyframes spin { + 100% { + transform: rotate(360deg); + } +} + +/* Status icon colors and animations */ +.statusIcon { + color: var(--vscode-foreground); + font-size: 14px; + transition: all 0.2s ease; +} + +.statusIcon[data-status="running"] { + color: var(--vscode-progressBar-background); +} + +.statusIcon[data-status="success"] { + color: var(--vscode-testing-iconPassed); +} + +.statusIcon[data-status="error"] { + color: var(--vscode-errorForeground); +} + +.statusIcon[data-status="rejected"] { + color: var(--vscode-disabledForeground); +} + +/* Hover effects */ +.toolItem:hover .toolName { + color: var(--vscode-textLink-foreground); +} + +/* Smooth transitions */ +.toolItem, +.toolRow, +.actionButton, +.statusBadge, +.argumentRow { + transition: all 0.2s ease; +} diff --git a/src/webview/tool-execution.tsx b/src/webview/tool-execution.tsx new file mode 100644 index 00000000..ac86ced3 --- /dev/null +++ b/src/webview/tool-execution.tsx @@ -0,0 +1,174 @@ +import { useTranslation } from "react-i18next" +import { VSCodeButton } from "@vscode/webview-ui-toolkit/react" + +import { Message, Tool } from "../common/types" + +import styles from "./styles/tool-execution.module.css" + +const StatusIcon = ({ status }: { status: string }) => { + const iconClass = { + pending: "codicon-circle-outline", + running: "codicon-sync codicon-modifier-spin", + success: "codicon-check", + error: "codicon-error", + rejected: "codicon-trash" + }[status] + + return +} + +interface ToolExecutionProps { + message: Message + onRunTool?: (message: Message, tool: Tool) => void + onRejectTool?: (message: Message, tool: Tool) => void + onRunAllTools?: (message: Message) => void +} + +export function ToolExecution({ + message, + onRunTool, + onRejectTool, + onRunAllTools +}: ToolExecutionProps) { + const { t } = useTranslation() + const tools = message.tools || {} + + if (!Object.keys(tools)?.length) { + return ( +
+ +

{t("no-tools")}

+
+ ) + } + + const handleRunAll = (message: Message) => { + onRunAllTools?.(message) + } + + const handleRunTool = (message: Message, tool: Tool) => { + onRunTool?.(message, tool) + } + + const handleRejectTool = (message: Message, tool: Tool) => { + onRejectTool?.(message, tool) + } + + const toolsByStatus = Object.values(tools).reduce((acc, tool) => { + const status = tool.status || "pending" + acc[status] = [...(acc[status] || []), tool] + return acc + }, {} as Record) + + const runningCount = toolsByStatus.running?.length || 0 + const errorCount = toolsByStatus.error?.length || 0 + + return ( +
+ {message.content && ( +

+ {message.content} +

+ )} +
+
+ + {t("tools")} + {runningCount > 0 && ( + + {runningCount} + + )} + {errorCount > 0 && ( + + {errorCount} + + )} +
+
+ 0} + onClick={() => handleRunAll(message)} + title={t("run-all-tools")} + > + + {t("run-all")} + +
+
+ +
+ {["running", "error", "pending", "success", "rejected"].map((status) => + toolsByStatus[status]?.length ? ( +
+
+ + {t(status)} + + {toolsByStatus[status].length} + +
+ {toolsByStatus[status].map((tool) => ( +
+
+ {t(tool.name)} +
+ handleRunTool(message, tool)} + className={styles.actionButton} + disabled={tool.status === "running"} + appearance="icon" + title={t("run-tool")} + > + + + handleRejectTool(message, tool)} + className={styles.actionButton} + disabled={ + tool.status === "running" || tool.status === "success" + } + title={t("reject-tool")} + appearance="icon" + > + + +
+
+ +
+
+ {Object.entries(tool.arguments || {}).map( + ([key, value]) => ( +
+
+ {key} + + {typeof value} + +
+
+ {JSON.stringify(value, null, 2)} +
+
+ ) + )} +
+ {!!tool.error && tool.status !== "success" && ( +
{tool.error}
+ )} +
+
+ ))} +
+ ) : null + )} +
+
+ ) +} diff --git a/src/webview/utils.ts b/src/webview/utils.ts index 9b298b3a..ccf534b9 100644 --- a/src/webview/utils.ts +++ b/src/webview/utils.ts @@ -1,9 +1,8 @@ import { MentionPluginKey } from "@tiptap/extension-mention" import { Extension } from "@tiptap/react" -import { EMPTY_MESAGE } from "../common/constants" import { CodeLanguage, supportedLanguages } from "../common/languages" -import { LanguageType, ServerMessage } from "../common/types" +import { LanguageType } from "../common/types" export const getLanguageMatch = ( language: LanguageType | undefined, @@ -31,14 +30,6 @@ export const getLanguageMatch = ( return "auto" } -export const getCompletionContent = (message: ServerMessage) => { - if (message.value.error && message.value.errorMessage) { - return message.value.errorMessage - } - - return message.value.completion || EMPTY_MESAGE -} - export const kebabToSentence = (kebabStr: string) => { if (!kebabStr) { return ""