Skip to content

Commit c936198

Browse files
authored
feat: add Proxy channel type and relay mode (songquanpeng#1678)
Add the Proxy channel type and relay mode to support proxying requests to custom upstream services.
1 parent 296ab01 commit c936198

File tree

17 files changed

+292
-106
lines changed

17 files changed

+292
-106
lines changed

controller/relay.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode {
3434
fallthrough
3535
case relaymode.AudioTranscription:
3636
err = controller.RelayAudioHelper(c, relayMode)
37+
case relaymode.Proxy:
38+
err = controller.RelayProxyHelper(c, relayMode)
3739
default:
3840
err = controller.RelayTextHelper(c)
3941
}
@@ -85,12 +87,15 @@ func Relay(c *gin.Context) {
8587
channelId := c.GetInt(ctxkey.ChannelId)
8688
lastFailedChannelId = channelId
8789
channelName := c.GetString(ctxkey.ChannelName)
90+
// BUG: bizErr is in race condition
8891
go processChannelRelayError(ctx, userId, channelId, channelName, bizErr)
8992
}
9093
if bizErr != nil {
9194
if bizErr.StatusCode == http.StatusTooManyRequests {
9295
bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试"
9396
}
97+
98+
// BUG: bizErr is in race condition
9499
bizErr.Error.Message = helper.MessageWithRequestId(bizErr.Error.Message, requestId)
95100
c.JSON(bizErr.StatusCode, gin.H{
96101
"error": bizErr.Error,

middleware/auth.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,12 @@ func TokenAuth() func(c *gin.Context) {
140140
return
141141
}
142142
}
143+
144+
// set channel id for proxy relay
145+
if channelId := c.Param("channelid"); channelId != "" {
146+
c.Set(ctxkey.SpecificChannelId, channelId)
147+
}
148+
143149
c.Next()
144150
}
145151
}

relay/adaptor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"github.com/songquanpeng/one-api/relay/adaptor/ollama"
1616
"github.com/songquanpeng/one-api/relay/adaptor/openai"
1717
"github.com/songquanpeng/one-api/relay/adaptor/palm"
18+
"github.com/songquanpeng/one-api/relay/adaptor/proxy"
1819
"github.com/songquanpeng/one-api/relay/adaptor/tencent"
1920
"github.com/songquanpeng/one-api/relay/adaptor/vertexai"
2021
"github.com/songquanpeng/one-api/relay/adaptor/xunfei"
@@ -58,6 +59,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor {
5859
return &deepl.Adaptor{}
5960
case apitype.VertexAI:
6061
return &vertexai.Adaptor{}
62+
case apitype.Proxy:
63+
return &proxy.Adaptor{}
6164
}
6265
return nil
6366
}

relay/adaptor/proxy/adaptor.go

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package proxy
2+
3+
import (
4+
"fmt"
5+
"io"
6+
"net/http"
7+
"strings"
8+
9+
"github.com/gin-gonic/gin"
10+
"github.com/pkg/errors"
11+
"github.com/songquanpeng/one-api/relay/adaptor"
12+
channelhelper "github.com/songquanpeng/one-api/relay/adaptor"
13+
"github.com/songquanpeng/one-api/relay/meta"
14+
"github.com/songquanpeng/one-api/relay/model"
15+
relaymodel "github.com/songquanpeng/one-api/relay/model"
16+
)
17+
18+
var _ adaptor.Adaptor = new(Adaptor)
19+
20+
const channelName = "proxy"
21+
22+
type Adaptor struct{}
23+
24+
func (a *Adaptor) Init(meta *meta.Meta) {
25+
}
26+
27+
func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) {
28+
return nil, errors.New("notimplement")
29+
}
30+
31+
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) {
32+
for k, v := range resp.Header {
33+
for _, vv := range v {
34+
c.Writer.Header().Set(k, vv)
35+
}
36+
}
37+
38+
c.Writer.WriteHeader(resp.StatusCode)
39+
if _, gerr := io.Copy(c.Writer, resp.Body); gerr != nil {
40+
return nil, &relaymodel.ErrorWithStatusCode{
41+
StatusCode: http.StatusInternalServerError,
42+
Error: relaymodel.Error{
43+
Message: gerr.Error(),
44+
},
45+
}
46+
}
47+
48+
return nil, nil
49+
}
50+
51+
func (a *Adaptor) GetModelList() (models []string) {
52+
return nil
53+
}
54+
55+
func (a *Adaptor) GetChannelName() string {
56+
return channelName
57+
}
58+
59+
// GetRequestURL remove static prefix, and return the real request url to the upstream service
60+
func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) {
61+
prefix := fmt.Sprintf("/v1/oneapi/proxy/%d", meta.ChannelId)
62+
return meta.BaseURL + strings.TrimPrefix(meta.RequestURLPath, prefix), nil
63+
64+
}
65+
66+
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error {
67+
for k, v := range c.Request.Header {
68+
req.Header.Set(k, v[0])
69+
}
70+
71+
// remove unnecessary headers
72+
req.Header.Del("Host")
73+
req.Header.Del("Content-Length")
74+
req.Header.Del("Accept-Encoding")
75+
req.Header.Del("Connection")
76+
77+
// set authorization header
78+
req.Header.Set("Authorization", meta.APIKey)
79+
80+
return nil
81+
}
82+
83+
func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) {
84+
return nil, errors.Errorf("not implement")
85+
}
86+
87+
func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) {
88+
return channelhelper.DoRequestHelper(a, c, meta, requestBody)
89+
}

