Skip to content

Commit 7c4d9d2

Browse files
committed
feat: support SiliconFlow (close #437, close #403)
1 parent d0f76a5 commit 7c4d9d2

File tree

10 files changed

+227
-23
lines changed

10 files changed

+227
-23
lines changed

common/constants.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ const (
213213
ChannelTypeDify = 37
214214
ChannelTypeJina = 38
215215
ChannelCloudflare = 39
216+
ChannelTypeSiliconFlow = 40
216217

217218
ChannelTypeDummy // this one is only for count, do not add any channel after this
218219

@@ -259,4 +260,5 @@ var ChannelBaseURLs = []string{
259260
"", //37
260261
"https://api.jina.ai", //38
261262
"https://api.cloudflare.com", //39
263+
"https://api.siliconflow.cn", //40
262264
}

dto/rerank.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
package dto
22

33
type RerankRequest struct {
4-
Documents []any `json:"documents"`
5-
Query string `json:"query"`
6-
Model string `json:"model"`
7-
TopN int `json:"top_n"`
4+
Documents []any `json:"documents"`
5+
Query string `json:"query"`
6+
Model string `json:"model"`
7+
TopN int `json:"top_n"`
8+
ReturnDocuments bool `json:"return_documents,omitempty"`
9+
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
10+
OverLapTokens int `json:"overlap_tokens,omitempty"`
811
}
912

1013
type RerankResponseDocument struct {
11-
Document any `json:"document"`
14+
Document any `json:"document,omitempty"`
1215
Index int `json:"index"`
1316
RelevanceScore float64 `json:"relevance_score"`
1417
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package siliconflow
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"github.com/gin-gonic/gin"
7+
"io"
8+
"net/http"
9+
"one-api/dto"
10+
"one-api/relay/channel"
11+
"one-api/relay/channel/openai"
12+
relaycommon "one-api/relay/common"
13+
"one-api/relay/constant"
14+
)
15+
16+
type Adaptor struct {
17+
}
18+
19+
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
20+
//TODO implement me
21+
return nil, errors.New("not implemented")
22+
}
23+
24+
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
25+
//TODO implement me
26+
return nil, errors.New("not implemented")
27+
}
28+
29+
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
30+
}
31+
32+
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
33+
if info.RelayMode == constant.RelayModeRerank {
34+
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
35+
} else if info.RelayMode == constant.RelayModeEmbeddings {
36+
return fmt.Sprintf("%s/v1/embeddings ", info.BaseUrl), nil
37+
} else if info.RelayMode == constant.RelayModeChatCompletions {
38+
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
39+
}
40+
return "", errors.New("invalid relay mode")
41+
}
42+
43+
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
44+
channel.SetupApiRequestHeader(info, c, req)
45+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
46+
return nil
47+
}
48+
49+
func (a *Adaptor) ConvertRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
50+
return request, nil
51+
}
52+
53+
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
54+
return channel.DoApiRequest(a, c, info, requestBody)
55+
}
56+
57+
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
58+
return request, nil
59+
}
60+
61+
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage *dto.Usage, err *dto.OpenAIErrorWithStatusCode) {
62+
if info.RelayMode == constant.RelayModeRerank {
63+
err, usage = siliconflowRerankHandler(c, resp)
64+
} else if info.RelayMode == constant.RelayModeChatCompletions {
65+
if info.IsStream {
66+
err, usage = openai.OaiStreamHandler(c, resp, info)
67+
} else {
68+
err, usage = openai.OpenaiHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
69+
}
70+
}
71+
return
72+
}
73+
74+
func (a *Adaptor) GetModelList() []string {
75+
return ModelList
76+
}
77+
78+
func (a *Adaptor) GetChannelName() string {
79+
return ChannelName
80+
}
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package siliconflow
2+
3+
var ModelList = []string{
4+
"THUDM/glm-4-9b-chat",
5+
//"stabilityai/stable-diffusion-xl-base-1.0",
6+
//"TencentARC/PhotoMaker",
7+
"InstantX/InstantID",
8+
//"stabilityai/stable-diffusion-2-1",
9+
//"stabilityai/sd-turbo",
10+
//"stabilityai/sdxl-turbo",
11+
"ByteDance/SDXL-Lightning",
12+
"deepseek-ai/deepseek-llm-67b-chat",
13+
"Qwen/Qwen1.5-14B-Chat",
14+
"Qwen/Qwen1.5-7B-Chat",
15+
"Qwen/Qwen1.5-110B-Chat",
16+
"Qwen/Qwen1.5-32B-Chat",
17+
"01-ai/Yi-1.5-6B-Chat",
18+
"01-ai/Yi-1.5-9B-Chat-16K",
19+
"01-ai/Yi-1.5-34B-Chat-16K",
20+
"THUDM/chatglm3-6b",
21+
"deepseek-ai/DeepSeek-V2-Chat",
22+
"Qwen/Qwen2-72B-Instruct",
23+
"Qwen/Qwen2-7B-Instruct",
24+
"Qwen/Qwen2-57B-A14B-Instruct",
25+
//"stabilityai/stable-diffusion-3-medium",
26+
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
27+
"Qwen/Qwen2-1.5B-Instruct",
28+
"internlm/internlm2_5-7b-chat",
29+
"BAAI/bge-large-en-v1.5",
30+
"BAAI/bge-large-zh-v1.5",
31+
"Pro/Qwen/Qwen2-7B-Instruct",
32+
"Pro/Qwen/Qwen2-1.5B-Instruct",
33+
"Pro/Qwen/Qwen1.5-7B-Chat",
34+
"Pro/THUDM/glm-4-9b-chat",
35+
"Pro/THUDM/chatglm3-6b",
36+
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
37+
"Pro/01-ai/Yi-1.5-6B-Chat",
38+
"Pro/google/gemma-2-9b-it",
39+
"Pro/internlm/internlm2_5-7b-chat",
40+
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
41+
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
42+
"black-forest-labs/FLUX.1-schnell",
43+
"iic/SenseVoiceSmall",
44+
"netease-youdao/bce-embedding-base_v1",
45+
"BAAI/bge-m3",
46+
"internlm/internlm2_5-20b-chat",
47+
"Qwen/Qwen2-Math-72B-Instruct",
48+
"netease-youdao/bce-reranker-base_v1",
49+
"BAAI/bge-reranker-v2-m3",
50+
}
51+
var ChannelName = "siliconflow"

