diff --git a/README.en.md b/README.en.md index 9345a21900..82dceb5b87 100644 --- a/README.en.md +++ b/README.en.md @@ -60,7 +60,7 @@ _✨ Access all LLM through the standard OpenAI API format, easy to deploy & use 1. Support for multiple large models: + [x] [OpenAI ChatGPT Series Models](https://platform.openai.com/docs/guides/gpt/chat-completions-api) (Supports [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude Series Models](https://anthropic.com) - + [x] [Google PaLM2 Series Models](https://developers.generativeai.google) + + [x] [Google PaLM2 and Gemini Series Models](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan Series Models](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen Series Models](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM Series Models](https://bigmodel.cn) diff --git a/README.ja.md b/README.ja.md index 6faf9bee9b..089fc2b5f8 100644 --- a/README.ja.md +++ b/README.ja.md @@ -60,7 +60,7 @@ _✨ 標準的な OpenAI API フォーマットを通じてすべての LLM に 1. 複数の大型モデルをサポート: + [x] [OpenAI ChatGPT シリーズモデル](https://platform.openai.com/docs/guides/gpt/chat-completions-api) ([Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference) をサポート) + [x] [Anthropic Claude シリーズモデル](https://anthropic.com) - + [x] [Google PaLM2 シリーズモデル](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini シリーズモデル](https://developers.generativeai.google) + [x] [Baidu Wenxin Yiyuan シリーズモデル](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [Alibaba Tongyi Qianwen シリーズモデル](https://help.aliyun.com/document_detail/2400395.html) + [x] [Zhipu ChatGLM シリーズモデル](https://bigmodel.cn) diff --git a/README.md b/README.md index 7e6a7b383c..8a1d6caf8a 100644 --- a/README.md +++ b/README.md @@ -66,7 +66,7 @@ _✨ 通过标准的 OpenAI API 格式访问所有的大模型,开箱即用 1. 支持多种大模型: + [x] [OpenAI ChatGPT 系列模型](https://platform.openai.com/docs/guides/gpt/chat-completions-api)(支持 [Azure OpenAI API](https://learn.microsoft.com/en-us/azure/ai-services/openai/reference)) + [x] [Anthropic Claude 系列模型](https://anthropic.com) - + [x] [Google PaLM2 系列模型](https://developers.generativeai.google) + + [x] [Google PaLM2/Gemini 系列模型](https://developers.generativeai.google) + [x] [百度文心一言系列模型](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) + [x] [阿里通义千问系列模型](https://help.aliyun.com/document_detail/2400395.html) + [x] [讯飞星火认知大模型](https://www.xfyun.cn/doc/spark/Web.html) diff --git a/common/constants.go b/common/constants.go index f6860f67ba..60700ec82f 100644 --- a/common/constants.go +++ b/common/constants.go @@ -187,6 +187,7 @@ const ( ChannelTypeAIProxyLibrary = 21 ChannelTypeFastGPT = 22 ChannelTypeTencent = 23 + ChannelTypeGemini = 24 ) var ChannelBaseURLs = []string{ @@ -214,4 +215,5 @@ var ChannelBaseURLs = []string{ "https://api.aiproxy.io", // 21 "https://fastgpt.run/api/openapi", // 22 "https://hunyuan.cloud.tencent.com", //23 + "", //24 } diff --git a/common/model-ratio.go b/common/model-ratio.go index ccbc05ddb5..c054fa5f66 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -83,6 +83,7 @@ var ModelRatio = map[string]float64{ "ERNIE-Bot-4": 8.572, // ¥0.12 / 1k tokens "Embedding-V1": 0.1429, // ¥0.002 / 1k tokens "PaLM-2": 1, + "gemini-pro": 1, // $0.00025 / 1k characters -> $0.001 / 1k tokens "chatglm_turbo": 0.3572, // ¥0.005 / 1k tokens "chatglm_pro": 0.7143, // ¥0.01 / 1k tokens "chatglm_std": 0.3572, // ¥0.005 / 1k tokens diff --git a/controller/channel-test.go b/controller/channel-test.go index bba9a657ab..3aaa48972b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -20,6 +20,8 @@ func testChannel(channel *model.Channel, request ChatRequest) (err error, openai switch channel.Type { case common.ChannelTypePaLM: fallthrough + case common.ChannelTypeGemini: + fallthrough case common.ChannelTypeAnthropic: fallthrough case common.ChannelTypeBaidu: diff --git a/controller/model.go b/controller/model.go index 8f79524d73..5c8aebc05a 100644 --- a/controller/model.go +++ b/controller/model.go @@ -423,6 +423,15 @@ func init() { Root: "PaLM-2", Parent: nil, }, + { + Id: "gemini-pro", + Object: "model", + Created: 1677649963, + OwnedBy: "google", + Permission: permission, + Root: "gemini-pro", + Parent: nil, + }, { Id: "chatglm_turbo", Object: "model", diff --git a/controller/relay-gemini.go b/controller/relay-gemini.go new file mode 100644 index 0000000000..455e30d898 --- /dev/null +++ b/controller/relay-gemini.go @@ -0,0 +1,281 @@ +package controller + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "one-api/common" + + "github.com/gin-gonic/gin" +) + +type GeminiChatRequest struct { + Contents []GeminiChatContent `json:"contents"` + SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` + GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` + Tools []GeminiChatTools `json:"tools,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiChatContent struct { + Role string `json:"role,omitempty"` + Parts []GeminiPart `json:"parts"` +} + +type GeminiChatSafetySettings struct { + Category string `json:"category"` + Threshold string `json:"threshold"` +} + +type GeminiChatTools struct { + FunctionDeclarations any `json:"functionDeclarations,omitempty"` +} + +type GeminiChatGenerationConfig struct { + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"topP,omitempty"` + TopK float64 `json:"topK,omitempty"` + MaxOutputTokens int `json:"maxOutputTokens,omitempty"` + CandidateCount int `json:"candidateCount,omitempty"` + StopSequences []string `json:"stopSequences,omitempty"` +} + +// Setting safety to the lowest possible values since Gemini is already powerless enough +func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { + geminiRequest := GeminiChatRequest{ + Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), + //SafetySettings: []GeminiChatSafetySettings{ + // { + // Category: "HARM_CATEGORY_HARASSMENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_HATE_SPEECH", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + // { + // Category: "HARM_CATEGORY_DANGEROUS_CONTENT", + // Threshold: "BLOCK_ONLY_HIGH", + // }, + //}, + GenerationConfig: GeminiChatGenerationConfig{ + Temperature: textRequest.Temperature, + TopP: textRequest.TopP, + MaxOutputTokens: textRequest.MaxTokens, + }, + } + if textRequest.Functions != nil { + geminiRequest.Tools = []GeminiChatTools{ + { + FunctionDeclarations: textRequest.Functions, + }, + } + } + shouldAddDummyModelMessage := false + for _, message := range textRequest.Messages { + content := GeminiChatContent{ + Role: message.Role, + Parts: []GeminiPart{ + { + Text: message.StringContent(), + }, + }, + } + // there's no assistant role in gemini and API shall vomit if Role is not user or model + if content.Role == "assistant" { + content.Role = "model" + } + // Converting system prompt to prompt from user for the same reason + if content.Role == "system" { + content.Role = "user" + shouldAddDummyModelMessage = true + } + geminiRequest.Contents = append(geminiRequest.Contents, content) + + // If a system message is the last message, we need to add a dummy model message to make gemini happy + if shouldAddDummyModelMessage { + geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ + Role: "model", + Parts: []GeminiPart{ + { + Text: "ok", + }, + }, + }) + shouldAddDummyModelMessage = false + } + } + + return &geminiRequest +} + +type GeminiChatResponse struct { + Candidates []GeminiChatCandidate `json:"candidates"` + PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` +} + +type GeminiChatCandidate struct { + Content GeminiChatContent `json:"content"` + FinishReason string `json:"finishReason"` + Index int64 `json:"index"` + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +type GeminiChatSafetyRating struct { + Category string `json:"category"` + Probability string `json:"probability"` +} + +type GeminiChatPromptFeedback struct { + SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` +} + +func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { + fullTextResponse := OpenAITextResponse{ + Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), + Object: "chat.completion", + Created: common.GetTimestamp(), + Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), + } + for i, candidate := range response.Candidates { + choice := OpenAITextResponseChoice{ + Index: i, + Message: Message{ + Role: "assistant", + Content: candidate.Content.Parts[0].Text, + }, + FinishReason: stopFinishReason, + } + fullTextResponse.Choices = append(fullTextResponse.Choices, choice) + } + return &fullTextResponse +} + +func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { + var choice ChatCompletionsStreamResponseChoice + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text + } + choice.FinishReason = &stopFinishReason + var response ChatCompletionsStreamResponse + response.Object = "chat.completion.chunk" + response.Model = "gemini" + response.Choices = []ChatCompletionsStreamResponseChoice{choice} + return &response +} + +func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { + responseText := "" + responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) + createdTime := common.GetTimestamp() + dataChan := make(chan string) + stopChan := make(chan bool) + go func() { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + common.SysError("error reading stream response: " + err.Error()) + stopChan <- true + return + } + err = resp.Body.Close() + if err != nil { + common.SysError("error closing stream response: " + err.Error()) + stopChan <- true + return + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + common.SysError("error unmarshalling stream response: " + err.Error()) + stopChan <- true + return + } + fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) + fullTextResponse.Id = responseId + fullTextResponse.Created = createdTime + if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { + responseText += geminiResponse.Candidates[0].Content.Parts[0].Text + } + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + common.SysError("error marshalling stream response: " + err.Error()) + stopChan <- true + return + } + dataChan <- string(jsonResponse) + stopChan <- true + }() + setEventStreamHeaders(c) + c.Stream(func(w io.Writer) bool { + select { + case data := <-dataChan: + c.Render(-1, common.CustomEvent{Data: "data: " + data}) + return true + case <-stopChan: + c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) + return false + } + }) + err := resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" + } + return nil, responseText +} + +func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { + responseBody, err := io.ReadAll(resp.Body) + if err != nil { + return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + err = resp.Body.Close() + if err != nil { + return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil + } + var geminiResponse GeminiChatResponse + err = json.Unmarshal(responseBody, &geminiResponse) + if err != nil { + return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + if len(geminiResponse.Candidates) == 0 { + return &OpenAIErrorWithStatusCode{ + OpenAIError: OpenAIError{ + Message: "No candidates returned", + Type: "server_error", + Param: "", + Code: 500, + }, + StatusCode: resp.StatusCode, + }, nil + } + fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) + completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) + usage := Usage{ + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + TotalTokens: promptTokens + completionTokens, + } + fullTextResponse.Usage = usage + jsonResponse, err := json.Marshal(fullTextResponse) + if err != nil { + return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil + } + c.Writer.Header().Set("Content-Type", "application/json") + c.Writer.WriteHeader(resp.StatusCode) + _, err = c.Writer.Write(jsonResponse) + return nil, &usage +} diff --git a/controller/relay-text.go b/controller/relay-text.go index a69c7f8bd9..211a34b3be 100644 --- a/controller/relay-text.go +++ b/controller/relay-text.go @@ -27,6 +27,7 @@ const ( APITypeXunfei APITypeAIProxyLibrary APITypeTencent + APITypeGemini ) var httpClient *http.Client @@ -118,6 +119,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiType = APITypeAIProxyLibrary case common.ChannelTypeTencent: apiType = APITypeTencent + case common.ChannelTypeGemini: + apiType = APITypeGemini } baseURL := common.ChannelBaseURLs[channelType] requestURL := c.Request.URL.String() @@ -177,6 +180,24 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { apiKey := c.Request.Header.Get("Authorization") apiKey = strings.TrimPrefix(apiKey, "Bearer ") fullRequestURL += "?key=" + apiKey + case APITypeGemini: + requestBaseURL := "https://generativelanguage.googleapis.com" + if baseURL != "" { + requestBaseURL = baseURL + } + version := "v1" + if c.GetString("api_version") != "" { + version = c.GetString("api_version") + } + action := "generateContent" + // actually gemini does not support stream, it's fake + //if textRequest.Stream { + // action = "streamGenerateContent" + //} + fullRequestURL = fmt.Sprintf("%s/%s/models/%s:%s", requestBaseURL, version, textRequest.Model, action) + apiKey := c.Request.Header.Get("Authorization") + apiKey = strings.TrimPrefix(apiKey, "Bearer ") + fullRequestURL += "?key=" + apiKey case APITypeZhipu: method := "invoke" if textRequest.Stream { @@ -274,6 +295,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case APITypeGemini: + geminiChatRequest := requestOpenAI2Gemini(textRequest) + jsonStr, err := json.Marshal(geminiChatRequest) + if err != nil { + return errorWrapper(err, "marshal_text_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) case APITypeZhipu: zhipuRequest := requestOpenAI2Zhipu(textRequest) jsonStr, err := json.Marshal(zhipuRequest) @@ -367,6 +395,8 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { req.Header.Set("Authorization", apiKey) case APITypePaLM: // do not set Authorization header + case APITypeGemini: + // do not set Authorization header default: req.Header.Set("Authorization", "Bearer "+apiKey) } @@ -527,6 +557,25 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode { } return nil } + case APITypeGemini: + if textRequest.Stream { + err, responseText := geminiChatStreamHandler(c, resp) + if err != nil { + return err + } + textResponse.Usage.PromptTokens = promptTokens + textResponse.Usage.CompletionTokens = countTokenText(responseText, textRequest.Model) + return nil + } else { + err, usage := geminiChatHandler(c, resp, promptTokens, textRequest.Model) + if err != nil { + return err + } + if usage != nil { + textResponse.Usage = *usage + } + return nil + } case APITypeZhipu: if isStream { err, usage := zhipuStreamHandler(c, resp) diff --git a/middleware/distributor.go b/middleware/distributor.go index 8be986c916..81338130b0 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -87,6 +87,8 @@ func Distribute() func(c *gin.Context) { c.Set("api_version", channel.Other) case common.ChannelTypeXunfei: c.Set("api_version", channel.Other) + case common.ChannelTypeGemini: + c.Set("api_version", channel.Other) case common.ChannelTypeAIProxyLibrary: c.Set("library_id", channel.Other) case common.ChannelTypeAli: diff --git a/web/src/constants/channel.constants.js b/web/src/constants/channel.constants.js index 7640774557..264dbefb3e 100644 --- a/web/src/constants/channel.constants.js +++ b/web/src/constants/channel.constants.js @@ -3,6 +3,7 @@ export const CHANNEL_OPTIONS = [ { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, diff --git a/web/src/pages/Channel/EditChannel.js b/web/src/pages/Channel/EditChannel.js index 62e8a1550b..114e593336 100644 --- a/web/src/pages/Channel/EditChannel.js +++ b/web/src/pages/Channel/EditChannel.js @@ -83,6 +83,9 @@ const EditChannel = () => { case 23: localModels = ['hunyuan']; break; + case 24: + localModels = ['gemini-pro']; + break; } setInputs((inputs) => ({ ...inputs, models: localModels })); }