From c936198ac8cf3c51bb839d11efea0bfd06fedc57 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 22 Jul 2024 22:51:19 +0800 Subject: [PATCH] feat: add Proxy channel type and relay mode (#1678) Add the Proxy channel type and relay mode to support proxying requests to custom upstream services. --- controller/relay.go | 5 + middleware/auth.go | 6 + relay/adaptor.go | 3 + relay/adaptor/proxy/adaptor.go | 89 ++++++++++++ relay/apitype/define.go | 1 + relay/channeltype/define.go | 1 + relay/channeltype/helper.go | 2 + relay/channeltype/url.go | 1 + relay/controller/proxy.go | 41 ++++++ relay/meta/relay_meta.go | 11 +- relay/relaymode/define.go | 2 + relay/relaymode/helper.go | 2 + router/relay.go | 1 + web/air/src/constants/channel.constants.js | 14 +- web/berry/src/constants/ChannelConstants.js | 6 + .../src/constants/channel.constants.js | 85 ++++++------ web/default/src/pages/Channel/EditChannel.js | 128 ++++++++++-------- 17 files changed, 292 insertions(+), 106 deletions(-) create mode 100644 relay/adaptor/proxy/adaptor.go create mode 100644 relay/controller/proxy.go diff --git a/controller/relay.go b/controller/relay.go index 932e023b41..49358e2597 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { fallthrough case relaymode.AudioTranscription: err = controller.RelayAudioHelper(c, relayMode) + case relaymode.Proxy: + err = controller.RelayProxyHelper(c, relayMode) default: err = controller.RelayTextHelper(c) } @@ -85,12 +87,15 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) + // BUG: bizErr is in race condition go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) } if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" } + + // BUG: bizErr is in race condition bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId) c.JSON(bizErr.StatusCode, gin.H{ "error": bizErr.Error, diff --git a/middleware/auth.go b/middleware/auth.go index 5cba490a09..e00198384e 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) { return } } + + // set channel id for proxy relay + if channelId := c.Param("channelid"); channelId != "" { + c.Set(ctxkey.SpecificChannelId, channelId) + } + c.Next() } } diff --git a/relay/adaptor.go b/relay/adaptor.go index 7fc83651a3..711e63bdc6 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -15,6 +15,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/ollama" "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" + "github.com/songquanpeng/one-api/relay/adaptor/proxy" "github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" @@ -58,6 +59,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &deepl.Adaptor{} case apitype.VertexAI: return &vertexai.Adaptor{} + case apitype.Proxy: + return &proxy.Adaptor{} } return nil } diff --git a/relay/adaptor/proxy/adaptor.go b/relay/adaptor/proxy/adaptor.go new file mode 100644 index 0000000000..670c76289a --- /dev/null +++ b/relay/adaptor/proxy/adaptor.go @@ -0,0 +1,89 @@ +package proxy + +import ( + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/relay/adaptor" + channelhelper "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +var _ adaptor.Adaptor = new(Adaptor) + +const channelName = "proxy" + +type Adaptor struct{} + +func (a *Adaptor) Init(meta *meta.Meta) { +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("notimplement") +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + for k, v := range resp.Header { + for _, vv := range v { + c.Writer.Header().Set(k, vv) + } + } + + c.Writer.WriteHeader(resp.StatusCode) + if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil { + return nil, &relaymodel.ErrorWithStatusCode{ + StatusCode: http.StatusInternalServerError, + Error: relaymodel.Error{ + Message: gerr.Error(), + }, + } + } + + return nil, nil +} + +func (a *Adaptor) GetModelList() (models []string) { + return nil +} + +func (a *Adaptor) GetChannelName() string { + return channelName +} + +// GetRequestURL remove static prefix, and return the real request url to the upstream service +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId) + return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil + +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + for k, v := range c.Request.Header { + req.Header.Set(k, v[0]) + } + + // remove unnecessary headers + req.Header.Del("Host") + req.Header.Del("Content-Length") + req.Header.Del("Accept-Encoding") + req.Header.Del("Connection") + + // set authorization header + req.Header.Set("Authorization", meta.APIKey) + + return nil +} + +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.Errorf("not implement") +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + return channelhelper.DoRequestHelper(a, c, meta, requestBody) +} diff --git a/relay/apitype/define.go b/relay/apitype/define.go index 212a1b6b1c..cf7b6a0d2b 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -18,6 +18,7 @@ const ( Cloudflare DeepL VertexAI + Proxy Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index d1e7fcef07..e3b0c98ef0 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -44,5 +44,6 @@ const ( Doubao Novita VertextAI + Proxy Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index 67270a6730..fae3357f8c 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { apiType = apitype.DeepL case VertextAI: apiType = apitype.VertexAI + case Proxy: + apiType = apitype.Proxy } return apiType diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index 20a24ab029..b5026713b9 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -44,6 +44,7 @@ var ChannelBaseURLs = []string{ "https://ark.cn-beijing.volces.com", // 40 "https://api.novita.ai/v3/openai", // 41 "", // 42 + "", // 43 } func init() { diff --git a/relay/controller/proxy.go b/relay/controller/proxy.go new file mode 100644 index 0000000000..dcaf15a979 --- /dev/null +++ b/relay/controller/proxy.go @@ -0,0 +1,41 @@ +// Package controller is a package for handling the relay controller +package controller + +import ( + "fmt" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + relaymodel "github.com/songquanpeng/one-api/relay/model" +) + +// RelayProxyHelper is a helper function to proxy the request to the upstream service +func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { + ctx := c.Request.Context() + meta := meta.GetByContext(c) + + adaptor := relay.GetAdaptor(meta.APIType) + if adaptor == nil { + return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest) + } + adaptor.Init(meta) + + resp, err := adaptor.DoRequest(c, meta, c.Request.Body) + if err != nil { + logger.Errorf(ctx, "DoRequest failed: %s", err.Error()) + return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError) + } + + // do response + _, respErr := adaptor.DoResponse(c, resp, meta) + if respErr != nil { + logger.Errorf(ctx, "respErr is not nil: %+v", respErr) + return respErr + } + + return nil +} diff --git a/relay/meta/relay_meta.go b/relay/meta/relay_meta.go index 04977db585..b1761e9a7c 100644 --- a/relay/meta/relay_meta.go +++ b/relay/meta/relay_meta.go @@ -18,11 +18,12 @@ type Meta struct { UserId int Group string ModelMapping map[string]string - BaseURL string - APIKey string - APIType int - Config model.ChannelConfig - IsStream bool + // BaseURL is the proxy url set in the channel config + BaseURL string + APIKey string + APIType int + Config model.ChannelConfig + IsStream bool // OriginModelName is the model name from the raw user request OriginModelName string // ActualModelName is the model name after mapping diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index 96d094382c..aa7712057d 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -11,4 +11,6 @@ const ( AudioSpeech AudioTranscription AudioTranslation + // Proxy is a special relay mode for proxying requests to custom upstream + Proxy ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 926dd42ec4..2cde5b8510 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,6 +24,8 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/oneapi/proxy") { + relayMode = Proxy } return relayMode } diff --git a/router/relay.go b/router/relay.go index 65072c869b..094ea5fb51 100644 --- a/router/relay.go +++ b/router/relay.go @@ -19,6 +19,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router := router.Group("/v1") relayV1Router.Use(middleware.RelayPanicRecover(), middleware.TokenAuth(), middleware.Distribute()) { + relayV1Router.Any("/oneapi/proxy/:channelid/*target", controller.Relay) relayV1Router.POST("/completions", controller.Relay) relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index 4bf035f977..18293f5f67 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -1,10 +1,13 @@ export const CHANNEL_OPTIONS = [ { key: 1, text: 'OpenAI', value: 1, color: 'green' }, { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 33, text: 'AWS', value: 33, color: 'black' }, { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, + { key: 41, text: 'Novita', value: 41, color: 'purple' }, + { key: 40, text: '字节跳动豆包', value: 40, color: 'blue' }, { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, @@ -17,6 +20,15 @@ export const CHANNEL_OPTIONS = [ { key: 29, text: 'Groq', value: 29, color: 'orange' }, { key: 30, text: 'Ollama', value: 30, color: 'black' }, { key: 31, text: '零一万物', value: 31, color: 'green' }, + { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, + { key: 34, text: 'Coze', value: 34, color: 'blue' }, + { key: 35, text: 'Cohere', value: 35, color: 'blue' }, + { key: 36, text: 'DeepSeek', value: 36, color: 'black' }, + { key: 37, text: 'Cloudflare', value: 37, color: 'orange' }, + { key: 38, text: 'DeepL', value: 38, color: 'black' }, + { key: 39, text: 'together.ai', value: 39, color: 'blue' }, + { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, + { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, @@ -34,4 +46,4 @@ export const CHANNEL_OPTIONS = [ for (let i = 0; i < CHANNEL_OPTIONS.length; i++) { CHANNEL_OPTIONS[i].label = CHANNEL_OPTIONS[i].text; -} \ No newline at end of file +} diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index ac2e73a6b4..acfda37b4b 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -167,6 +167,12 @@ export const CHANNEL_OPTIONS = { value: 42, color: 'primary' }, + 43: { + key: 43, + text: 'Proxy', + value: 43, + color: 'primary' + }, 41: { key: 41, text: 'Novita', diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index b17f56c082..b2a7101654 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -1,44 +1,45 @@ export const CHANNEL_OPTIONS = [ - {key: 1, text: 'OpenAI', value: 1, color: 'green'}, - {key: 14, text: 'Anthropic Claude', value: 14, color: 'black'}, - {key: 33, text: 'AWS', value: 33, color: 'black'}, - {key: 3, text: 'Azure OpenAI', value: 3, color: 'olive'}, - {key: 11, text: 'Google PaLM2', value: 11, color: 'orange'}, - {key: 24, text: 'Google Gemini', value: 24, color: 'orange'}, - {key: 28, text: 'Mistral AI', value: 28, color: 'orange'}, - {key: 41, text: 'Novita', value: 41, color: 'purple'}, - {key: 40, text: '字节跳动豆包', value: 40, color: 'blue'}, - {key: 15, text: '百度文心千帆', value: 15, color: 'blue'}, - {key: 17, text: '阿里通义千问', value: 17, color: 'orange'}, - {key: 18, text: '讯飞星火认知', value: 18, color: 'blue'}, - {key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet'}, - {key: 19, text: '360 智脑', value: 19, color: 'blue'}, - {key: 25, text: 'Moonshot AI', value: 25, color: 'black'}, - {key: 23, text: '腾讯混元', value: 23, color: 'teal'}, - {key: 26, text: '百川大模型', value: 26, color: 'orange'}, - {key: 27, text: 'MiniMax', value: 27, color: 'red'}, - {key: 29, text: 'Groq', value: 29, color: 'orange'}, - {key: 30, text: 'Ollama', value: 30, color: 'black'}, - {key: 31, text: '零一万物', value: 31, color: 'green'}, - {key: 32, text: '阶跃星辰', value: 32, color: 'blue'}, - {key: 34, text: 'Coze', value: 34, color: 'blue'}, - {key: 35, text: 'Cohere', value: 35, color: 'blue'}, - {key: 36, text: 'DeepSeek', value: 36, color: 'black'}, - {key: 37, text: 'Cloudflare', value: 37, color: 'orange'}, - {key: 38, text: 'DeepL', value: 38, color: 'black'}, - {key: 39, text: 'together.ai', value: 39, color: 'blue'}, - {key: 42, text: 'VertexAI', value: 42, color: 'blue'}, - {key: 8, text: '自定义渠道', value: 8, color: 'pink'}, - {key: 22, text: '知识库:FastGPT', value: 22, color: 'blue'}, - {key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple'}, - {key: 20, text: '代理:OpenRouter', value: 20, color: 'black'}, - {key: 2, text: '代理:API2D', value: 2, color: 'blue'}, - {key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown'}, - {key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple'}, - {key: 10, text: '代理:AI Proxy', value: 10, color: 'purple'}, - {key: 4, text: '代理:CloseAI', value: 4, color: 'teal'}, - {key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet'}, - {key: 9, text: '代理:AI.LS', value: 9, color: 'yellow'}, - {key: 12, text: '代理:API2GPT', value: 12, color: 'blue'}, - {key: 13, text: '代理:AIGC2D', value: 13, color: 'purple'} + { key: 1, text: 'OpenAI', value: 1, color: 'green' }, + { key: 14, text: 'Anthropic Claude', value: 14, color: 'black' }, + { key: 33, text: 'AWS', value: 33, color: 'black' }, + { key: 3, text: 'Azure OpenAI', value: 3, color: 'olive' }, + { key: 11, text: 'Google PaLM2', value: 11, color: 'orange' }, + { key: 24, text: 'Google Gemini', value: 24, color: 'orange' }, + { key: 28, text: 'Mistral AI', value: 28, color: 'orange' }, + { key: 41, text: 'Novita', value: 41, color: 'purple' }, + { key: 40, text: '字节跳动豆包', value: 40, color: 'blue' }, + { key: 15, text: '百度文心千帆', value: 15, color: 'blue' }, + { key: 17, text: '阿里通义千问', value: 17, color: 'orange' }, + { key: 18, text: '讯飞星火认知', value: 18, color: 'blue' }, + { key: 16, text: '智谱 ChatGLM', value: 16, color: 'violet' }, + { key: 19, text: '360 智脑', value: 19, color: 'blue' }, + { key: 25, text: 'Moonshot AI', value: 25, color: 'black' }, + { key: 23, text: '腾讯混元', value: 23, color: 'teal' }, + { key: 26, text: '百川大模型', value: 26, color: 'orange' }, + { key: 27, text: 'MiniMax', value: 27, color: 'red' }, + { key: 29, text: 'Groq', value: 29, color: 'orange' }, + { key: 30, text: 'Ollama', value: 30, color: 'black' }, + { key: 31, text: '零一万物', value: 31, color: 'green' }, + { key: 32, text: '阶跃星辰', value: 32, color: 'blue' }, + { key: 34, text: 'Coze', value: 34, color: 'blue' }, + { key: 35, text: 'Cohere', value: 35, color: 'blue' }, + { key: 36, text: 'DeepSeek', value: 36, color: 'black' }, + { key: 37, text: 'Cloudflare', value: 37, color: 'orange' }, + { key: 38, text: 'DeepL', value: 38, color: 'black' }, + { key: 39, text: 'together.ai', value: 39, color: 'blue' }, + { key: 42, text: 'VertexAI', value: 42, color: 'blue' }, + { key: 43, text: 'Proxy', value: 43, color: 'blue' }, + { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, + { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, + { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, + { key: 20, text: '代理:OpenRouter', value: 20, color: 'black' }, + { key: 2, text: '代理:API2D', value: 2, color: 'blue' }, + { key: 5, text: '代理:OpenAI-SB', value: 5, color: 'brown' }, + { key: 7, text: '代理:OhMyGPT', value: 7, color: 'purple' }, + { key: 10, text: '代理:AI Proxy', value: 10, color: 'purple' }, + { key: 4, text: '代理:CloseAI', value: 4, color: 'teal' }, + { key: 6, text: '代理:OpenAI Max', value: 6, color: 'violet' }, + { key: 9, text: '代理:AI.LS', value: 9, color: 'yellow' }, + { key: 12, text: '代理:API2GPT', value: 12, color: 'blue' }, + { key: 13, text: '代理:AIGC2D', value: 13, color: 'purple' } ]; diff --git a/web/default/src/pages/Channel/EditChannel.js b/web/default/src/pages/Channel/EditChannel.js index 64ff22f5ea..b967907e32 100644 --- a/web/default/src/pages/Channel/EditChannel.js +++ b/web/default/src/pages/Channel/EditChannel.js @@ -170,7 +170,7 @@ const EditChannel = () => { showInfo('请填写渠道名称和渠道密钥!'); return; } - if (inputs.models.length === 0) { + if (inputs.type !== 43 && inputs.models.length === 0) { showInfo('请至少选择一个模型!'); return; } @@ -370,63 +370,75 @@ const EditChannel = () => { ) } - - { - copy(value).then(); - }} - selection - onChange={handleInputChange} - value={inputs.models} - autoComplete='new-password' - options={modelOptions} - /> - -
- - - - 填入 - } - placeholder='输入自定义模型名称' - value={customModel} - onChange={(e, { value }) => { - setCustomModel(value); - }} - onKeyDown={(e) => { - if (e.key === 'Enter') { - addCustomModel(); - e.preventDefault(); - } - }} - /> -
- - - + { + inputs.type !== 43 && ( + + { + copy(value).then(); + }} + selection + onChange={handleInputChange} + value={inputs.models} + autoComplete='new-password' + options={modelOptions} + /> + + ) + } + { + inputs.type !== 43 && ( +
+ + + + 填入 + } + placeholder='输入自定义模型名称' + value={customModel} + onChange={(e, { value }) => { + setCustomModel(value); + }} + onKeyDown={(e) => { + if (e.key === 'Enter') { + addCustomModel(); + e.preventDefault(); + } + }} + /> +
+ ) + } + { + inputs.type !== 43 && ( + + + + ) + } { inputs.type === 33 && (