diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36798711a9..3034a54767 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,13 +1,13 @@ name: CI # This setup assumes that you run the unit tests with code coverage in the same -# workflow that will also print the coverage report as comment to the pull request. +# workflow that will also print the coverage report as comment to the pull request. # Therefore, you need to trigger this workflow when a pull request is (re)opened or # when new code is pushed to the branch of the pull request. In addition, you also -# need to trigger this workflow when new code is pushed to the main branch because +# need to trigger this workflow when new code is pushed to the main branch because # we need to upload the code coverage results as artifact for the main branch as # well since it will be the baseline code coverage. -# +# # We do not want to trigger the workflow for pushes to *any* branch because this # would trigger our jobs twice on pull requests (once from "push" event and once # from "pull_request->synchronize") @@ -31,7 +31,7 @@ jobs: with: go-version: ^1.22 - # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a + # When you execute your unit tests, make sure to use the "-coverprofile" flag to write a # coverage profile to a file. You will need the name of the file (e.g. "coverage.txt") # in the next step as well as the next job. - name: Test diff --git a/.gitignore b/.gitignore index 4e431e6588..0cedb4b401 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ logs data /web/node_modules cmd.md -.env \ No newline at end of file +.env +/one-api diff --git a/common/ctxkey/key.go b/common/ctxkey/key.go index 115558a51c..96229a7975 100644 --- a/common/ctxkey/key.go +++ b/common/ctxkey/key.go @@ -3,6 +3,7 @@ package ctxkey const ( Config = "config" Id = "id" + RequestId = "X-Oneapi-Request-Id" Username = "username" Role = "role" Status = "status" @@ -15,6 +16,7 @@ const ( Group = "group" ModelMapping = "model_mapping" ChannelName = "channel_name" + ContentType = "content_type" TokenId = "token_id" TokenName = "token_name" BaseURL = "base_url" diff --git a/common/gin.go b/common/gin.go index 815b4ee54a..c8254bfd45 100644 --- a/common/gin.go +++ b/common/gin.go @@ -3,24 +3,27 @@ package common import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/common/ctxkey" "io" + "reflect" "strings" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/ctxkey" ) -func GetRequestBody(c *gin.Context) ([]byte, error) { - requestBody, _ := c.Get(ctxkey.KeyRequestBody) - if requestBody != nil { - return requestBody.([]byte), nil +func GetRequestBody(c *gin.Context) (requestBody []byte, err error) { + if requestBodyCache, _ := c.Get(ctxkey.KeyRequestBody); requestBodyCache != nil { + return requestBodyCache.([]byte), nil } - requestBody, err := io.ReadAll(c.Request.Body) + requestBody, err = io.ReadAll(c.Request.Body) if err != nil { return nil, err } _ = c.Request.Body.Close() c.Set(ctxkey.KeyRequestBody, requestBody) - return requestBody.([]byte), nil + + return requestBody, nil } func UnmarshalBodyReusable(c *gin.Context, v any) error { @@ -28,18 +31,25 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error { if err != nil { return err } + + // check v should be a pointer + if v == nil || reflect.TypeOf(v).Kind() != reflect.Ptr { + return errors.Errorf("UnmarshalBodyReusable only accept pointer, got %v", reflect.TypeOf(v)) + } + contentType := c.Request.Header.Get("Content-Type") if strings.HasPrefix(contentType, "application/json") { - err = json.Unmarshal(requestBody, &v) - c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + err = json.Unmarshal(requestBody, v) } else { c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) - err = c.ShouldBind(&v) + err = c.ShouldBind(v) } if err != nil { - return err + return errors.Wrap(err, "unmarshal request body failed") } // Reset request body + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + return nil } diff --git a/controller/relay.go b/controller/relay.go index 49358e2597..f792b258e9 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -26,7 +26,8 @@ import ( func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { var err *model.ErrorWithStatusCode switch relayMode { - case relaymode.ImagesGenerations: + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: err = controller.RelayImageHelper(c, relayMode) case relaymode.AudioSpeech: fallthrough @@ -45,10 +46,6 @@ func relayHelper(c *gin.Context, relayMode int) *model.ErrorWithStatusCode { func Relay(c *gin.Context) { ctx := c.Request.Context() relayMode := relaymode.GetByPath(c.Request.URL.Path) - if config.DebugEnabled { - requestBody, _ := common.GetRequestBody(c) - logger.Debugf(ctx, "request body: %s", string(requestBody)) - } channelId := c.GetInt(ctxkey.ChannelId) userId := c.GetInt(ctxkey.Id) bizErr := relayHelper(c, relayMode) @@ -60,7 +57,7 @@ func Relay(c *gin.Context) { channelName := c.GetString(ctxkey.ChannelName) group := c.GetString(ctxkey.Group) originalModel := c.GetString(ctxkey.OriginalModel) - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) requestId := c.GetString(helper.RequestIdKey) retryTimes := config.RetryTimes if !shouldRetry(c, bizErr.StatusCode) { @@ -87,9 +84,9 @@ func Relay(c *gin.Context) { channelId := c.GetInt(ctxkey.ChannelId) lastFailedChannelId = channelId channelName := c.GetString(ctxkey.ChannelName) - // BUG: bizErr is in race condition - go processChannelRelayError(ctx, userId, channelId, channelName, bizErr) + go processChannelRelayError(ctx, userId, channelId, channelName, *bizErr) } + if bizErr != nil { if bizErr.StatusCode == http.StatusTooManyRequests { bizErr.Error.Message = "当前分组上游负载已饱和,请稍后再试" @@ -122,7 +119,10 @@ func shouldRetry(c *gin.Context, statusCode int) bool { return true } -func processChannelRelayError(ctx context.Context, userId int, channelId int, channelName string, err *model.ErrorWithStatusCode) { +func processChannelRelayError(ctx context.Context, + userId int, channelId int, channelName string, + // FIX: err should not use a pointer to avoid data race in concurrent situations + err model.ErrorWithStatusCode) { logger.Errorf(ctx, "relay error (channel id %d, user id: %d): %s", channelId, userId, err.Message) // https://platform.openai.com/docs/guides/error-codes/api-errors if monitor.ShouldDisableChannel(&err.Error, err.StatusCode) { diff --git a/go.mod b/go.mod index ada53bc33c..cfc8bcadf3 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.24.0 golang.org/x/image v0.18.0 + golang.org/x/sync v0.7.0 google.golang.org/api v0.187.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 @@ -99,7 +100,6 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect golang.org/x/sys v0.21.0 // indirect golang.org/x/text v0.16.0 // indirect golang.org/x/time v0.5.0 // indirect diff --git a/middleware/distributor.go b/middleware/distributor.go index 0aceb29dd8..6a45b50a21 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -64,6 +64,7 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode if channel.SystemPrompt != nil && *channel.SystemPrompt != "" { c.Set(ctxkey.SystemPrompt, *channel.SystemPrompt) } + c.Set(ctxkey.ContentType, c.Request.Header.Get("Content-Type")) c.Set(ctxkey.ModelMapping, channel.GetModelMapping()) c.Set(ctxkey.OriginalModel, modelName) // for retry c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) diff --git a/middleware/logger.go b/middleware/logger.go index 191364f8cd..587d748c2b 100644 --- a/middleware/logger.go +++ b/middleware/logger.go @@ -2,6 +2,7 @@ package middleware import ( "fmt" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/request-id.go b/middleware/request-id.go index bef09e32f6..c1f3adc22f 100644 --- a/middleware/request-id.go +++ b/middleware/request-id.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/helper" ) diff --git a/middleware/utils.go b/middleware/utils.go index 4d2f8092ef..46120f2a75 100644 --- a/middleware/utils.go +++ b/middleware/utils.go @@ -1,12 +1,13 @@ package middleware import ( - "fmt" + "strings" + "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" - "strings" ) func abortWithMessage(c *gin.Context, statusCode int, message string) { @@ -24,28 +25,30 @@ func getRequestModel(c *gin.Context) (string, error) { var modelRequest ModelRequest err := common.UnmarshalBodyReusable(c, &modelRequest) if err != nil { - return "", fmt.Errorf("common.UnmarshalBodyReusable failed: %w", err) + return "", errors.Wrap(err, "common.UnmarshalBodyReusable failed") } - if strings.HasPrefix(c.Request.URL.Path, "/v1/moderations") { + + switch { + case strings.HasPrefix(c.Request.URL.Path, "/v1/moderations"): if modelRequest.Model == "" { modelRequest.Model = "text-moderation-stable" } - } - if strings.HasSuffix(c.Request.URL.Path, "embeddings") { + case strings.HasSuffix(c.Request.URL.Path, "embeddings"): if modelRequest.Model == "" { modelRequest.Model = c.Param("model") } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/images/generations"), + strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits"): if modelRequest.Model == "" { modelRequest.Model = "dall-e-2" } - } - if strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") || strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations") { + case strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions"), + strings.HasPrefix(c.Request.URL.Path, "/v1/audio/translations"): if modelRequest.Model == "" { modelRequest.Model = "whisper-1" } } + return modelRequest.Model, nil } diff --git a/monitor/manage.go b/monitor/manage.go index 44c13612d3..268d3924ec 100644 --- a/monitor/manage.go +++ b/monitor/manage.go @@ -34,7 +34,7 @@ func ShouldDisableChannel(err *model.Error, statusCode int) bool { strings.Contains(lowerMessage, "credit") || strings.Contains(lowerMessage, "balance") || strings.Contains(lowerMessage, "permission denied") || - strings.Contains(lowerMessage, "organization has been restricted") || // groq + strings.Contains(lowerMessage, "organization has been restricted") || // groq strings.Contains(lowerMessage, "已欠费") { return true } diff --git a/one-api b/one-api deleted file mode 100755 index 4c9190bb93..0000000000 Binary files a/one-api and /dev/null differ diff --git a/relay/adaptor.go b/relay/adaptor.go index 711e63bdc6..03e8390319 100644 --- a/relay/adaptor.go +++ b/relay/adaptor.go @@ -16,6 +16,7 @@ import ( "github.com/songquanpeng/one-api/relay/adaptor/openai" "github.com/songquanpeng/one-api/relay/adaptor/palm" "github.com/songquanpeng/one-api/relay/adaptor/proxy" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" "github.com/songquanpeng/one-api/relay/adaptor/tencent" "github.com/songquanpeng/one-api/relay/adaptor/vertexai" "github.com/songquanpeng/one-api/relay/adaptor/xunfei" @@ -61,6 +62,8 @@ func GetAdaptor(apiType int) adaptor.Adaptor { return &vertexai.Adaptor{} case apitype.Proxy: return &proxy.Adaptor{} + case apitype.Replicate: + return &replicate.Adaptor{} } return nil } diff --git a/relay/adaptor/common.go b/relay/adaptor/common.go index 8953d7a3c8..9069255aee 100644 --- a/relay/adaptor/common.go +++ b/relay/adaptor/common.go @@ -3,11 +3,13 @@ package adaptor import ( "errors" "fmt" + "io" + "net/http" + "github.com/gin-gonic/gin" "github.com/songquanpeng/one-api/common/client" + "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/relay/meta" - "io" - "net/http" ) func SetupCommonRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) { @@ -27,6 +29,9 @@ func DoRequestHelper(a Adaptor, c *gin.Context, meta *meta.Meta, requestBody io. if err != nil { return nil, fmt.Errorf("new request failed: %w", err) } + + req.Header.Set("Content-Type", c.GetString(ctxkey.ContentType)) + err = a.SetupRequestHeader(c, req, meta) if err != nil { return nil, fmt.Errorf("setup request header failed: %w", err) diff --git a/relay/adaptor/ollama/main.go b/relay/adaptor/ollama/main.go index 43317ff66f..fa1b05f0c5 100644 --- a/relay/adaptor/ollama/main.go +++ b/relay/adaptor/ollama/main.go @@ -31,8 +31,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest { TopP: request.TopP, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, - NumPredict: request.MaxTokens, - NumCtx: request.NumCtx, + NumPredict: request.MaxTokens, + NumCtx: request.NumCtx, }, Stream: request.Stream, } @@ -122,7 +122,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC for scanner.Scan() { data := scanner.Text() if strings.HasPrefix(data, "}") { - data = strings.TrimPrefix(data, "}") + "}" + data = strings.TrimPrefix(data, "}") + "}" } var ollamaResponse ChatResponse diff --git a/relay/adaptor/openai/adaptor.go b/relay/adaptor/openai/adaptor.go index 6946e402a8..fa85f52c77 100644 --- a/relay/adaptor/openai/adaptor.go +++ b/relay/adaptor/openai/adaptor.go @@ -111,10 +111,13 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met switch meta.Mode { case relaymode.ImagesGenerations: err, _ = ImageHandler(c, resp) + case relaymode.ImagesEdits: + err, _ = ImagesEditsHandler(c, resp) default: err, usage = Handler(c, resp, meta.PromptTokens, meta.ActualModelName) } } + return } diff --git a/relay/adaptor/openai/image.go b/relay/adaptor/openai/image.go index 0f89618a24..433d942171 100644 --- a/relay/adaptor/openai/image.go +++ b/relay/adaptor/openai/image.go @@ -3,12 +3,30 @@ package openai import ( "bytes" "encoding/json" - "github.com/gin-gonic/gin" - "github.com/songquanpeng/one-api/relay/model" "io" "net/http" + + "github.com/gin-gonic/gin" + "github.com/songquanpeng/one-api/relay/model" ) +// ImagesEditsHandler just copy response body to client +// +// https://platform.openai.com/docs/api-reference/images/createEdit +func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { + c.Writer.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + c.Writer.Header().Set(k, v[0]) + } + + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil + } + defer resp.Body.Close() + + return nil, nil +} + func ImageHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { var imageResponse ImageResponse responseBody, err := io.ReadAll(resp.Body) diff --git a/relay/adaptor/openai/util.go b/relay/adaptor/openai/util.go index ba0cab7dbc..83beadbafc 100644 --- a/relay/adaptor/openai/util.go +++ b/relay/adaptor/openai/util.go @@ -1,8 +1,16 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "context" + "fmt" + + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/model" +) func ErrorWrapper(err error, code string, statusCode int) *model.ErrorWithStatusCode { + logger.Error(context.TODO(), fmt.Sprintf("[%s]%+v", code, err)) + Error := model.Error{ Message: err.Error(), Type: "one_api_error", diff --git a/relay/adaptor/replicate/adaptor.go b/relay/adaptor/replicate/adaptor.go new file mode 100644 index 0000000000..34e42de5e5 --- /dev/null +++ b/relay/adaptor/replicate/adaptor.go @@ -0,0 +1,128 @@ +package replicate + +import ( + "bytes" + "fmt" + "io" + "net/http" + "slices" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "github.com/songquanpeng/one-api/relay/relaymode" +) + +type Adaptor struct { + meta *meta.Meta +} + +// ConvertImageRequest implements adaptor.Adaptor. +func (a *Adaptor) ConvertImageRequest(request *model.ImageRequest) (any, error) { + return nil, errors.New("should call replicate.ConvertImageRequest instead") +} + +func ConvertImageRequest(c *gin.Context, request *model.ImageRequest) (any, error) { + meta := meta.GetByContext(c) + + if request.ResponseFormat != "b64_json" { + return nil, errors.New("only support b64_json response format") + } + if request.N != 1 && request.N != 0 { + return nil, errors.New("only support N=1") + } + + switch meta.Mode { + case relaymode.ImagesGenerations: + return convertImageCreateRequest(request) + case relaymode.ImagesEdits: + return convertImageRemixRequest(c) + default: + return nil, errors.New("not implemented") + } +} + +func convertImageCreateRequest(request *model.ImageRequest) (any, error) { + return DrawImageRequest{ + Input: ImageInput{ + Steps: 25, + Prompt: request.Prompt, + Guidance: 3, + Seed: int(time.Now().UnixNano()), + SafetyTolerance: 5, + NImages: 1, // replicate will always return 1 image + Width: 1440, + Height: 1440, + AspectRatio: "1:1", + }, + }, nil +} + +func convertImageRemixRequest(c *gin.Context) (any, error) { + // recover request body + requestBody, err := common.GetRequestBody(c) + if err != nil { + return nil, errors.Wrap(err, "get request body") + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody)) + + rawReq := new(OpenaiImageEditRequest) + if err := c.ShouldBind(rawReq); err != nil { + return nil, errors.Wrap(err, "parse image edit form") + } + + return rawReq.toFluxRemixRequest() +} + +func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.GeneralOpenAIRequest) (any, error) { + return nil, errors.New("not implemented") +} + +func (a *Adaptor) Init(meta *meta.Meta) { + a.meta = meta +} + +func (a *Adaptor) GetRequestURL(meta *meta.Meta) (string, error) { + if !slices.Contains(ModelList, meta.OriginModelName) { + return "", errors.Errorf("model %s not supported", meta.OriginModelName) + } + + return fmt.Sprintf("https://api.replicate.com/v1/models/%s/predictions", meta.OriginModelName), nil +} + +func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Request, meta *meta.Meta) error { + adaptor.SetupCommonRequestHeader(c, req, meta) + req.Header.Set("Authorization", "Bearer "+meta.APIKey) + return nil +} + +func (a *Adaptor) DoRequest(c *gin.Context, meta *meta.Meta, requestBody io.Reader) (*http.Response, error) { + logger.Info(c, "send image request to replicate") + return adaptor.DoRequestHelper(a, c, meta, requestBody) +} + +func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Meta) (usage *model.Usage, err *model.ErrorWithStatusCode) { + switch meta.Mode { + case relaymode.ImagesGenerations, + relaymode.ImagesEdits: + err, usage = ImageHandler(c, resp) + default: + err = openai.ErrorWrapper(errors.New("not implemented"), "not_implemented", http.StatusInternalServerError) + } + + return +} + +func (a *Adaptor) GetModelList() []string { + return ModelList +} + +func (a *Adaptor) GetChannelName() string { + return "replicate" +} diff --git a/relay/adaptor/replicate/constant.go b/relay/adaptor/replicate/constant.go new file mode 100644 index 0000000000..157cc04592 --- /dev/null +++ b/relay/adaptor/replicate/constant.go @@ -0,0 +1,58 @@ +package replicate + +// ModelList is a list of models that can be used with Replicate. +// +// https://replicate.com/pricing +var ModelList = []string{ + // ------------------------------------- + // image model + // ------------------------------------- + "black-forest-labs/flux-1.1-pro", + "black-forest-labs/flux-1.1-pro-ultra", + "black-forest-labs/flux-canny-dev", + "black-forest-labs/flux-canny-pro", + "black-forest-labs/flux-depth-dev", + "black-forest-labs/flux-depth-pro", + "black-forest-labs/flux-dev", + "black-forest-labs/flux-dev-lora", + "black-forest-labs/flux-fill-dev", + "black-forest-labs/flux-fill-pro", + "black-forest-labs/flux-pro", + "black-forest-labs/flux-redux-dev", + "black-forest-labs/flux-redux-schnell", + "black-forest-labs/flux-schnell", + "black-forest-labs/flux-schnell-lora", + "ideogram-ai/ideogram-v2", + "ideogram-ai/ideogram-v2-turbo", + "recraft-ai/recraft-v3", + "recraft-ai/recraft-v3-svg", + "stability-ai/stable-diffusion-3", + "stability-ai/stable-diffusion-3.5-large", + "stability-ai/stable-diffusion-3.5-large-turbo", + "stability-ai/stable-diffusion-3.5-medium", + // ------------------------------------- + // language model + // ------------------------------------- + // "ibm-granite/granite-20b-code-instruct-8k", // TODO: implement the adaptor + // "ibm-granite/granite-3.0-2b-instruct", // TODO: implement the adaptor + // "ibm-granite/granite-3.0-8b-instruct", // TODO: implement the adaptor + // "ibm-granite/granite-8b-code-instruct-128k", // TODO: implement the adaptor + // "meta/llama-2-13b", // TODO: implement the adaptor + // "meta/llama-2-13b-chat", // TODO: implement the adaptor + // "meta/llama-2-70b", // TODO: implement the adaptor + // "meta/llama-2-70b-chat", // TODO: implement the adaptor + // "meta/llama-2-7b", // TODO: implement the adaptor + // "meta/llama-2-7b-chat", // TODO: implement the adaptor + // "meta/meta-llama-3.1-405b-instruct", // TODO: implement the adaptor + // "meta/meta-llama-3-70b", // TODO: implement the adaptor + // "meta/meta-llama-3-70b-instruct", // TODO: implement the adaptor + // "meta/meta-llama-3-8b", // TODO: implement the adaptor + // "meta/meta-llama-3-8b-instruct", // TODO: implement the adaptor + // "mistralai/mistral-7b-instruct-v0.2", // TODO: implement the adaptor + // "mistralai/mistral-7b-v0.1", // TODO: implement the adaptor + // "mistralai/mixtral-8x7b-instruct-v0.1", // TODO: implement the adaptor + // ------------------------------------- + // video model + // ------------------------------------- + // "minimax/video-01", // TODO: implement the adaptor +} diff --git a/relay/adaptor/replicate/image.go b/relay/adaptor/replicate/image.go new file mode 100644 index 0000000000..5cc093bfc5 --- /dev/null +++ b/relay/adaptor/replicate/image.go @@ -0,0 +1,223 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "image" + "image/png" + "io" + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "github.com/pkg/errors" + "github.com/songquanpeng/one-api/common/logger" + "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/meta" + "github.com/songquanpeng/one-api/relay/model" + "golang.org/x/image/webp" + "golang.org/x/sync/errgroup" +) + +// // ImagesEditsHandler just copy response body to client +// // +// // https://replicate.com/black-forest-labs/flux-fill-pro +// func ImagesEditsHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusCode, *model.Usage) { +// c.Writer.WriteHeader(resp.StatusCode) +// for k, v := range resp.Header { +// c.Writer.Header().Set(k, v[0]) +// } + +// if _, err := io.Copy(c.Writer, resp.Body); err != nil { +// return openai.ErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError), nil +// } +// defer resp.Body.Close() + +// return nil, nil +// } + +var errNextLoop = errors.New("next_loop") + +func ImageHandler(c *gin.Context, resp *http.Response) ( + *model.ErrorWithStatusCode, *model.Usage) { + if resp.StatusCode != http.StatusCreated { + payload, _ := io.ReadAll(resp.Body) + return openai.ErrorWrapper( + errors.Errorf("bad_status_code [%d]%s", resp.StatusCode, string(payload)), + "bad_status_code", http.StatusInternalServerError), + nil + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil + } + + respData := new(ImageResponse) + if err = json.Unmarshal(respBody, respData); err != nil { + return openai.ErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil + } + + for { + err = func() error { + // get task + taskReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, respData.URLs.Get, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + taskReq.Header.Set("Authorization", "Bearer "+meta.GetByContext(c).APIKey) + taskResp, err := http.DefaultClient.Do(taskReq) + if err != nil { + return errors.Wrap(err, "get task") + } + defer taskResp.Body.Close() + + if taskResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(taskResp.Body) + return errors.Errorf("bad status code [%d]%s", + taskResp.StatusCode, string(payload)) + } + + taskBody, err := io.ReadAll(taskResp.Body) + if err != nil { + return errors.Wrap(err, "read task response") + } + + taskData := new(ImageResponse) + if err = json.Unmarshal(taskBody, taskData); err != nil { + return errors.Wrap(err, "decode task response") + } + + switch taskData.Status { + case "succeeded": + case "failed", "canceled": + return errors.Errorf("task failed, [%s]%s", taskData.Status, taskData.Error) + default: + time.Sleep(time.Second * 3) + return errNextLoop + } + + output, err := taskData.GetOutput() + if err != nil { + return errors.Wrap(err, "get output") + } + if len(output) == 0 { + return errors.New("response output is empty") + } + + var mu sync.Mutex + var pool errgroup.Group + respBody := &openai.ImageResponse{ + Created: taskData.CompletedAt.Unix(), + Data: []openai.ImageData{}, + } + + for _, imgOut := range output { + imgOut := imgOut + pool.Go(func() error { + // download image + downloadReq, err := http.NewRequestWithContext(c.Request.Context(), + http.MethodGet, imgOut, nil) + if err != nil { + return errors.Wrap(err, "new request") + } + + imgResp, err := http.DefaultClient.Do(downloadReq) + if err != nil { + return errors.Wrap(err, "download image") + } + defer imgResp.Body.Close() + + if imgResp.StatusCode != http.StatusOK { + payload, _ := io.ReadAll(imgResp.Body) + return errors.Errorf("bad status code [%d]%s", + imgResp.StatusCode, string(payload)) + } + + imgData, err := io.ReadAll(imgResp.Body) + if err != nil { + return errors.Wrap(err, "read image") + } + + imgData, err = ConvertImageToPNG(imgData) + if err != nil { + return errors.Wrap(err, "convert image") + } + + mu.Lock() + respBody.Data = append(respBody.Data, openai.ImageData{ + B64Json: fmt.Sprintf("data:image/png;base64,%s", + base64.StdEncoding.EncodeToString(imgData)), + }) + mu.Unlock() + + return nil + }) + } + + if err := pool.Wait(); err != nil { + if len(respBody.Data) == 0 { + return errors.WithStack(err) + } + + logger.Error(c, fmt.Sprintf("some images failed to download: %+v", err)) + } + + c.JSON(http.StatusOK, respBody) + return nil + }() + if err != nil { + if errors.Is(err, errNextLoop) { + continue + } + + return openai.ErrorWrapper(err, "image_task_failed", http.StatusInternalServerError), nil + } + + break + } + + return nil, nil +} + +// ConvertImageToPNG converts a WebP image to PNG format +func ConvertImageToPNG(webpData []byte) ([]byte, error) { + // bypass if it's already a PNG image + if bytes.HasPrefix(webpData, []byte("\x89PNG")) { + return webpData, nil + } + + // check if is jpeg, convert to png + if bytes.HasPrefix(webpData, []byte("\xff\xd8\xff")) { + img, _, err := image.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode jpeg") + } + + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil + } + + // Decode the WebP image + img, err := webp.Decode(bytes.NewReader(webpData)) + if err != nil { + return nil, errors.Wrap(err, "decode webp") + } + + // Encode the image as PNG + var pngBuffer bytes.Buffer + if err := png.Encode(&pngBuffer, img); err != nil { + return nil, errors.Wrap(err, "encode png") + } + + return pngBuffer.Bytes(), nil +} diff --git a/relay/adaptor/replicate/model.go b/relay/adaptor/replicate/model.go new file mode 100644 index 0000000000..93bdcda213 --- /dev/null +++ b/relay/adaptor/replicate/model.go @@ -0,0 +1,229 @@ +package replicate + +import ( + "bytes" + "encoding/base64" + "image" + "image/png" + "io" + "mime/multipart" + "time" + + "github.com/pkg/errors" +) + +type OpenaiImageEditRequest struct { + Image *multipart.FileHeader `json:"image" form:"image" binding:"required"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + Mask *multipart.FileHeader `json:"mask" form:"mask" binding:"required"` + Model string `json:"model" form:"model" binding:"required"` + N int `json:"n" form:"n" binding:"min=0,max=10"` + Size string `json:"size" form:"size"` + ResponseFormat string `json:"response_format" form:"response_format"` +} + +// toFluxRemixRequest convert OpenAI's image edit request to Flux's remix request. +// +// Note that the mask formats of OpenAI and Flux are different: +// OpenAI's mask sets the parts to be modified as transparent (0, 0, 0, 0), +// while Flux sets the parts to be modified as black (255, 255, 255, 255), +// so we need to convert the format here. +// +// Both OpenAI's Image and Mask are browser-native ImageData, +// which need to be converted to base64 dataURI format. +func (r *OpenaiImageEditRequest) toFluxRemixRequest() (*InpaintingImageByFlusReplicateRequest, error) { + if r.ResponseFormat != "b64_json" { + return nil, errors.New("response_format must be b64_json for replicate models") + } + + fluxReq := &InpaintingImageByFlusReplicateRequest{ + Input: FluxInpaintingInput{ + Prompt: r.Prompt, + Seed: int(time.Now().UnixNano()), + Steps: 30, + Guidance: 3, + SafetyTolerance: 5, + PromptUnsampling: false, + OutputFormat: "png", + }, + } + + imgFile, err := r.Image.Open() + if err != nil { + return nil, errors.Wrap(err, "open image file") + } + defer imgFile.Close() + imgData, err := io.ReadAll(imgFile) + if err != nil { + return nil, errors.Wrap(err, "read image file") + } + + maskFile, err := r.Mask.Open() + if err != nil { + return nil, errors.Wrap(err, "open mask file") + } + defer maskFile.Close() + + // Convert image to base64 + imageBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imgData) + fluxReq.Input.Image = imageBase64 + + // Convert mask data to RGBA + maskPNG, err := png.Decode(maskFile) + if err != nil { + return nil, errors.Wrap(err, "decode mask file") + } + + // convert mask to RGBA + var maskRGBA *image.RGBA + switch converted := maskPNG.(type) { + case *image.RGBA: + maskRGBA = converted + default: + // Convert to RGBA + bounds := maskPNG.Bounds() + maskRGBA = image.NewRGBA(bounds) + for y := bounds.Min.Y; y < bounds.Max.Y; y++ { + for x := bounds.Min.X; x < bounds.Max.X; x++ { + maskRGBA.Set(x, y, maskPNG.At(x, y)) + } + } + } + + maskData := maskRGBA.Pix + invertedMask := make([]byte, len(maskData)) + for i := 0; i+4 <= len(maskData); i += 4 { + // If pixel is transparent (alpha = 0), make it black (255) + if maskData[i+3] == 0 { + invertedMask[i] = 255 // R + invertedMask[i+1] = 255 // G + invertedMask[i+2] = 255 // B + invertedMask[i+3] = 255 // A + } else { + // Copy original pixel + copy(invertedMask[i:i+4], maskData[i:i+4]) + } + } + + // Convert inverted mask to base64 encoded png image + invertedMaskRGBA := &image.RGBA{ + Pix: invertedMask, + Stride: maskRGBA.Stride, + Rect: maskRGBA.Rect, + } + + var buf bytes.Buffer + err = png.Encode(&buf, invertedMaskRGBA) + if err != nil { + return nil, errors.Wrap(err, "encode inverted mask to png") + } + + invertedMaskBase64 := "data:image/png;base64," + base64.StdEncoding.EncodeToString(buf.Bytes()) + fluxReq.Input.Mask = invertedMaskBase64 + + return fluxReq, nil +} + +// DrawImageRequest draw image by fluxpro +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type DrawImageRequest struct { + Input ImageInput `json:"input"` +} + +// ImageInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-1.1-pro/api/schema +type ImageInput struct { + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + ImagePrompt string `json:"image_prompt"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + Interval int `json:"interval" binding:"required,min=1,max=4"` + AspectRatio string `json:"aspect_ratio" binding:"required,oneof=1:1 16:9 2:3 3:2 4:5 5:4 9:16"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + Seed int `json:"seed"` + NImages int `json:"n_images" binding:"required,min=1,max=8"` + Width int `json:"width" binding:"required,min=256,max=1440"` + Height int `json:"height" binding:"required,min=256,max=1440"` +} + +// InpaintingImageByFlusReplicateRequest is request to inpainting image by flux pro +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type InpaintingImageByFlusReplicateRequest struct { + Input FluxInpaintingInput `json:"input"` +} + +// FluxInpaintingInput is input of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-fill-pro/api/schema +type FluxInpaintingInput struct { + Mask string `json:"mask" binding:"required"` + Image string `json:"image" binding:"required"` + Seed int `json:"seed"` + Steps int `json:"steps" binding:"required,min=1"` + Prompt string `json:"prompt" binding:"required,min=5"` + Guidance int `json:"guidance" binding:"required,min=2,max=5"` + OutputFormat string `json:"output_format"` + SafetyTolerance int `json:"safety_tolerance" binding:"required,min=1,max=5"` + PromptUnsampling bool `json:"prompt_unsampling"` +} + +// ImageResponse is response of DrawImageByFluxProRequest +// +// https://replicate.com/black-forest-labs/flux-pro?prediction=kg1krwsdf9rg80ch1sgsrgq7h8&output=json +type ImageResponse struct { + CompletedAt time.Time `json:"completed_at"` + CreatedAt time.Time `json:"created_at"` + DataRemoved bool `json:"data_removed"` + Error string `json:"error"` + ID string `json:"id"` + Input DrawImageRequest `json:"input"` + Logs string `json:"logs"` + Metrics FluxMetrics `json:"metrics"` + // Output could be `string` or `[]string` + Output any `json:"output"` + StartedAt time.Time `json:"started_at"` + Status string `json:"status"` + URLs FluxURLs `json:"urls"` + Version string `json:"version"` +} + +func (r *ImageResponse) GetOutput() ([]string, error) { + switch v := r.Output.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + case nil: + return nil, nil + case []interface{}: + // convert []interface{} to []string + ret := make([]string, len(v)) + for idx, vv := range v { + if vvv, ok := vv.(string); ok { + ret[idx] = vvv + } else { + return nil, errors.Errorf("unknown output type: [%T]%v", vv, vv) + } + } + + return ret, nil + default: + return nil, errors.Errorf("unknown output type: [%T]%v", r.Output, r.Output) + } +} + +// FluxMetrics is metrics of ImageResponse +type FluxMetrics struct { + ImageCount int `json:"image_count"` + PredictTime float64 `json:"predict_time"` + TotalTime float64 `json:"total_time"` +} + +// FluxURLs is urls of ImageResponse +type FluxURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} diff --git a/relay/adaptor/replicate/model_test.go b/relay/adaptor/replicate/model_test.go new file mode 100644 index 0000000000..6cde5e9417 --- /dev/null +++ b/relay/adaptor/replicate/model_test.go @@ -0,0 +1,106 @@ +package replicate + +import ( + "bytes" + "image" + "image/draw" + "image/png" + "io" + "mime/multipart" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +type nopCloser struct { + io.Reader +} + +func (n nopCloser) Close() error { return nil } + +// Custom FileHeader to override Open method +type customFileHeader struct { + *multipart.FileHeader + openFunc func() (multipart.File, error) +} + +func (c *customFileHeader) Open() (multipart.File, error) { + return c.openFunc() +} + +func TestOpenaiImageEditRequest_toFluxRemixRequest(t *testing.T) { + // Create a simple image for testing + img := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(img, img.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var imgBuf bytes.Buffer + err := png.Encode(&imgBuf, img) + require.NoError(t, err) + + // Create a simple mask for testing + mask := image.NewRGBA(image.Rect(0, 0, 10, 10)) + draw.Draw(mask, mask.Bounds(), &image.Uniform{C: image.Black}, image.Point{}, draw.Src) + var maskBuf bytes.Buffer + err = png.Encode(&maskBuf, mask) + require.NoError(t, err) + + // Create a multipart.FileHeader from the image and mask bytes + imgFileHeader, err := createFileHeader("image", "test.png", imgBuf.Bytes()) + require.NoError(t, err) + maskFileHeader, err := createFileHeader("mask", "test.png", maskBuf.Bytes()) + require.NoError(t, err) + + req := &OpenaiImageEditRequest{ + Image: imgFileHeader, + Mask: maskFileHeader, + Prompt: "Test prompt", + Model: "test-model", + ResponseFormat: "b64_json", + } + + fluxReq, err := req.toFluxRemixRequest() + require.NoError(t, err) + require.NotNil(t, fluxReq) + require.Equal(t, req.Prompt, fluxReq.Input.Prompt) + require.NotEmpty(t, fluxReq.Input.Image) + require.NotEmpty(t, fluxReq.Input.Mask) +} + +// createFileHeader creates a multipart.FileHeader from file bytes +func createFileHeader(fieldname, filename string, fileBytes []byte) (*multipart.FileHeader, error) { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Create a form file field + part, err := writer.CreateFormFile(fieldname, filename) + if err != nil { + return nil, err + } + + // Write the file bytes to the form file field + _, err = part.Write(fileBytes) + if err != nil { + return nil, err + } + + // Close the writer to finalize the form + err = writer.Close() + if err != nil { + return nil, err + } + + // Parse the multipart form + req := &http.Request{ + Header: http.Header{}, + Body: io.NopCloser(body), + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + err = req.ParseMultipartForm(int64(body.Len())) + if err != nil { + return nil, err + } + + // Retrieve the file header from the parsed form + fileHeader := req.MultipartForm.File[fieldname][0] + return fileHeader, nil +} diff --git a/relay/adaptor/vertexai/gemini/adapter.go b/relay/adaptor/vertexai/gemini/adapter.go index ceff1ed2a0..f86baee0e2 100644 --- a/relay/adaptor/vertexai/gemini/adapter.go +++ b/relay/adaptor/vertexai/gemini/adapter.go @@ -15,7 +15,7 @@ import ( ) var ModelList = []string{ - "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", + "gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision", "gemini-1.5-pro-002", "gemini-1.5-flash-002", } type Adaptor struct { diff --git a/relay/apitype/define.go b/relay/apitype/define.go index cf7b6a0d2b..0c6a5ff11a 100644 --- a/relay/apitype/define.go +++ b/relay/apitype/define.go @@ -19,6 +19,7 @@ const ( DeepL VertexAI Proxy + Replicate Dummy // this one is only for count, do not add any channel after this ) diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index 1b58ec0902..b04bf4e273 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -211,6 +211,31 @@ var ModelRatio = map[string]float64{ "deepl-ja": 25.0 / 1000 * USD, // https://console.x.ai/ "grok-beta": 5.0 / 1000 * USD, + // replicate charges based on the number of generated images + // https://replicate.com/pricing + "black-forest-labs/flux-1.1-pro": 0.04 * USD, + "black-forest-labs/flux-1.1-pro-ultra": 0.06 * USD, + "black-forest-labs/flux-canny-dev": 0.025 * USD, + "black-forest-labs/flux-canny-pro": 0.05 * USD, + "black-forest-labs/flux-depth-dev": 0.025 * USD, + "black-forest-labs/flux-depth-pro": 0.05 * USD, + "black-forest-labs/flux-dev": 0.025 * USD, + "black-forest-labs/flux-dev-lora": 0.032 * USD, + "black-forest-labs/flux-fill-dev": 0.04 * USD, + "black-forest-labs/flux-fill-pro": 0.05 * USD, + "black-forest-labs/flux-pro": 0.055 * USD, + "black-forest-labs/flux-redux-dev": 0.025 * USD, + "black-forest-labs/flux-redux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell": 0.003 * USD, + "black-forest-labs/flux-schnell-lora": 0.02 * USD, + "ideogram-ai/ideogram-v2": 0.08 * USD, + "ideogram-ai/ideogram-v2-turbo": 0.05 * USD, + "recraft-ai/recraft-v3": 0.04 * USD, + "recraft-ai/recraft-v3-svg": 0.08 * USD, + "stability-ai/stable-diffusion-3": 0.035 * USD, + "stability-ai/stable-diffusion-3.5-large": 0.065 * USD, + "stability-ai/stable-diffusion-3.5-large-turbo": 0.04 * USD, + "stability-ai/stable-diffusion-3.5-medium": 0.035 * USD, } var CompletionRatio = map[string]float64{ diff --git a/relay/channeltype/define.go b/relay/channeltype/define.go index 98316959a1..f54d0e30de 100644 --- a/relay/channeltype/define.go +++ b/relay/channeltype/define.go @@ -47,5 +47,6 @@ const ( Proxy SiliconFlow XAI + Replicate Dummy ) diff --git a/relay/channeltype/helper.go b/relay/channeltype/helper.go index fae3357f8c..8839b30adb 100644 --- a/relay/channeltype/helper.go +++ b/relay/channeltype/helper.go @@ -37,6 +37,8 @@ func ToAPIType(channelType int) int { apiType = apitype.DeepL case VertextAI: apiType = apitype.VertexAI + case Replicate: + apiType = apitype.Replicate case Proxy: apiType = apitype.Proxy } diff --git a/relay/channeltype/url.go b/relay/channeltype/url.go index b8bd61f89e..8e271f4efb 100644 --- a/relay/channeltype/url.go +++ b/relay/channeltype/url.go @@ -47,6 +47,7 @@ var ChannelBaseURLs = []string{ "", // 43 "https://api.siliconflow.cn", // 44 "https://api.x.ai", // 45 + "https://api.replicate.com/v1/models/", // 46 } func init() { diff --git a/relay/controller/image.go b/relay/controller/image.go index 1e06e858ef..6154c74c7b 100644 --- a/relay/controller/image.go +++ b/relay/controller/image.go @@ -4,18 +4,20 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" + "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/ctxkey" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay" "github.com/songquanpeng/one-api/relay/adaptor/openai" + "github.com/songquanpeng/one-api/relay/adaptor/replicate" billingratio "github.com/songquanpeng/one-api/relay/billing/ratio" "github.com/songquanpeng/one-api/relay/channeltype" "github.com/songquanpeng/one-api/relay/meta" @@ -26,7 +28,7 @@ func getImageRequest(c *gin.Context, relayMode int) (*relaymodel.ImageRequest, e imageRequest := &relaymodel.ImageRequest{} err := common.UnmarshalBodyReusable(c, imageRequest) if err != nil { - return nil, err + return nil, errors.WithStack(err) } if imageRequest.N == 0 { imageRequest.N = 1 @@ -134,7 +136,8 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus c.Set("response_format", imageRequest.ResponseFormat) var requestBody io.Reader - if isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body + if strings.ToLower(c.GetString(ctxkey.ContentType)) == "application/json" && + isModelMapped || meta.ChannelType == channeltype.Azure { // make Azure channel request body jsonStr, err := json.Marshal(imageRequest) if err != nil { return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) @@ -150,12 +153,11 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } adaptor.Init(meta) + // these adaptors need to convert the request switch meta.ChannelType { - case channeltype.Ali: - fallthrough - case channeltype.Baidu: - fallthrough - case channeltype.Zhipu: + case channeltype.Zhipu, + channeltype.Ali, + channeltype.Baidu: finalRequest, err := adaptor.ConvertImageRequest(imageRequest) if err != nil { return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) @@ -165,6 +167,16 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) } requestBody = bytes.NewBuffer(jsonStr) + case channeltype.Replicate: + finalRequest, err := replicate.ConvertImageRequest(c, imageRequest) + if err != nil { + return openai.ErrorWrapper(err, "convert_image_request_failed", http.StatusInternalServerError) + } + jsonStr, err := json.Marshal(finalRequest) + if err != nil { + return openai.ErrorWrapper(err, "marshal_image_request_failed", http.StatusInternalServerError) + } + requestBody = bytes.NewBuffer(jsonStr) } modelRatio := billingratio.GetModelRatio(imageModel, meta.ChannelType) @@ -172,7 +184,14 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus ratio := modelRatio * groupRatio userQuota, err := model.CacheGetUserQuota(ctx, meta.UserId) - quota := int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + var quota int64 + switch meta.ChannelType { + case channeltype.Replicate: + // replicate always return 1 image + quota = int64(ratio * imageCostRatio * 1000) + default: + quota = int64(ratio*imageCostRatio*1000) * int64(imageRequest.N) + } if userQuota-quota < 0 { return openai.ErrorWrapper(errors.New("user quota is not enough"), "insufficient_user_quota", http.StatusForbidden) @@ -186,7 +205,9 @@ func RelayImageHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus } defer func(ctx context.Context) { - if resp != nil && resp.StatusCode != http.StatusOK { + if resp != nil && + resp.StatusCode != http.StatusCreated && // replicate returns 201 + resp.StatusCode != http.StatusOK { return } diff --git a/relay/model/image.go b/relay/model/image.go index bab8425619..ec3e769182 100644 --- a/relay/model/image.go +++ b/relay/model/image.go @@ -1,12 +1,12 @@ package model type ImageRequest struct { - Model string `json:"model"` - Prompt string `json:"prompt" binding:"required"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Style string `json:"style,omitempty"` - User string `json:"user,omitempty"` + Model string `json:"model" form:"model"` + Prompt string `json:"prompt" form:"prompt" binding:"required"` + N int `json:"n,omitempty" form:"n"` + Size string `json:"size,omitempty" form:"size"` + Quality string `json:"quality,omitempty" form:"quality"` + ResponseFormat string `json:"response_format,omitempty" form:"response_format"` + Style string `json:"style,omitempty" form:"style"` + User string `json:"user,omitempty" form:"user"` } diff --git a/relay/relaymode/define.go b/relay/relaymode/define.go index aa7712057d..79999826e0 100644 --- a/relay/relaymode/define.go +++ b/relay/relaymode/define.go @@ -11,6 +11,7 @@ const ( AudioSpeech AudioTranscription AudioTranslation + ImagesEdits // Proxy is a special relay mode for proxying requests to custom upstream Proxy ) diff --git a/relay/relaymode/helper.go b/relay/relaymode/helper.go index 2cde5b8510..35a0535e11 100644 --- a/relay/relaymode/helper.go +++ b/relay/relaymode/helper.go @@ -24,8 +24,11 @@ func GetByPath(path string) int { relayMode = AudioTranscription } else if strings.HasPrefix(path, "/v1/audio/translations") { relayMode = AudioTranslation + } else if strings.HasPrefix(path, "/v1/images/edits") { + relayMode = ImagesEdits } else if strings.HasPrefix(path, "/v1/oneapi/proxy") { relayMode = Proxy } + return relayMode } diff --git a/router/relay.go b/router/relay.go index 8f3c73030d..554a64f47b 100644 --- a/router/relay.go +++ b/router/relay.go @@ -25,7 +25,7 @@ func SetRelayRouter(router *gin.Engine) { relayV1Router.POST("/chat/completions", controller.Relay) relayV1Router.POST("/edits", controller.Relay) relayV1Router.POST("/images/generations", controller.Relay) - relayV1Router.POST("/images/edits", controller.RelayNotImplemented) + relayV1Router.POST("/images/edits", controller.Relay) relayV1Router.POST("/images/variations", controller.RelayNotImplemented) relayV1Router.POST("/embeddings", controller.Relay) relayV1Router.POST("/engines/:model/embeddings", controller.Relay) diff --git a/web/air/src/constants/channel.constants.js b/web/air/src/constants/channel.constants.js index a7e984ecf5..e7b25399b9 100644 --- a/web/air/src/constants/channel.constants.js +++ b/web/air/src/constants/channel.constants.js @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' }, diff --git a/web/berry/src/constants/ChannelConstants.js b/web/berry/src/constants/ChannelConstants.js index 3539887525..375adcd958 100644 --- a/web/berry/src/constants/ChannelConstants.js +++ b/web/berry/src/constants/ChannelConstants.js @@ -185,6 +185,12 @@ export const CHANNEL_OPTIONS = { value: 45, color: 'primary' }, + 45: { + key: 46, + text: 'Replicate', + value: 46, + color: 'primary' + }, 41: { key: 41, text: 'Novita', diff --git a/web/default/src/constants/channel.constants.js b/web/default/src/constants/channel.constants.js index 5b25577d48..614255085c 100644 --- a/web/default/src/constants/channel.constants.js +++ b/web/default/src/constants/channel.constants.js @@ -31,6 +31,7 @@ export const CHANNEL_OPTIONS = [ { key: 43, text: 'Proxy', value: 43, color: 'blue' }, { key: 44, text: 'SiliconFlow', value: 44, color: 'blue' }, { key: 45, text: 'xAI', value: 45, color: 'blue' }, + { key: 46, text: 'Replicate', value: 46, color: 'blue' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 22, text: '知识库:FastGPT', value: 22, color: 'blue' }, { key: 21, text: '知识库:AI Proxy', value: 21, color: 'purple' },