diff --git a/src/extension/chat-service.ts b/src/extension/chat-service.ts index 80841b91..69fa5dfe 100644 --- a/src/extension/chat-service.ts +++ b/src/extension/chat-service.ts @@ -560,21 +560,23 @@ export class ChatService { private async loadFileContents(files: FileItem[]): Promise { if (!files?.length) return '' - let fileContents = ''; + let fileContents = '' for (const file of files) { try { - const content = await fs.readFile(file.path, 'utf-8'); - fileContents += `File: ${file.name}\n\n${content}\n\n`; + const content = await fs.readFile(file.path, 'utf-8') + fileContents += `File: ${file.name}\n\n${content}\n\n` } catch (error) { - console.error(`Error reading file ${file.path}:`, error); + console.error(`Error reading file ${file.path}:`, error) } } - return fileContents.trim(); + return fileContents.trim() } - - public async streamChatCompletion(messages: Message[], filePaths: FileItem[]) { + public async streamChatCompletion( + messages: Message[], + filePaths: FileItem[] + ) { this._completion = '' this.sendEditorLanguage() const editor = window.activeTextEditor @@ -604,27 +606,37 @@ export class ChatService { additionalContext += `Additional Context:\n${ragContext}\n\n` } - const fileContents = await this.loadFileContents(filePaths); + const fileContents = await this.loadFileContents(filePaths) if (fileContents) { - additionalContext += `File Contents:\n${fileContents}\n\n`; + additionalContext += `File Contents:\n${fileContents}\n\n` } - const updatedMessages = [systemMessage, ...messages.slice(0, -1)] + const provider = this.getProvider() + + if (!provider) return + + const conversation = [] + + conversation.push(...messages.slice(0, -1)) + + if (!provider.modelName.includes('claude')) { + conversation.unshift(systemMessage) + } if (additionalContext) { const lastMessageContent = `${cleanedText}\n\n${additionalContext.trim()}` - updatedMessages.push({ + conversation.push({ role: USER, content: lastMessageContent }) } else { - updatedMessages.push({ + conversation.push({ ...lastMessage, content: cleanedText }) } updateLoadingMessage(this._webView, 'Thinking') - const request = this.buildStreamRequest(updatedMessages) + const request = this.buildStreamRequest(conversation) if (!request) return const { requestBody, requestOptions } = request return this.streamResponse({ requestBody, requestOptions }) @@ -678,13 +690,20 @@ export class ChatService { ? `${prompt}\n\nAdditional Context:\n${ragContext}` : prompt - const conversation: Message[] = [ - systemMessage, - { - role: USER, - content: userContent - } - ] + const provider = this.getProvider() + + if (!provider) return [] + + const conversation = [] + + conversation.push({ + role: USER, + content: userContent + }) + + if (!provider.modelName.includes('claude')) { + conversation.push(systemMessage) + } return conversation } diff --git a/src/extension/conversation-history.ts b/src/extension/conversation-history.ts index e8291a1d..1f58c729 100644 --- a/src/extension/conversation-history.ts +++ b/src/extension/conversation-history.ts @@ -21,7 +21,9 @@ import { EXTENSION_SESSION_NAME, SYMMETRY_EMITTER_KEY, SYMMETRY_DATA_MESSAGE, - TITLE_GENERATION_PROMPT_MESAGE + TITLE_GENERATION_PROMPT_MESAGE, + USER, + ASSISTANT } from '../common/constants' import { SessionManager } from './session-manager' import { SymmetryService } from './symmetry-service' @@ -157,6 +159,13 @@ export class ConversationHistory { const requestOptions = this.getRequestOptions(provider) + if (messages.length === 1 && messages[0].role === ASSISTANT) { + messages.unshift({ + role: USER, + content: 'Request to review code.' + }) + } + const requestBody = createStreamRequestBody(provider.provider, { model: provider.modelName, numPredictChat: this.config.numPredictChat, @@ -164,7 +173,7 @@ export class ConversationHistory { messages: [ ...messages, { - role: 'user', + role: USER, content: TITLE_GENERATION_PROMPT_MESAGE } ], @@ -236,7 +245,7 @@ export class ConversationHistory { conversation ) this.webView?.postMessage({ - type: CONVERSATION_EVENT_NAME.getActiveConversation, + type: CONVERSATION_EVENT_NAME.setActiveConversation, value: { data: conversation } @@ -287,7 +296,7 @@ export class ConversationHistory { messages: [ ...conversation.messages, { - role: 'user', + role: USER, content: TITLE_GENERATION_PROMPT_MESAGE } ], diff --git a/src/extension/review-service.ts b/src/extension/review-service.ts index 8e97d2d7..eb47e653 100644 --- a/src/extension/review-service.ts +++ b/src/extension/review-service.ts @@ -12,6 +12,7 @@ import { EVENT_NAME, EXTENSION_CONTEXT_NAME, GITHUB_EVENT_NAME, + USER, WEBUI_TABS } from '../common/constants' import { StreamResponse } from 'symmetry-core' @@ -134,7 +135,7 @@ export class GithubService extends ConversationHistory { const messages = [ { - role: 'user', + role: USER, content: prompt } ] @@ -181,11 +182,6 @@ export class GithubService extends ConversationHistory { return streamResponse({ body: requestBody, options: requestOptions, - onStart: (controller: AbortController) => { - this.webView?.onDidReceiveMessage(() => { - controller?.abort() - }) - }, onData: (streamResponse) => { const provider = this.getProvider() if (!provider) return diff --git a/src/webview/hooks.ts b/src/webview/hooks.ts index cd97352b..46f38ef3 100644 --- a/src/webview/hooks.ts +++ b/src/webview/hooks.ts @@ -479,6 +479,9 @@ export const useConversationHistory = () => { if (message?.type === CONVERSATION_EVENT_NAME.getConversations) { setConversations(message.value.data) } + if (message?.type === CONVERSATION_EVENT_NAME.setActiveConversation) { + setConversation(message.value.data) + } } }