relay/apitype/define.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ const (
1818
Cloudflare
1919
DeepL
2020
VertexAI
21+
Proxy
2122

2223
Dummy // this one is only for count, do not add any channel after this
2324
)

relay/channeltype/define.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,6 @@ const (
4444
Doubao
4545
Novita
4646
VertextAI
47+
Proxy
4748
Dummy
4849
)

relay/channeltype/helper.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ func ToAPIType(channelType int) int {
3737
apiType = apitype.DeepL
3838
case VertextAI:
3939
apiType = apitype.VertexAI
40+
case Proxy:
41+
apiType = apitype.Proxy
4042
}
4143

4244
return apiType

relay/channeltype/url.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ var ChannelBaseURLs = []string{
4444
"https://ark.cn-beijing.volces.com", // 40
4545
"https://api.novita.ai/v3/openai", // 41
4646
"", // 42
47+
"", // 43
4748
}
4849

4950
func init() {

relay/controller/proxy.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Package controller is a package for handling the relay controller
2+
package controller
3+
4+
import (
5+
"fmt"
6+
"net/http"
7+
8+
"github.com/gin-gonic/gin"
9+
"github.com/songquanpeng/one-api/common/logger"
10+
"github.com/songquanpeng/one-api/relay"
11+
"github.com/songquanpeng/one-api/relay/adaptor/openai"
12+
"github.com/songquanpeng/one-api/relay/meta"
13+
relaymodel "github.com/songquanpeng/one-api/relay/model"
14+
)
15+
16+
// RelayProxyHelper is a helper function to proxy the request to the upstream service
17+
func RelayProxyHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode {
18+
ctx := c.Request.Context()
19+
meta := meta.GetByContext(c)
20+
21+
adaptor := relay.GetAdaptor(meta.APIType)
22+
if adaptor == nil {
23+
return openai.ErrorWrapper(fmt.Errorf("invalid api type: %d", meta.APIType), "invalid_api_type", http.StatusBadRequest)
24+
}
25+
adaptor.Init(meta)
26+
27+
resp, err := adaptor.DoRequest(c, meta, c.Request.Body)
28+
if err != nil {
29+
logger.Errorf(ctx, "DoRequest failed: %s", err.Error())
30+
return openai.ErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
31+
}
32+
33+
// do response
34+
_, respErr := adaptor.DoResponse(c, resp, meta)
35+
if respErr != nil {
36+
logger.Errorf(ctx, "respErr is not nil: %+v", respErr)
37+
return respErr
38+
}
39+
40+
return nil
41+
}

relay/meta/relay_meta.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ type Meta struct {
1818
UserId int
1919
Group string
2020
ModelMapping map[string]string
21-
BaseURL string
22-
APIKey string
23-
APIType int
24-
Config model.ChannelConfig
25-
IsStream bool
21+
// BaseURL is the proxy url set in the channel config
22+
BaseURL string
23+
APIKey string
24+
APIType int
25+
Config model.ChannelConfig
26+
IsStream bool
2627
// OriginModelName is the model name from the raw user request
2728
OriginModelName string
2829
// ActualModelName is the model name after mapping

0 commit comments

Comments
 (0)