diff --git a/package-lock.json b/package-lock.json index be3cd729..84b37bb1 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "twinny", - "version": "3.16.4", + "version": "3.16.5", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "twinny", - "version": "3.16.4", + "version": "3.16.5", "cpu": [ "x64", "arm64" diff --git a/package.json b/package.json index e9eee23b..993cb787 100644 --- a/package.json +++ b/package.json @@ -2,7 +2,7 @@ "name": "twinny", "displayName": "twinny - AI Code Completion and Chat", "description": "Locally hosted AI code completion plugin for vscode", - "version": "3.16.4", + "version": "3.16.5", "icon": "assets/icon.png", "keywords": [ "code-inference", diff --git a/src/common/constants.ts b/src/common/constants.ts index 0783ddb6..0a0484b1 100644 --- a/src/common/constants.ts +++ b/src/common/constants.ts @@ -195,6 +195,7 @@ export const FIM_TEMPLATE_FORMAT = { codegemma: 'codegemma', codellama: 'codellama', codeqwen: 'codeqwen', + codestral: 'codestral', custom: 'custom-template', deepseek: 'deepseek', llama: 'llama', @@ -216,6 +217,8 @@ export const STOP_STARCODER = ['<|endoftext|>', ''] export const STOP_CODEGEMMA = ['<|file_separator|>', '<|end_of_turn|>', ''] +export const STOP_CODESTRAL = ['[PREFIX]', '[SUFFIX]'] + export const DEFAULT_TEMPLATE_NAMES = defaultTemplates.map(({ name }) => name) export const DEFAULT_ACTION_TEMPLATES = [ diff --git a/src/extension/fim-templates.ts b/src/extension/fim-templates.ts index d7ad0e4a..c704b8af 100644 --- a/src/extension/fim-templates.ts +++ b/src/extension/fim-templates.ts @@ -3,7 +3,8 @@ import { STOP_DEEPSEEK, STOP_LLAMA, STOP_STARCODER, - STOP_CODEGEMMA + STOP_CODEGEMMA, + STOP_CODESTRAL } from '../common/constants' import { supportedLanguages } from '../common/languages' import { FimPromptTemplate } from '../common/types' @@ -61,6 +62,23 @@ export const getFimPromptTemplateDeepseek = ({ return `<|fim▁begin|>${fileContext}\n${heading}${prefix}<|fim▁hole|>${suffix}<|fim▁end|>` } +export const getFimPromptTemplateCodestral = ({ + context, + header, + fileContextEnabled, + prefixSuffix, + language +}: FimPromptTemplate) => { + const { prefix, suffix } = prefixSuffix + const { fileContext, heading } = getFileContext( + fileContextEnabled, + context, + language, + header + ) + return `${fileContext}\n\n[SUFFIX]${suffix}[PREFIX]${heading}${prefix}` +} + export const getFimPromptTemplateOther = ({ context, header, @@ -90,6 +108,10 @@ function getFimTemplateAuto(fimModel: string, args: FimPromptTemplate) { return getFimPromptTemplateDeepseek(args) } + if (fimModel.includes(FIM_TEMPLATE_FORMAT.codestral)) { + return getFimPromptTemplateCodestral(args) + } + if ( fimModel.includes(FIM_TEMPLATE_FORMAT.stableCode) || fimModel.includes(FIM_TEMPLATE_FORMAT.starcoder) || @@ -111,6 +133,10 @@ function getFimTemplateChosen(format: string, args: FimPromptTemplate) { return getFimPromptTemplateDeepseek(args) } + if (format === FIM_TEMPLATE_FORMAT.codestral) { + return getFimPromptTemplateCodestral(args) + } + if ( format === FIM_TEMPLATE_FORMAT.stableCode || format === FIM_TEMPLATE_FORMAT.starcoder || @@ -157,6 +183,10 @@ export const getStopWordsAuto = (fimModel: string) => { return STOP_CODEGEMMA } + if (fimModel.includes(FIM_TEMPLATE_FORMAT.codestral)) { + return STOP_CODESTRAL + } + return STOP_LLAMA } @@ -169,6 +199,7 @@ export const getStopWordsChosen = (format: string) => { ) return STOP_STARCODER if (format === FIM_TEMPLATE_FORMAT.codegemma) return STOP_CODEGEMMA + if (format === FIM_TEMPLATE_FORMAT.codestral) return STOP_CODESTRAL return STOP_LLAMA } diff --git a/src/extension/provider-options.ts b/src/extension/provider-options.ts index 6c1976e5..19bb0803 100644 --- a/src/extension/provider-options.ts +++ b/src/extension/provider-options.ts @@ -76,7 +76,7 @@ export function createStreamRequestBodyFim( prompt, stream: true, temperature: options.temperature, - n_predict: options.numPredictFim + max_tokens: options.numPredictFim } case apiProviders.LlamaCpp: case apiProviders.Oobabooga: @@ -84,7 +84,7 @@ export function createStreamRequestBodyFim( prompt, stream: true, temperature: options.temperature, - n_predict: options.numPredictFim + max_tokens: options.numPredictFim } case apiProviders.LiteLLM: return {