relay/channel/siliconflow/dto.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
package siliconflow
2+
3+
import "one-api/dto"
4+
5+
type SFTokens struct {
6+
InputTokens int `json:"input_tokens"`
7+
OutputTokens int `json:"output_tokens"`
8+
}
9+
10+
type SFMeta struct {
11+
Tokens SFTokens `json:"tokens"`
12+
}
13+
14+
type SFRerankResponse struct {
15+
Results []dto.RerankResponseDocument `json:"results"`
16+
Meta SFMeta `json:"meta"`
17+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package siliconflow
2+
3+
import (
4+
"encoding/json"
5+
"github.com/gin-gonic/gin"
6+
"io"
7+
"net/http"
8+
"one-api/dto"
9+
"one-api/service"
10+
)
11+
12+
func siliconflowRerankHandler(c *gin.Context, resp *http.Response) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
13+
responseBody, err := io.ReadAll(resp.Body)
14+
if err != nil {
15+
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
16+
}
17+
err = resp.Body.Close()
18+
if err != nil {
19+
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
20+
}
21+
var siliconflowResp SFRerankResponse
22+
err = json.Unmarshal(responseBody, &siliconflowResp)
23+
if err != nil {
24+
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
25+
}
26+
usage := &dto.Usage{
27+
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
28+
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
29+
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
30+
}
31+
rerankResp := &dto.RerankResponse{
32+
Results: siliconflowResp.Results,
33+
Usage: *usage,
34+
}
35+
36+
jsonResponse, err := json.Marshal(rerankResp)
37+
if err != nil {
38+
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
39+
}
40+
c.Writer.Header().Set("Content-Type", "application/json")
41+
c.Writer.WriteHeader(resp.StatusCode)
42+
_, err = c.Writer.Write(jsonResponse)
43+
return nil, usage
44+
}

relay/constant/api_type.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const (
2323
APITypeDify
2424
APITypeJina
2525
APITypeCloudflare
26+
APITypeSiliconFlow
2627

2728
APITypeDummy // this one is only for count, do not add any channel after this
2829
)
@@ -66,6 +67,8 @@ func ChannelType2APIType(channelType int) (int, bool) {
6667
apiType = APITypeJina
6768
case common.ChannelCloudflare:
6869
apiType = APITypeCloudflare
70+
case common.ChannelTypeSiliconFlow:
71+
apiType = APITypeSiliconFlow
6972
}
7073
if apiType == -1 {
7174
return APITypeOpenAI, false

relay/relay-text.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ func postConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelN
317317
totalTokens := promptTokens + completionTokens
318318
var logContent string
319319
if !usePrice {
320-
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f,补全倍率 %.2f", modelRatio, groupRatio, completionRatio)
320+
logContent = fmt.Sprintf("模型倍率 %.2f,补全倍率 %.2f,分组倍率 %.2f", modelRatio, completionRatio, groupRatio)
321321
} else {
322322
logContent = fmt.Sprintf("模型价格 %.2f,分组倍率 %.2f", modelPrice, groupRatio)
323323
}

relay/relay_adaptor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"one-api/relay/channel/openai"
1717
"one-api/relay/channel/palm"
1818
"one-api/relay/channel/perplexity"
19+
"one-api/relay/channel/siliconflow"
1920
"one-api/relay/channel/task/suno"
2021
"one-api/relay/channel/tencent"
2122
"one-api/relay/channel/xunfei"
@@ -62,6 +63,8 @@ func GetAdaptor(apiType int) channel.Adaptor {
6263
return &jina.Adaptor{}
6364
case constant.APITypeCloudflare:
6465
return &cloudflare.Adaptor{}
66+
case constant.APITypeSiliconFlow:
67+
return &siliconflow.Adaptor{}
6568
}
6669
return nil
6770
}

0 commit comments

Comments
 (0)